Module ktrain.graph.sg_wrappers
Expand source code
from ..dataset import SequenceDataset
from ..imports import *
# import stellargraph
try:
import stellargraph as sg
from stellargraph.mapper import link_mappers, node_mappers
except:
raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
raise Exception(SG_ERRMSG)
class NodeSequenceWrapper(node_mappers.NodeSequence, SequenceDataset):
def __init__(self, node_seq):
if not isinstance(node_seq, node_mappers.NodeSequence):
raise ValueError("node_seq must by a stellargraph NodeSequence object")
self.node_seq = node_seq
self.targets = node_seq.targets
self.generator = node_seq.generator
self.ids = node_seq.ids
self.__len__ = node_seq.__len__
self.__getitem__ = node_seq.__getitem__
self.on_epoch_end = node_seq.on_epoch_end
self.indices = node_seq.indices
def __setattr__(self, name, value):
if name == "batch_size":
self.generator.batch_size = value
elif name == "data_size":
self.node_seq.data_size = value
elif name == "shuffle":
self.node_seq.shuffle = value
elif name == "head_node_types":
self.node_seq.head_node_types = value
elif name == "_sampling_schema":
self.node_seq._sample_schema = value
else:
self.__dict__[name] = value
return
def __getattr__(self, name):
if name == "batch_size":
return self.generator.batch_size
elif name == "data_size":
return self.node_seq.data_size
elif name == "shuffle":
return self.node_seq.shuffle
elif name == "head_node_types":
return self.node_seq.head_node_types
elif name == "_sampling_schema":
return self.node_seq._sampling_schema
elif name == "reset":
# stellargraph did not implement reset for its generators
# return a zero-argument lambda that returns None
return lambda: None
elif name == "graph":
return self.generator.graph
else:
try:
return self.__dict__[name]
except:
raise AttributeError
return
def nsamples(self):
return self.targets.shape[0]
def get_y(self):
return self.targets
def xshape(self):
return self[0][0][0].shape[1:] # returns 1st neighborhood only
def nclasses(self):
return self[0][1].shape[1]
class LinkSequenceWrapper(link_mappers.LinkSequence, SequenceDataset):
def __init__(self, link_seq):
if not isinstance(link_seq, link_mappers.LinkSequence):
raise ValueError("link_seq must by a stellargraph LinkSequence object")
self.link_seq = link_seq
self.targets = link_seq.targets
self.generator = link_seq.generator
self.ids = link_seq.ids
self.__len__ = link_seq.__len__
self.__getitem__ = link_seq.__getitem__
self.on_epoch_end = link_seq.on_epoch_end
self.indices = link_seq.indices
def __setattr__(self, name, value):
if name == "batch_size":
self.generator.batch_size = value
elif name == "data_size":
self.link_seq.data_size = value
elif name == "shuffle":
self.link_seq.shuffle = value
elif name == "head_node_types":
self.link_seq.head_node_types = value
elif name == "_sampling_schema":
self.link_seq._sample_schema = value
else:
self.__dict__[name] = value
return
def __getattr__(self, name):
if name == "batch_size":
return self.generator.batch_size
elif name == "data_size":
return self.link_seq.data_size
elif name == "shuffle":
return self.link_seq.shuffle
elif name == "head_node_types":
return self.link_seq.head_node_types
elif name == "_sampling_schema":
return self.link_seq._sampling_schema
elif name == "reset":
# stellargraph did not implement reset for its generators
# return a zero-argument lambda that returns None
return lambda: None
elif name == "graph":
return self.generator.graph
else:
try:
return self.__dict__[name]
except:
raise AttributeError
return
def nsamples(self):
return self.targets.shape[0]
def get_y(self):
return self.targets
def xshape(self):
return self[0][0][0].shape[1:] # returns 1st neighborhood only
def nclasses(self):
return 2
return self[0][1].shape[1]
Classes
class LinkSequenceWrapper (link_seq)
-
Keras-compatible data generator to use with Keras methods :meth:
keras.Model.fit_generator
, :meth:keras.Model.evaluate_generator
, and :meth:keras.Model.predict_generator
This class generates data samples for link inference models and should be created using the :meth:flow
method of :class:GraphSAGELinkGenerator
or :class:HinSAGELinkGenerator
or :class:Attri2VecLinkGenerator
.Args
generator
- An instance of :class:
GraphSAGELinkGenerator
or :class:HinSAGELinkGenerator
or - :class:
Attri2VecLinkGenerator
. ids
:list
oriterable
- Link IDs to batch, each link id being a tuple of (src, dst) node ids. (The graph nodes must have a "feature" attribute that is used as input to the GraphSAGE/Attri2Vec model.) These are the links that are to be used to train or inference, and the embeddings calculated for these links via a binary operator applied to their source and destination nodes, are passed to the downstream task of link prediction or link attribute inference. The source and target nodes of the links are used as head nodes for which subgraphs are sampled. The subgraphs are sampled from all nodes.
targets
:list
oriterable
- Labels corresponding to the above links, e.g., 0 or 1 for the link prediction problem.
shuffle
:bool
- If True (default) the ids will be randomly shuffled every epoch.
Expand source code
class LinkSequenceWrapper(link_mappers.LinkSequence, SequenceDataset): def __init__(self, link_seq): if not isinstance(link_seq, link_mappers.LinkSequence): raise ValueError("link_seq must by a stellargraph LinkSequence object") self.link_seq = link_seq self.targets = link_seq.targets self.generator = link_seq.generator self.ids = link_seq.ids self.__len__ = link_seq.__len__ self.__getitem__ = link_seq.__getitem__ self.on_epoch_end = link_seq.on_epoch_end self.indices = link_seq.indices def __setattr__(self, name, value): if name == "batch_size": self.generator.batch_size = value elif name == "data_size": self.link_seq.data_size = value elif name == "shuffle": self.link_seq.shuffle = value elif name == "head_node_types": self.link_seq.head_node_types = value elif name == "_sampling_schema": self.link_seq._sample_schema = value else: self.__dict__[name] = value return def __getattr__(self, name): if name == "batch_size": return self.generator.batch_size elif name == "data_size": return self.link_seq.data_size elif name == "shuffle": return self.link_seq.shuffle elif name == "head_node_types": return self.link_seq.head_node_types elif name == "_sampling_schema": return self.link_seq._sampling_schema elif name == "reset": # stellargraph did not implement reset for its generators # return a zero-argument lambda that returns None return lambda: None elif name == "graph": return self.generator.graph else: try: return self.__dict__[name] except: raise AttributeError return def nsamples(self): return self.targets.shape[0] def get_y(self): return self.targets def xshape(self): return self[0][0][0].shape[1:] # returns 1st neighborhood only def nclasses(self): return 2 return self[0][1].shape[1]
Ancestors
- stellargraph.mapper.link_mappers.LinkSequence
- SequenceDataset
- Dataset
- keras.utils.data_utils.Sequence
Methods
def get_y(self)
-
Expand source code
def get_y(self): return self.targets
def nsamples(self)
-
Expand source code
def nsamples(self): return self.targets.shape[0]
Inherited members
class NodeSequenceWrapper (node_seq)
-
Keras-compatible data generator to use with the Keras methods :meth:
keras.Model.fit_generator
, :meth:keras.Model.evaluate_generator
, and :meth:keras.Model.predict_generator
.This class generated data samples for node inference models and should be created using the
.flow(…)
method of :class:GraphSAGENodeGenerator
or :class:DirectedGraphSAGENodeGenerator
or :class:HinSAGENodeGenerator
or :class:Attri2VecNodeGenerator
.GraphSAGENodeGenerator, DirectedGraphSAGENodeGenerator,and HinSAGENodeGenerator are classes that capture the graph structure and the feature vectors of each node. These generator classes are used within the NodeSequence to generate samples of k-hop neighbourhoods in the graph and to return to this class the features from the sampled neighbourhoods.
Attri2VecNodeGenerator is the class that captures node feature vectors of each node.
Args
generator
- GraphSAGENodeGenerator, DirectedGraphSAGENodeGenerator or HinSAGENodeGenerator or Attri2VecNodeGenerator. The generator object containing the graph information.
ids
- list A list of the node_ids to be used as head-nodes in the downstream task.
targets
- list, optional (default=None) A list of targets or labels to be used in the downstream class.
shuffle
:bool
- If True (default) the ids will be randomly shuffled every epoch.
Expand source code
class NodeSequenceWrapper(node_mappers.NodeSequence, SequenceDataset): def __init__(self, node_seq): if not isinstance(node_seq, node_mappers.NodeSequence): raise ValueError("node_seq must by a stellargraph NodeSequence object") self.node_seq = node_seq self.targets = node_seq.targets self.generator = node_seq.generator self.ids = node_seq.ids self.__len__ = node_seq.__len__ self.__getitem__ = node_seq.__getitem__ self.on_epoch_end = node_seq.on_epoch_end self.indices = node_seq.indices def __setattr__(self, name, value): if name == "batch_size": self.generator.batch_size = value elif name == "data_size": self.node_seq.data_size = value elif name == "shuffle": self.node_seq.shuffle = value elif name == "head_node_types": self.node_seq.head_node_types = value elif name == "_sampling_schema": self.node_seq._sample_schema = value else: self.__dict__[name] = value return def __getattr__(self, name): if name == "batch_size": return self.generator.batch_size elif name == "data_size": return self.node_seq.data_size elif name == "shuffle": return self.node_seq.shuffle elif name == "head_node_types": return self.node_seq.head_node_types elif name == "_sampling_schema": return self.node_seq._sampling_schema elif name == "reset": # stellargraph did not implement reset for its generators # return a zero-argument lambda that returns None return lambda: None elif name == "graph": return self.generator.graph else: try: return self.__dict__[name] except: raise AttributeError return def nsamples(self): return self.targets.shape[0] def get_y(self): return self.targets def xshape(self): return self[0][0][0].shape[1:] # returns 1st neighborhood only def nclasses(self): return self[0][1].shape[1]
Ancestors
- stellargraph.mapper.node_mappers.NodeSequence
- SequenceDataset
- Dataset
- keras.utils.data_utils.Sequence
Methods
def get_y(self)
-
Expand source code
def get_y(self): return self.targets
def nsamples(self)
-
Expand source code
def nsamples(self): return self.targets.shape[0]
Inherited members