Source code for cogdl.data.data

import re

import torch


[docs]class Data(object): r"""A plain old python object modeling a single graph with various (optional) attributes: Args: x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_index (LongTensor, optional): Graph connectivity in COO format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) edge_attr (Tensor, optional): Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) y (Tensor, optional): Graph or node targets with arbitrary shape. (default: :obj:`None`) pos (Tensor, optional): Node position matrix with shape :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) The data object is not restricted to these attributes and can be extented by any other additional data. """ def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None): self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos @staticmethod
[docs] def from_dict(dictionary): r"""Creates a data object from a python dictionary.""" data = Data() for key, item in dictionary.items(): data[key] = item return data
[docs] def __getitem__(self, key): r"""Gets the data of the attribute :obj:`key`.""" return getattr(self, key)
[docs] def __setitem__(self, key, value): """Sets the attribute :obj:`key` to :obj:`value`.""" setattr(self, key, value)
@property
[docs] def keys(self): r"""Returns all names of graph attributes.""" keys = [key for key in self.__dict__.keys() if self[key] is not None] keys = [key for key in keys if key[:2] != '__' and key[-2:] != '__'] return keys
[docs] def __len__(self): r"""Returns the number of all present attributes.""" return len(self.keys)
[docs] def __contains__(self, key): r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the data.""" return key in self.keys
[docs] def __iter__(self): r"""Iterates over all present attributes in the data, yielding their attribute names and content.""" for key in sorted(self.keys): yield key, self[key]
[docs] def __call__(self, *keys): r"""Iterates over all attributes :obj:`*keys` in the data, yielding their attribute names and content. If :obj:`*keys` is not given this method will iterative over all present attributes.""" for key in sorted(self.keys) if not keys else keys: if self[key] is not None: yield key, self[key]
[docs] def cat_dim(self, key, value): r"""Returns the dimension in which the attribute :obj:`key` with content :obj:`value` gets concatenated when creating batches. .. note:: This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute. """ # `*index*` and `*face*` should be concatenated in the last dimension, # everything else in the first dimension. return -1 if bool(re.search("(index|face)", key)) else 0
[docs] def __inc__(self, key, value): r""""Returns the incremental count to cumulatively increase the value of the next attribute of :obj:`key` when creating batches. .. note:: This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute. """ # Only `*index*` and `*face*` should be cumulatively summed up when # creating batches. return self.num_nodes if bool(re.search('(index|face)', key)) else 0
@property
[docs] def num_edges(self): r"""Returns the number of edges in the graph.""" for key, item in self("edge_index", "edge_attr"): return item.size(self.cat_dim(key, item)) return None
@property
[docs] def num_features(self): r"""Returns the number of features per node in the graph.""" return 1 if self.x.dim() == 1 else self.x.size(1)
@property
[docs] def num_nodes(self): if self.x is not None: return self.x.shape[0] return torch.max(self.edge_index)+1
[docs] def is_coalesced(self): r"""Returns :obj:`True`, if edge indices are ordered and do not contain duplicate entries.""" row, col = self.edge_index index = self.num_nodes * row + col return row.size(0) == torch.unique(index).size(0)
[docs] def apply(self, func, *keys): r"""Applies the function :obj:`func` to all attributes :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to all present attributes. """ for key, item in self(*keys): self[key] = func(item) return self
[docs] def contiguous(self, *keys): r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. If :obj:`*keys` is not given, all present attributes are ensured to have a contiguous memory layout.""" return self.apply(lambda x: x.contiguous(), *keys)
[docs] def to(self, device, *keys): r"""Performs tensor dtype and/or device conversion to all attributes :obj:`*keys`. If :obj:`*keys` is not given, the conversion is applied to all present attributes.""" return self.apply(lambda x: x.to(device), *keys)
[docs] def cuda(self, *keys): return self.apply(lambda x: x.cuda(), *keys)
[docs] def clone(self): return Data.from_dict({k: v.clone() for k, v in self})
[docs] def __repr__(self): info = ["{}={}".format(key, list(item.size())) for key, item in self] return "{}({})".format(self.__class__.__name__, ", ".join(info))