GNN-4-dataset class and practice of node prediction and edge prediction tasks

Reference open source learning address: datawhale

1. Introduction to inmemorydataset base class

In PyG, we inherit InMemoryDataset Class to customize a dataset class that can store all data in memory.

class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)

Official InMemoryDataset documentation: torch_geometric.data.InMemoryDataset

As above InMemoryDataset As shown in the constructor interface of class, each dataset should have a root folder, which indicates where the dataset should be saved. There are at least two folders under the root directory:

  • One folder is raw_dir, which is used to store unprocessed files. Data set files downloaded from the network will be stored here;
  • The other folder is processed_dir, where the processed data set is saved.

In addition, inheritance InMemoryDataset Class can pass a transform function and a pre_transform function and a pre_filter functions, which are all None by default.

  • The transform function takes the Data object as a parameter and returns it after conversion. This function is called every time Data is accessed, so it should be used for Data Augmentation.
  • pre_ The transform function accepts [data]( https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data )Object is a parameter, which is returned after conversion. This function is in sample [data `]( https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data The object is called before it is saved to the file, so it is best used for a large amount of precomputation that only needs to be done once.
  • pre_ The filter function can manually filter out data objects before saving. One use case of this function is to filter sample categories.

To create a InMemoryDataset , we need to implement four basic methods:

  • raw_file_names() This is an attribute method that returns a list of file names. The file should be available in raw_dir folder, otherwise call the process() function to download the file to raw_dir folder.
  • processed_file_names() . This is an attribute method that returns a list of file names. The file should be processed_dir folder. Otherwise, call the process() function to preprocess the sample and save it to processed_dir folder.
  • download() : download the original data file to raw_dir folder.
  • process() : preprocess the sample and save it to processed_dir folder.

Sample conversion from original file to Data The procedure of the class object is defined in the process function

import torch
from torch_geometric.data import InMemoryDataset, download_url

class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root=root, transform=transform, pre_transform=pre_transform, pre_filter=pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In this function, sometimes we need to read and create a Data Object and save it to processed_dir. Since it is quite slow for python to save a huge list, we passed it before saving collate() The function assembles the list into a huge Data Object. The function also returns a slice dictionary to reconstruct a single sample from this object. Finally, we need to load the data object and the slice dictionary into the attribute self Data and self Slices.

2. Practice of node prediction and edge prediction tasks

2.1 node prediction task practice

Neural network model

class GAT(torch.nn.Module):
    def __init__(self, num_features, hidden_channels_list, num_classes):
        super(GAT, self).__init__()
        torch.manual_seed(12345)
        hns = [num_features] + hidden_channels_list
        conv_list = []
        for idx in range(len(hidden_channels_list)):
            conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
            conv_list.append(ReLU(inplace=True),)

        self.convseq = Sequential('x, edge_index', conv_list)
        self.linear = Linear(hidden_channels_list[-1], num_classes)

    def forward(self, x, edge_index):
        x = self.convseq(x, edge_index)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear(x)
        return x

dataset.num_features: 500

GAT( (convseq): Sequential(
(0): GATConv(500, 200, heads=1)
(1): ReLU(inplace=True)
(2): GATConv(200, 100, heads=1)
(3): ReLU(inplace=True) )
(linear): Linear(in_features=100, out_features=3, bias=True)
)

2.2 side prediction task practice

Edge prediction task: If yes, predict whether there are edges between two nodes.
When we get a graph data set, we have the node characteristic matrix x and the edge information of which nodes have edges_ index. edge_index stores positive samples. In order to build edge prediction tasks, we need to generate some negative samples, that is, sample some node pairs without edges as negative sample edges, and the positive and negative samples should be balanced.
In addition, the samples should be divided into three sets: training set, verification set and test set.

PyG provides us with a ready-made method, train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1), and the first parameter is torch_geometric.data.Data object, the second parameter is the proportion of verification set, and the third parameter is the proportion of test set. The function will automatically sample the negative samples, and divide the positive and negative samples into three sets: training set, verification set and test set. It uses train_pos_edge_index,train_neg_adj_mask,val_pos_edge_index,val_neg_edge_index,test_pos_edge_index and test_neg_edge_index attribute replaces edge_index attribute.

Neural network model

import torch
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

The neural network used for edge prediction is mainly composed of two parts: one is encoding, which is the same as the generation node representation we introduced earlier; The second is decoding, which generates the probability that the edge is true (odds) based on the representation of the nodes at both ends of the edge. decode_all(self, z) is used in the inference stage. We need to predict the probability of edges for all node pairs of input nodes.

Keywords: Python neural networks

Added by djcritch on Wed, 26 Jan 2022 03:33:19 +0200