Module ktrain.graph.preprocessor
Expand source code
from .. import utils as U
from ..imports import *
from ..preprocessor import Preprocessor
class NodePreprocessor(Preprocessor):
"""
```
Node preprocessing base class
```
"""
def __init__(self, G_nx, df, sample_size=10, missing_label_value=None):
self.sampsize = sample_size # neighbor sample size
self.df = df # node attributes and targets
# TODO: eliminate storage redundancy
self.G = G_nx # networkx graph
self.G_sg = None # StellarGraph
# clean df
df.index = df.index.map(str)
df = df[df.index.isin(list(self.G.nodes()))]
# class names
self.c = list(set([c[0] for c in df[["target"]].values]))
if missing_label_value is not None:
self.c.remove(missing_label_value)
self.c.sort()
# feature names + target
self.colnames = list(df.columns.values)
if self.colnames[-1] != "target":
raise ValueError('last column of df must be "target"')
# set by preprocess_train
self.y_encoding = None
def get_preprocessor(self):
return (self.G, self.df)
def get_classes(self):
return self.c
@property
def feature_names(self):
return self.colnames[:-1]
def preprocess(self, df, G):
return self.preprocess_test(df, G)
def ids_exist(self, node_ids):
"""
```
check validity of node IDs
```
"""
df = self.df[self.df.index.isin(node_ids)]
return df.shape[0] > 0
def preprocess_train(self, node_ids):
"""
```
preprocess training set
```
"""
if not self.ids_exist(node_ids):
raise ValueError("node_ids must exist in self.df")
# subset df for training nodes
df_tr = self.df[self.df.index.isin(node_ids)]
# one-hot-encode target
self.y_encoding = sklearn.feature_extraction.DictVectorizer(sparse=False)
train_targets = self.y_encoding.fit_transform(
df_tr[["target"]].to_dict("records")
)
# import stellargraph
try:
import stellargraph as sg
from stellargraph.mapper import GraphSAGENodeGenerator
except:
raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
raise Exception(SG_ERRMSG)
# return generator
G_sg = sg.StellarGraph(self.G, node_features=self.df[self.feature_names])
self.G_sg = G_sg
generator = GraphSAGENodeGenerator(
G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize]
)
train_gen = generator.flow(df_tr.index, train_targets, shuffle=True)
from .sg_wrappers import NodeSequenceWrapper
return NodeSequenceWrapper(train_gen)
def preprocess_valid(self, node_ids):
"""
```
preprocess validation nodes (transductive inference)
node_ids (list): list of node IDs that generator will yield
```
"""
if not self.ids_exist(node_ids):
raise ValueError("node_ids must exist in self.df")
if self.y_encoding is None:
raise Exception(
"Unset parameters. Are you sure you called preprocess_train first?"
)
# subset df for validation nodes
df_val = self.df[self.df.index.isin(node_ids)]
# one-hot-encode target
val_targets = self.y_encoding.transform(df_val[["target"]].to_dict("records"))
# import stellargraph
try:
import stellargraph as sg
from stellargraph.mapper import GraphSAGENodeGenerator
except:
raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
raise Exception(SG_ERRMSG)
# return generator
if self.G_sg is None:
self.G_sg = sg.StellarGraph(
self.G, node_features=self.df[self.feature_names]
)
generator = GraphSAGENodeGenerator(
self.G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize]
)
val_gen = generator.flow(df_val.index, val_targets, shuffle=False)
from .sg_wrappers import NodeSequenceWrapper
return NodeSequenceWrapper(val_gen)
def preprocess_test(self, df_te, G_te):
"""
```
preprocess for inductive inference
df_te (DataFrame): pandas dataframe containing new node attributes
G_te (Graph): a networkx Graph containing new nodes
```
"""
try:
import networkx as nx
except ImportError:
raise ImportError("Please install networkx: pip install networkx")
if self.y_encoding is None:
raise Exception(
"Unset parameters. Are you sure you called preprocess_train first?"
)
# get aggregrated df
# df_agg = pd.concat([df_te, self.df]).drop_duplicates(keep='last')
df_agg = pd.concat([df_te, self.df])
# df_te = pd.concat([self.df, df_agg]).drop_duplicates(keep=False)
# get aggregrated graph
is_subset = set(self.G.nodes()) <= set(G_te.nodes())
if not is_subset:
raise ValueError("Nodes in self.G must be subset of G_te")
G_agg = nx.compose(self.G, G_te)
# one-hot-encode target
if "target" in df_te.columns:
test_targets = self.y_encoding.transform(
df_te[["target"]].to_dict("records")
)
else:
test_targets = [-1] * len(df_te.shape[0])
# import stellargraph
try:
import stellargraph as sg
from stellargraph.mapper import GraphSAGENodeGenerator
except:
raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
raise Exception(SG_ERRMSG)
# return generator
G_sg = sg.StellarGraph(G_agg, node_features=df_agg[self.feature_names])
generator = GraphSAGENodeGenerator(
G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize]
)
test_gen = generator.flow(df_te.index, test_targets, shuffle=False)
from .sg_wrappers import NodeSequenceWrapper
return NodeSequenceWrapper(test_gen)
class LinkPreprocessor(Preprocessor):
"""
```
Link preprocessing base class
```
"""
def __init__(self, G, sample_sizes=[10, 20]):
self.sample_sizes = sample_sizes
self.G = G # original graph under consideration with all original links
# class names
self.c = ["negative", "positive"]
def get_preprocessor(self):
return self
def get_classes(self):
return self.c
def preprocess(self, G, edge_ids):
edge_labels = [1] * len(edge_ids)
return self.preprocess_valid(G, edge_ids, edge_labels)
def preprocess_train(self, G, edge_ids, edge_labels, mode="train"):
"""
```
preprocess training set
Args:
G (networkx graph): networkx graph
edge_ids(list): list of tuples representing edge ids
edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not)
```
"""
# import stellargraph
try:
import stellargraph as sg
from stellargraph.mapper import GraphSAGELinkGenerator
except:
raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
raise Exception(SG_ERRMSG)
# edge_labels = to_categorical(edge_labels)
G_sg = sg.StellarGraph(G, node_features="feature")
# print(G_sg.info())
shuffle = True if mode == "train" else False
link_seq = GraphSAGELinkGenerator(G_sg, U.DEFAULT_BS, self.sample_sizes).flow(
edge_ids, edge_labels, shuffle=shuffle
)
from .sg_wrappers import LinkSequenceWrapper
return LinkSequenceWrapper(link_seq)
def preprocess_valid(self, G, edge_ids, edge_labels):
"""
```
preprocess training set
Args:
G (networkx graph): networkx graph
edge_ids(list): list of tuples representing edge ids
edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not)
```
"""
return self.preprocess_train(G, edge_ids, edge_labels, mode="valid")
Classes
class LinkPreprocessor (G, sample_sizes=[10, 20])
-
Link preprocessing base class
Expand source code
class LinkPreprocessor(Preprocessor): """ ``` Link preprocessing base class ``` """ def __init__(self, G, sample_sizes=[10, 20]): self.sample_sizes = sample_sizes self.G = G # original graph under consideration with all original links # class names self.c = ["negative", "positive"] def get_preprocessor(self): return self def get_classes(self): return self.c def preprocess(self, G, edge_ids): edge_labels = [1] * len(edge_ids) return self.preprocess_valid(G, edge_ids, edge_labels) def preprocess_train(self, G, edge_ids, edge_labels, mode="train"): """ ``` preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not) ``` """ # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGELinkGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # edge_labels = to_categorical(edge_labels) G_sg = sg.StellarGraph(G, node_features="feature") # print(G_sg.info()) shuffle = True if mode == "train" else False link_seq = GraphSAGELinkGenerator(G_sg, U.DEFAULT_BS, self.sample_sizes).flow( edge_ids, edge_labels, shuffle=shuffle ) from .sg_wrappers import LinkSequenceWrapper return LinkSequenceWrapper(link_seq) def preprocess_valid(self, G, edge_ids, edge_labels): """ ``` preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not) ``` """ return self.preprocess_train(G, edge_ids, edge_labels, mode="valid")
Ancestors
- Preprocessor
- abc.ABC
Methods
def get_classes(self)
-
Expand source code
def get_classes(self): return self.c
def get_preprocessor(self)
-
Expand source code
def get_preprocessor(self): return self
def preprocess(self, G, edge_ids)
-
Expand source code
def preprocess(self, G, edge_ids): edge_labels = [1] * len(edge_ids) return self.preprocess_valid(G, edge_ids, edge_labels)
def preprocess_train(self, G, edge_ids, edge_labels, mode='train')
-
preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not)
Expand source code
def preprocess_train(self, G, edge_ids, edge_labels, mode="train"): """ ``` preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not) ``` """ # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGELinkGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # edge_labels = to_categorical(edge_labels) G_sg = sg.StellarGraph(G, node_features="feature") # print(G_sg.info()) shuffle = True if mode == "train" else False link_seq = GraphSAGELinkGenerator(G_sg, U.DEFAULT_BS, self.sample_sizes).flow( edge_ids, edge_labels, shuffle=shuffle ) from .sg_wrappers import LinkSequenceWrapper return LinkSequenceWrapper(link_seq)
def preprocess_valid(self, G, edge_ids, edge_labels)
-
preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not)
Expand source code
def preprocess_valid(self, G, edge_ids, edge_labels): """ ``` preprocess training set Args: G (networkx graph): networkx graph edge_ids(list): list of tuples representing edge ids edge_labels(list): edge labels (1 or 0 to indicated whether it is a true edge in original graph or not) ``` """ return self.preprocess_train(G, edge_ids, edge_labels, mode="valid")
class NodePreprocessor (G_nx, df, sample_size=10, missing_label_value=None)
-
Node preprocessing base class
Expand source code
class NodePreprocessor(Preprocessor): """ ``` Node preprocessing base class ``` """ def __init__(self, G_nx, df, sample_size=10, missing_label_value=None): self.sampsize = sample_size # neighbor sample size self.df = df # node attributes and targets # TODO: eliminate storage redundancy self.G = G_nx # networkx graph self.G_sg = None # StellarGraph # clean df df.index = df.index.map(str) df = df[df.index.isin(list(self.G.nodes()))] # class names self.c = list(set([c[0] for c in df[["target"]].values])) if missing_label_value is not None: self.c.remove(missing_label_value) self.c.sort() # feature names + target self.colnames = list(df.columns.values) if self.colnames[-1] != "target": raise ValueError('last column of df must be "target"') # set by preprocess_train self.y_encoding = None def get_preprocessor(self): return (self.G, self.df) def get_classes(self): return self.c @property def feature_names(self): return self.colnames[:-1] def preprocess(self, df, G): return self.preprocess_test(df, G) def ids_exist(self, node_ids): """ ``` check validity of node IDs ``` """ df = self.df[self.df.index.isin(node_ids)] return df.shape[0] > 0 def preprocess_train(self, node_ids): """ ``` preprocess training set ``` """ if not self.ids_exist(node_ids): raise ValueError("node_ids must exist in self.df") # subset df for training nodes df_tr = self.df[self.df.index.isin(node_ids)] # one-hot-encode target self.y_encoding = sklearn.feature_extraction.DictVectorizer(sparse=False) train_targets = self.y_encoding.fit_transform( df_tr[["target"]].to_dict("records") ) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator G_sg = sg.StellarGraph(self.G, node_features=self.df[self.feature_names]) self.G_sg = G_sg generator = GraphSAGENodeGenerator( G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) train_gen = generator.flow(df_tr.index, train_targets, shuffle=True) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(train_gen) def preprocess_valid(self, node_ids): """ ``` preprocess validation nodes (transductive inference) node_ids (list): list of node IDs that generator will yield ``` """ if not self.ids_exist(node_ids): raise ValueError("node_ids must exist in self.df") if self.y_encoding is None: raise Exception( "Unset parameters. Are you sure you called preprocess_train first?" ) # subset df for validation nodes df_val = self.df[self.df.index.isin(node_ids)] # one-hot-encode target val_targets = self.y_encoding.transform(df_val[["target"]].to_dict("records")) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator if self.G_sg is None: self.G_sg = sg.StellarGraph( self.G, node_features=self.df[self.feature_names] ) generator = GraphSAGENodeGenerator( self.G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) val_gen = generator.flow(df_val.index, val_targets, shuffle=False) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(val_gen) def preprocess_test(self, df_te, G_te): """ ``` preprocess for inductive inference df_te (DataFrame): pandas dataframe containing new node attributes G_te (Graph): a networkx Graph containing new nodes ``` """ try: import networkx as nx except ImportError: raise ImportError("Please install networkx: pip install networkx") if self.y_encoding is None: raise Exception( "Unset parameters. Are you sure you called preprocess_train first?" ) # get aggregrated df # df_agg = pd.concat([df_te, self.df]).drop_duplicates(keep='last') df_agg = pd.concat([df_te, self.df]) # df_te = pd.concat([self.df, df_agg]).drop_duplicates(keep=False) # get aggregrated graph is_subset = set(self.G.nodes()) <= set(G_te.nodes()) if not is_subset: raise ValueError("Nodes in self.G must be subset of G_te") G_agg = nx.compose(self.G, G_te) # one-hot-encode target if "target" in df_te.columns: test_targets = self.y_encoding.transform( df_te[["target"]].to_dict("records") ) else: test_targets = [-1] * len(df_te.shape[0]) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator G_sg = sg.StellarGraph(G_agg, node_features=df_agg[self.feature_names]) generator = GraphSAGENodeGenerator( G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) test_gen = generator.flow(df_te.index, test_targets, shuffle=False) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(test_gen)
Ancestors
- Preprocessor
- abc.ABC
Instance variables
var feature_names
-
Expand source code
@property def feature_names(self): return self.colnames[:-1]
Methods
def get_classes(self)
-
Expand source code
def get_classes(self): return self.c
def get_preprocessor(self)
-
Expand source code
def get_preprocessor(self): return (self.G, self.df)
def ids_exist(self, node_ids)
-
check validity of node IDs
Expand source code
def ids_exist(self, node_ids): """ ``` check validity of node IDs ``` """ df = self.df[self.df.index.isin(node_ids)] return df.shape[0] > 0
def preprocess(self, df, G)
-
Expand source code
def preprocess(self, df, G): return self.preprocess_test(df, G)
def preprocess_test(self, df_te, G_te)
-
preprocess for inductive inference df_te (DataFrame): pandas dataframe containing new node attributes G_te (Graph): a networkx Graph containing new nodes
Expand source code
def preprocess_test(self, df_te, G_te): """ ``` preprocess for inductive inference df_te (DataFrame): pandas dataframe containing new node attributes G_te (Graph): a networkx Graph containing new nodes ``` """ try: import networkx as nx except ImportError: raise ImportError("Please install networkx: pip install networkx") if self.y_encoding is None: raise Exception( "Unset parameters. Are you sure you called preprocess_train first?" ) # get aggregrated df # df_agg = pd.concat([df_te, self.df]).drop_duplicates(keep='last') df_agg = pd.concat([df_te, self.df]) # df_te = pd.concat([self.df, df_agg]).drop_duplicates(keep=False) # get aggregrated graph is_subset = set(self.G.nodes()) <= set(G_te.nodes()) if not is_subset: raise ValueError("Nodes in self.G must be subset of G_te") G_agg = nx.compose(self.G, G_te) # one-hot-encode target if "target" in df_te.columns: test_targets = self.y_encoding.transform( df_te[["target"]].to_dict("records") ) else: test_targets = [-1] * len(df_te.shape[0]) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator G_sg = sg.StellarGraph(G_agg, node_features=df_agg[self.feature_names]) generator = GraphSAGENodeGenerator( G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) test_gen = generator.flow(df_te.index, test_targets, shuffle=False) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(test_gen)
def preprocess_train(self, node_ids)
-
preprocess training set
Expand source code
def preprocess_train(self, node_ids): """ ``` preprocess training set ``` """ if not self.ids_exist(node_ids): raise ValueError("node_ids must exist in self.df") # subset df for training nodes df_tr = self.df[self.df.index.isin(node_ids)] # one-hot-encode target self.y_encoding = sklearn.feature_extraction.DictVectorizer(sparse=False) train_targets = self.y_encoding.fit_transform( df_tr[["target"]].to_dict("records") ) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator G_sg = sg.StellarGraph(self.G, node_features=self.df[self.feature_names]) self.G_sg = G_sg generator = GraphSAGENodeGenerator( G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) train_gen = generator.flow(df_tr.index, train_targets, shuffle=True) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(train_gen)
def preprocess_valid(self, node_ids)
-
preprocess validation nodes (transductive inference) node_ids (list): list of node IDs that generator will yield
Expand source code
def preprocess_valid(self, node_ids): """ ``` preprocess validation nodes (transductive inference) node_ids (list): list of node IDs that generator will yield ``` """ if not self.ids_exist(node_ids): raise ValueError("node_ids must exist in self.df") if self.y_encoding is None: raise Exception( "Unset parameters. Are you sure you called preprocess_train first?" ) # subset df for validation nodes df_val = self.df[self.df.index.isin(node_ids)] # one-hot-encode target val_targets = self.y_encoding.transform(df_val[["target"]].to_dict("records")) # import stellargraph try: import stellargraph as sg from stellargraph.mapper import GraphSAGENodeGenerator except: raise Exception(SG_ERRMSG) if version.parse(sg.__version__) < version.parse("0.8"): raise Exception(SG_ERRMSG) # return generator if self.G_sg is None: self.G_sg = sg.StellarGraph( self.G, node_features=self.df[self.feature_names] ) generator = GraphSAGENodeGenerator( self.G_sg, U.DEFAULT_BS, [self.sampsize, self.sampsize] ) val_gen = generator.flow(df_val.index, val_targets, shuffle=False) from .sg_wrappers import NodeSequenceWrapper return NodeSequenceWrapper(val_gen)