import random
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from grave import plot_network, use_attributes
from tabulate import tabulate
import torch
from cogdl import oagbert
from cogdl.datasets import build_dataset_from_name
[docs]class Pipeline(object):
def __init__(self, app: str, **kwargs):
self.app = app
self.kwargs = kwargs
def __call__(self, **kwargs):
raise NotImplementedError
[docs]class DatasetPipeline(Pipeline):
def __init__(self, app: str, **kwargs):
super(DatasetPipeline, self).__init__(app, **kwargs)
def __call__(self, dataset, **kwargs):
if isinstance(dataset, str):
dataset = [dataset]
return self._call(dataset, **kwargs)
[docs]class DatasetStatsPipeline(DatasetPipeline):
def __init__(self, app: str, **kwargs):
super(DatasetStatsPipeline, self).__init__(app, **kwargs)
def _call(self, dataset=[], **kwargs):
if isinstance(dataset, str):
dataset = [dataset]
tab_data = []
col_names = [
"Dataset",
"#nodes",
"#edges",
"#features",
"#classes",
"#labeled data",
]
for name in dataset:
dataset = build_dataset_from_name(name)
data = dataset[0]
tab_data.append(
[
name,
data.x.shape[0],
data.edge_index[0].shape[0],
data.x.shape[1],
len(set(data.y.numpy())),
sum(data.train_mask.numpy()),
]
)
print(tabulate(tab_data, headers=col_names, tablefmt="psql"))
return tab_data
[docs]class DatasetVisualPipeline(DatasetPipeline):
def __init__(self, app: str, **kwargs):
super(DatasetVisualPipeline, self).__init__(app, **kwargs)
def _call(self, dataset="cora", seed=-1, depth=3, **kwargs):
if isinstance(dataset, list):
dataset = dataset[0]
name = dataset
dataset = build_dataset_from_name(name)
data = dataset[0]
G = nx.Graph()
edge_index = torch.stack(data.edge_index)
G.add_edges_from([tuple(edge_index[:, i].numpy()) for i in range(edge_index.shape[1])])
if seed == -1:
seed = random.choice(list(G.nodes()))
q = [seed]
node_set = set([seed])
node_index = {seed: 0}
max_index = 1
for _ in range(depth):
nq = []
for x in q:
for key in G[x].keys():
if key not in node_set:
nq.append(key)
node_set.add(key)
node_index[key] = node_index[x] + 1
if len(nq) > 0:
max_index += 1
q = nq
cmap = cm.rainbow(np.linspace(0.0, 1.0, max_index))
for node, index in node_index.items():
G.nodes[node]["color"] = cmap[index]
G.nodes[node]["size"] = (max_index - index) * 50
pic_file = f"{name}.png"
plt.subplots()
plot_network(G.subgraph(list(node_set)), node_style=use_attributes())
plt.savefig(pic_file)
print(f"Sampled ego network saved to {pic_file}")
return q
[docs]class OAGBertInferencePipepline(Pipeline):
def __init__(self, app: str, model: str, **kwargs):
super(OAGBertInferencePipepline, self).__init__(app, model=model, **kwargs)
load_weights = kwargs["load_weights"] if "load_weights" in kwargs else True
self.tokenizer, self.bert_model = oagbert(model, load_weights=load_weights)
def __call__(self, sequence, **kwargs):
tokens = self.tokenizer(sequence, return_tensors="pt", padding=True)
outputs = self.bert_model(**tokens)
return outputs
SUPPORTED_APPS = {
"dataset-stats": {"impl": DatasetStatsPipeline, "default": {"dataset": "cora"}},
"dataset-visual": {"impl": DatasetVisualPipeline, "default": {"dataset": "cora"}},
"oagbert": {"impl": OAGBertInferencePipepline, "default": {"model": "oagbert-v1"}},
}
[docs]def check_app(app: str):
if app in SUPPORTED_APPS:
targeted_app = SUPPORTED_APPS[app]
return targeted_app
raise KeyError("Unknown app {}, available apps are {}".format(app, list(SUPPORTED_APPS.keys())))
[docs]def pipeline(app: str, **kwargs) -> Pipeline:
targeted_app = check_app(app)
task_class = targeted_app["impl"]
default_args = targeted_app["default"]
default_args.update(kwargs)
return task_class(app=app, **default_args)