Source code for datasets.pyg_modelnet

import os
import os.path as osp
import shutil
import glob
import numpy as np

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from cogdl.data import Data, Dataset, download_url
from . import register_dataset


[docs]class ModelNet10(ModelNet): def __init__(self, train): dataset = "ModelNet10" pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): ModelNet(path, "10", transform, pre_transform) super(ModelNet10, self).__init__(path, name="10", train=train, transform=transform, pre_transform=pre_transform)
[docs]class ModelNet40(ModelNet): def __init__(self, train): dataset = "ModelNet40" pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): ModelNet(path, "40", transform, pre_transform) super(ModelNet40, self).__init__(path, name="40", train=train, transform=transform, pre_transform=pre_transform)
[docs]@register_dataset("ModelNet10") class ModelNetData10(ModelNet): def __init__(self): dataset = "ModelNet10" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) self.train_data = ModelNet10(True) self.test_data = ModelNet10(False) self.num_graphs = len(self.train_data) + len(self.test_data) super(ModelNetData10, self).__init__(path, name="10")
[docs] def get_all(self): return self.train_data, self.test_data
[docs] def __getitem__(self, item): if item < len(self.train_data): return self.train_data[item] return self.test_data[item]
[docs] def __len__(self): return len(self.train_data) + len(self.test_data)
@property
[docs] def train_index(self): return 0, len(self.train_data)
@property
[docs] def test_index(self): return len(self.train_data), self.num_graphs
[docs]@register_dataset("ModelNet40") class ModelNetData40(ModelNet): def __init__(self): dataset = "ModelNet40" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) self.train_data = ModelNet40(True) self.test_data = ModelNet40(False) self.num_graphs = len(self.train_data) + len(self.test_data) super(ModelNetData40, self).__init__(path, name="40")
[docs] def get_all(self): return self.train_data, self.test_data
[docs] def __getitem__(self, item): if item < len(self.train_data): return self.train_data[item] return self.test_data[item]
[docs] def __len__(self): return len(self.train_data) + len(self.test_data)
@property
[docs] def train_index(self): return 0, len(self.train_data)
@property
[docs] def test_index(self): return len(self.train_data), self.num_graphs