Module ktrain.graph.sg_wrappers

Expand source code
from ..dataset import SequenceDataset
from ..imports import *

# import stellargraph
try:
    import stellargraph as sg
    from stellargraph.mapper import link_mappers, node_mappers
except:
    raise Exception(SG_ERRMSG)
if version.parse(sg.__version__) < version.parse("0.8"):
    raise Exception(SG_ERRMSG)


class NodeSequenceWrapper(node_mappers.NodeSequence, SequenceDataset):
    def __init__(self, node_seq):
        if not isinstance(node_seq, node_mappers.NodeSequence):
            raise ValueError("node_seq must by a stellargraph NodeSequence object")
        self.node_seq = node_seq
        self.targets = node_seq.targets
        self.generator = node_seq.generator
        self.ids = node_seq.ids
        self.__len__ = node_seq.__len__
        self.__getitem__ = node_seq.__getitem__
        self.on_epoch_end = node_seq.on_epoch_end
        self.indices = node_seq.indices

    def __setattr__(self, name, value):
        if name == "batch_size":
            self.generator.batch_size = value
        elif name == "data_size":
            self.node_seq.data_size = value
        elif name == "shuffle":
            self.node_seq.shuffle = value
        elif name == "head_node_types":
            self.node_seq.head_node_types = value
        elif name == "_sampling_schema":
            self.node_seq._sample_schema = value
        else:
            self.__dict__[name] = value
        return

    def __getattr__(self, name):
        if name == "batch_size":
            return self.generator.batch_size
        elif name == "data_size":
            return self.node_seq.data_size
        elif name == "shuffle":
            return self.node_seq.shuffle
        elif name == "head_node_types":
            return self.node_seq.head_node_types
        elif name == "_sampling_schema":
            return self.node_seq._sampling_schema
        elif name == "reset":
            # stellargraph did not implement reset for its generators
            # return a zero-argument lambda that returns None
            return lambda: None
        elif name == "graph":
            return self.generator.graph
        else:
            try:
                return self.__dict__[name]
            except:
                raise AttributeError
        return

    def nsamples(self):
        return self.targets.shape[0]

    def get_y(self):
        return self.targets

    def xshape(self):
        return self[0][0][0].shape[1:]  # returns 1st neighborhood only

    def nclasses(self):
        return self[0][1].shape[1]


class LinkSequenceWrapper(link_mappers.LinkSequence, SequenceDataset):
    def __init__(self, link_seq):
        if not isinstance(link_seq, link_mappers.LinkSequence):
            raise ValueError("link_seq must by a stellargraph LinkSequence object")
        self.link_seq = link_seq
        self.targets = link_seq.targets
        self.generator = link_seq.generator
        self.ids = link_seq.ids
        self.__len__ = link_seq.__len__
        self.__getitem__ = link_seq.__getitem__
        self.on_epoch_end = link_seq.on_epoch_end
        self.indices = link_seq.indices

    def __setattr__(self, name, value):
        if name == "batch_size":
            self.generator.batch_size = value
        elif name == "data_size":
            self.link_seq.data_size = value
        elif name == "shuffle":
            self.link_seq.shuffle = value
        elif name == "head_node_types":
            self.link_seq.head_node_types = value
        elif name == "_sampling_schema":
            self.link_seq._sample_schema = value
        else:
            self.__dict__[name] = value
        return

    def __getattr__(self, name):
        if name == "batch_size":
            return self.generator.batch_size
        elif name == "data_size":
            return self.link_seq.data_size
        elif name == "shuffle":
            return self.link_seq.shuffle
        elif name == "head_node_types":
            return self.link_seq.head_node_types
        elif name == "_sampling_schema":
            return self.link_seq._sampling_schema
        elif name == "reset":
            # stellargraph did not implement reset for its generators
            # return a zero-argument lambda that returns None
            return lambda: None
        elif name == "graph":
            return self.generator.graph
        else:
            try:
                return self.__dict__[name]
            except:
                raise AttributeError
        return

    def nsamples(self):
        return self.targets.shape[0]

    def get_y(self):
        return self.targets

    def xshape(self):
        return self[0][0][0].shape[1:]  # returns 1st neighborhood only

    def nclasses(self):
        return 2
        return self[0][1].shape[1]

Classes

class LinkSequenceWrapper (link_seq)

Keras-compatible data generator to use with Keras methods :meth:keras.Model.fit_generator, :meth:keras.Model.evaluate_generator, and :meth:keras.Model.predict_generator This class generates data samples for link inference models and should be created using the :meth:flow method of :class:GraphSAGELinkGenerator or :class:HinSAGELinkGenerator or :class:Attri2VecLinkGenerator.

Args

