Training
本文档详细介绍了如何使用TuGraph进行图神经网络(GNN)的训练。
1. 训练
使用TuGraph 图学习模块进行训练时,可以分为全图训练和mini-batch训练。 全图训练即把全图从TuGraph db加载到内存中,再进行GNN的训练。而mini-batch训练则使用上面提到的TuGraph 图学习模块的采样算子,将全图数据进行采样后,再送入训练框架中进行训练。
2. Mini-Batch训练
Mini-Batch训练需要使用TuGraph 图学习模块的采样算子,目前支持Neighbor Sampling、Edge Sampling、Random Walk Sampling和Negative Sampling。 TuGraph 图学习模块的采样算子进行采样后的结果以List的形式返回。 下面以Neighbor Sampling为例,介绍如何将采样后的结果,进行格式转换,送入到训练框架中进行训练。 用户需要提供一个Sample类:
class TuGraphSample(object):
    def __init__(self, args=None):
        super(TuGraphSample, self).__init__()
        self.args = args
    def sample(self, g, seed_nodes):
        args = self.args
        # 1. 加载图数据
        galaxy = PyGalaxy(args.db_path)
        galaxy.SetCurrentUser(args.username, args.password)
        db = galaxy.OpenGraph(args.graph_name, False)
        sample_node = seed_nodes.tolist()
        length = args.randomwalk_length
        NodeInfo = []
        EdgeInfo = []
        # 2. 采样方法,结果存储在NodeInfo和EdgeInfo中
        if args.sample_method == 'randomwalk':
            randomwalk.Process(db, 100, sample_node, length, NodeInfo, EdgeInfo)
        elif args.sample_method == 'negative':
            negativesample.Process(db, 100)
        else:
            neighborsample(db, 100, sample_node, args.nbor_sample_num, NodeInfo, EdgeInfo)
        del db
        del galaxy
        # 3. 对结果进行格式转换,使之符合训练格式
        remap(EdgeInfo[0], EdgeInfo[1], NodeInfo[0])
        g = dgl.graph((EdgeInfo[0], EdgeInfo[1]))
        g.ndata['feat'] = torch.tensor(NodeInfo[1])
        g.ndata['label'] = torch.tensor(NodeInfo[2])
        return g
如代码所示,首先将图数据加载到内存中。然后使用采样算子对图数据进行采样,结果存储在NodeInfo和EdgeInfo中。NodeInfo和EdgeInfo是python list结果,其存储的信息结果如下:
| 图数据 | 存储信息位置 | 
|---|---|
| 边起点 | EdgeInfo[0] | 
| 边终点 | EdgeInfo[1] | 
| 顶点ID | NodeInfo[0] | 
| 顶点特征 | NodeInfo[1] | 
| 顶点标签 | NodeInfo[2] | 
最后对结果进行格式转换,使之符合训练格式。这里我们使用的是DGL训练框架,因此使用结果数据构造了DGL Graph,最终将DGL Graph返回。 我们提供TuGraphSample类之后,就可以使用它进行Mini-Batch训练了。 令DGL的数据加载部分使用TuGraphSample的实例sampler:
    sampler = TugraphSample(args)
    fake_g = construct_graph() # just make dgl happy
    dataloader = dgl.dataloading.DataLoader(fake_g,
        torch.arange(train_nids),
        sampler,
        batch_size=batch_size,
        device=device,
        use_ddp=True,
        num_workers=0,
        drop_last=False,
        )
使用DGL进行模型训练:
def train(dataloader, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    model.train()
    s = time.time()
    for graph in dataloader:
        load_time = time.time()
        graph = dgl.add_self_loop(graph)
        logits = model(graph, graph.ndata['feat'])
        loss = loss_fcn(logits, graph.ndata['label'])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_time = time.time()
        print('load time', load_time - s, 'train_time', train_time - load_time)
        s = time.time()
    return float(loss)
3. 全图训练
GNN(图神经网络)的全图训练是一种涉及一次处理整个训练数据集的训练。它是 GNN 最简单、最直接的训练方法之一,整个图被视为单个实例。 在全图训练中,整个数据集被加载到内存中,模型在整个图上进行训练。这种类型的训练对于中小型图特别有用,并且主要用于不随时间变化的静态图。 在算子调用时,使用以下方式:
getdb.Process(db, olapondb, feature_len, NodeInfo, EdgeInfo)
获取全图数据,然后将全图送入训练框架中进行训练。 完整代码:请参考learn/examples/train_full_cora.py。