generator
An instance of :class:GraphSAGELinkGenerator or :class:HinSAGELinkGenerator or
:class:Attri2VecLinkGenerator.
ids : list or iterable
Link IDs to batch, each link id being a tuple of (src, dst) node ids. (The graph nodes must have a "feature" attribute that is used as input to the GraphSAGE/Attri2Vec model.) These are the links that are to be used to train or inference, and the embeddings calculated for these links via a binary operator applied to their source and destination nodes, are passed to the downstream task of link prediction or link attribute inference. The source and target nodes of the links are used as head nodes for which subgraphs are sampled. The subgraphs are sampled from all nodes.
targets : list or iterable
Labels corresponding to the above links, e.g., 0 or 1 for the link prediction problem.
shuffle : bool
If True (default) the ids will be randomly shuffled every epoch.
Expand source code
class LinkSequenceWrapper(link_mappers.LinkSequence, SequenceDataset):
    def __init__(self, link_seq):
        if not isinstance(link_seq, link_mappers.LinkSequence):
            raise ValueError("link_seq must by a stellargraph LinkSequence object")
        self.link_seq = link_seq
        self.targets = link_seq.targets
        self.generator = link_seq.generator
        self.ids = link_seq.ids
        self.__len__ = link_seq.__len__
        self.__getitem__ = link_seq.__getitem__
        self.on_epoch_end = link_seq.on_epoch_end
        self.indices = link_seq.indices

    def __setattr__(self, name, value):
        if name == "batch_size":
            self.generator.batch_size = value
        elif name == "data_size":
            self.link_seq.data_size = value
        elif name == "shuffle":
            self.link_seq.shuffle = value
        elif name == "head_node_types":
            self.link_seq.head_node_types = value
        elif name == "_sampling_schema":
            self.link_seq._sample_schema = value
        else:
            self.__dict__[name] = value
        return

    def __getattr__(self, name):
        if name == "batch_size":
            return self.generator.batch_size
        elif name == "data_size":
            return self.link_seq.data_size
        elif name == "shuffle":
            return self.link_seq.shuffle
        elif name == "head_node_types":
            return self.link_seq.head_node_types
        elif name == "_sampling_schema":
            return self.link_seq._sampling_schema
        elif name == "reset":
            # stellargraph did not implement reset for its generators
            # return a zero-argument lambda that returns None
            return lambda: None
        elif name == "graph":
            return self.generator.graph
        else:
            try:
                return self.__dict__[name]
            except:
                raise AttributeError
        return

    def nsamples(self):
        return self.targets.shape[0]

    def get_y(self):
        return self.targets

    def xshape(self):
        return self[0][0][0].shape[1:]  # returns 1st neighborhood only

    def nclasses(self):
        return 2
        return self[0][1].shape[1]

Ancestors

Methods

def get_y(self)
Expand source code
def get_y(self):
    return self.targets
def nsamples(self)
Expand source code
def nsamples(self):
    return self.targets.shape[0]

Inherited members

class NodeSequenceWrapper (node_seq)

Keras-compatible data generator to use with the Keras methods :meth:keras.Model.fit_generator, :meth:keras.Model.evaluate_generator, and :meth:keras.Model.predict_generator.

This class generated data samples for node inference models and should be created using the .flow(…) method of :class:GraphSAGENodeGenerator or :class:DirectedGraphSAGENodeGenerator or :class:HinSAGENodeGenerator or :class:Attri2VecNodeGenerator.

GraphSAGENodeGenerator, DirectedGraphSAGENodeGenerator,and HinSAGENodeGenerator are classes that capture the graph structure and the feature vectors of each node. These generator classes are used within the NodeSequence to generate samples of k-hop neighbourhoods in the graph and to return to this class the features from the sampled neighbourhoods.

Attri2VecNodeGenerator is the class that captures node feature vectors of each node.

Args

generator
GraphSAGENodeGenerator, DirectedGraphSAGENodeGenerator or HinSAGENodeGenerator or Attri2VecNodeGenerator. The generator object containing the graph information.
ids
list A list of the node_ids to be used as head-nodes in the downstream task.
targets
list, optional (default=None) A list of targets or labels to be used in the downstream class.
shuffle : bool
If True (default) the ids will be randomly shuffled every epoch.
Expand source code
class NodeSequenceWrapper(node_mappers.NodeSequence, SequenceDataset):
    def __init__(self, node_seq):
        if not isinstance(node_seq, node_mappers.NodeSequence):
            raise ValueError("node_seq must by a stellargraph NodeSequence object")
        self.node_seq = node_seq
        self.targets = node_seq.targets
        self.generator = node_seq.generator
        self.ids = node_seq.ids
        self.__len__ = node_seq.__len__
        self.__getitem__ = node_seq.__getitem__
        self.on_epoch_end = node_seq.on_epoch_end
        self.indices = node_seq.indices

    def __setattr__(self, name, value):
        if name == "batch_size":
            self.generator.batch_size = value
        elif name == "data_size":
            self.node_seq.data_size = value
        elif name == "shuffle":
            self.node_seq.shuffle = value
        elif name == "head_node_types":
            self.node_seq.head_node_types = value
        elif name == "_sampling_schema":
            self.node_seq._sample_schema = value
        else:
            self.__dict__[name] = value
        return

    def __getattr__(self, name):
        if name == "batch_size":
            return self.generator.batch_size
        elif name == "data_size":
            return self.node_seq.data_size
        elif name == "shuffle":
            return self.node_seq.shuffle
        elif name == "head_node_types":
            return self.node_seq.head_node_types
        elif name == "_sampling_schema":
            return self.node_seq._sampling_schema
        elif name == "reset":
            # stellargraph did not implement reset for its generators
            # return a zero-argument lambda that returns None
            return lambda: None
        elif name == "graph":
            return self.generator.graph
        else:
            try:
                return self.__dict__[name]
            except:
                raise AttributeError
        return

    def nsamples(self):
        return self.targets.shape[0]

    def get_y(self):
        return self.targets

    def xshape(self):
        return self[0][0][0].shape[1:]  # returns 1st neighborhood only

    def nclasses(self):
        return self[0][1].shape[1]

Ancestors

Methods

def get_y(self)
Expand source code
def get_y(self):
    return self.targets
def nsamples(self)
Expand source code
def nsamples(self):
    return self.targets.shape[0]

Inherited members