打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
图计算怎么用?——图神经网络
userphoto

2023.10.01 辽宁

关注

前面两篇分别从图数据的两种不同的应用形式,讲解了图数据如何增强传统的风险模型。

  • 图计算怎么用?——图特征工程 图特征工程的方法较多依赖于风险领域的专业知识和业务专家经验,业务专家可借助特征工程的方法,快速搭建图数据增强的风险识别模型。
  •  图计算怎么用?——图表示学习图表示学习则走了类似词向量表示的方向,由无监督的方式为图中节点生成 Embedding 表示,从而将图的拓扑结构信息引入到下游任务中。

传统的 Graph Embedding 生成算法较为通用,生成过程并未将拓扑结构信息和节点、边的属性信息进行有效的融合,且与下游任务是无关的。可见通用的 Graph Embedding 会带来一定的信息增益,但对于特定任务的帮助有一定的局限性。本篇是《图计算怎么用?》系列方法论层面的最后一篇,将重点讲述当前图计算的热门方向——图神经网络方法(Graph Neural Network)的应用。之前在公众号上也单独做过一些经典的图风控方向的工业应用论文解读,这里先做个链接,有空可以回顾回顾: 

一、图学习三大任务

在介绍具体 GNN 方法之前,现讲讲图深度学习的三大类任务,明确我们需要解决的问题目标。图学习三大任务:

  1. 节点分类(Node Classification):给定图结构、节点属性和边属性,部分节点标签,预测其他节点的标签;这是个半监督的学习任务(Semi-Supervised),在反欺诈场景很常见,已知一部分欺诈用户(bad)和部分好用户(good),其他都是大量的未知用户(unknown),需要预测这些 unknown 标签客户中,是欺诈客户(is bad?)的概率。
  2. 边分类/边预测(Link Prediction):给定图结构、节点属性和边属性,预测两个节点之间产生连接的可能性;这个任务的实例,来自于推荐场景,已知历史用户(User)和物品(Item)交互的二部图(bipartite),需要给用户推荐感兴趣的下一个 Item,问 User 和 某 Item 之间是否有边存在。
  3. 图分类(Graph Classification):给定图结构(此时可以有多个图结构)、节点属性和边属性,部分图的标签,预测其他图的标签;图分类的任务,在互联网和金融领域都较为少见,我们更多是做的零售的场景,一般用户、实体都是作为节点的存在,预测节点相关的属性、连接。前几年比较火的 DeemMind 发布的 AlphaFold 是生物分子领域的一些深度学习应用,作为蛋白质、化学分子结构,图分类任务预测其功效和属性就成为一个典型的例子。

我们在金融风控场景用的最多的还是图节点分类任务,给出用户节点违约或者欺诈的概率预测,供业务决策使用。但并不是说,风控场景的问题不适合抽象成为边预测和图分类任务,例如:

  • 在风控中也被用来做信号的提纯(并非单纯的一个模型),可看看 GEM 中的做法
  • 金融场景见过的是——度小满在 人行征信数据挖掘上的 GNN 模型的尝试(并未见论文)

二、图神经网络方法

GNN的方法是我最不想讲的 Topic,就像深度学习的推荐算法一样,各式各样的 Model 和结构适配特定场景任务,不是越复杂越fancy的就一定好。另外,一提到基础算法就头大的问题在于,如何尽量通俗易懂、公式符合含量低的呈现出来。

图神经网络(Graph Neural Network,GNN)方法,和诸多的神经网络(DNN)方法一样,可以把其结构简单分为两块——特征部分和预测任务,GNN 的特征部分精髓在于消息传递(Message Passing)的模式。

和 CNN 的卷积类似,节点 的邻居子图(k-hop)的信息(包括邻居节点、边的属性,邻接关系)与节点 强相关。MPNN 范式是一种用节点 的邻居子图信息,表征中心节点的方式。 

Message Passing: 

消息传递神经网络(Message Passing Neural Network,MPNN)范式,是一种空域(Spatial Domain)的表达。对于古老一点的谱域(Spectral Domain)表达(e.g. GCN),图傅里叶变换,图信号处理等概念,这里不做过多的赘述(需要大量的公式和数学推导)。

这里用经典的 GraphSAGE 模型来举例 Message Passing,其他的空域 GNN 算法都是类似的。

  1. 图中红色 Node 代表目标中心节点,为其生成表征;
  2. 需要其 1阶()和 2阶()的邻居的信息(这里以2阶举例,深层的 GNN 有 over smoothing 的问题,又是个别样的话题);
  3. 考虑到邻居节点数量随着 k-hop 的深度会指数增长,GraphSAGE 采用了固定数量的邻居采样;
  4. 被采样到的邻居节点,会向上一级节点传递消息,直到全部信息经 2 级的操作,都汇聚到中心节点上;
  5. 这样图中的每个节点的表征,都是基于其自身属性(的部分)与其邻居信息(的部分)组合得到,兼顾了自身属性信息、邻居属性信息、邻居关系信息,并且邻居的关联与属性提前进行了融合;
  6. 相对于传统的图表示学习(无监督、下游任务无关的),GNN 使用 MPNN 得到的表征,直接参与下游任务的训练,监督信号会反向指导预测任务部分和特征部分的权重,即会影响到邻居聚合的权重,以适配特定任务的数据。

其他的空域 GNN 算法基本都是在 MPNN 范式下,设计特定的邻居聚合方式,例如:

  • GAT(Graph Attention Network):Attention Is All You Need!不是所有的邻居都一般重要,自然的想法就是引入注意力机制,来选择对下游任务有效的部分。当然后续也出了 GATv2 来弥补一些 Attention 权重同质化的问题。

  • SGC(Simplified Graph Convolution):常规的 MPNN 可以理解成为基于邻接矩阵的特征传播(EP)和基于非线性激活函数的特征变换(ET),并且特征传播与特征变换是匹配的,即一次特征传播紧接着一次特征变换。有实验性研究(腾讯 《Evaluating Deep Graph Neural Networks》)表明,GNN 的 over smoothing 问题主要来自多个图卷积层的特征变换(ET 操作)。SGC 就是将 EP 和 ET 解耦,做深特征传播以求使用更深邻居信息,提升模型效果的目的。

三、图神经网络实例

本来想在这个 Section 讲个 GNN 的工业化实例的,比如 PinSage 就是首个将 GraphSAGE 工业应用,网上这方面的介绍也很多。另外,面向特定任务设计的 GNN 结构,更多的是给予一定的启发,对自己的任务,我认为还是得基于数据分析——关联关系是否有强信号?有没有挖掘的机会和空间?先回答这些问题,再去考虑如何设计和改造已有的 GNN 算法,来解决你的问题;如果回到是否定的,不需要用图计算,那么“抓到老鼠就是好猫”。

我们都绕不开的是,如何将 GNN 算法给 run 起来,包括模型训练 和 模型推理这些基本要求。快速构建 GNN 的模型项目,并使用测试数据进行测试。通常会先做算法 POC,验证通过后再实际启动项目,要不一开始整图数据,体量大占资源较多,后续风险也很高。

GNN 构建工具——DGL

当然大厂有自己的图平台,更多都是支持分布式的,例如腾讯的 Angel,阿里的 GraphScope,字节也说有 等等。如果你的公司还没有支持 GNN 的平台,但你总有个支持 Machine Learning 和 Deep Learning 的平台吧,此时无需图数据库等,就可以支持 GNN 的算法开发和部署。

个人用过比较好的是 AWS 上海 AI 开发的 DGL (Deep Graph Library,https://www.dgl.ai/),这个在上期的图表示学习有过介绍,这里再详细讲讲。DGL 是建立在主流深度学习框架之上的 GNN 库:

  • 后端支持 PyTorch、TensorFlow和 MXNet
  • 支持 CPU 单机,多机,GPU 单机多卡、多机多卡 多种计算模式
  • 封装好了众多 SOTA 的 GNN 模块,方便进行组装和算法改造
  • 最重要的就是文档友好(https://docs.dgl.ai/index.html),社区活跃,方便上手应用
  • 版本的更新较为及时,对一些重要功能的完善和对前沿算法的支撑,还是走在了前列
  • 在 DGL 库之上,也提供了用户友好的命令行调用框架 DGL-GO(https://github.com/dmlc/dgl/tree/master/dglgo),支持快速的 GNN 实验
image.png

DGL 模型示例

以下展示个使用 DGL 工具,做 GCN 模型的案例,找找感觉 。

数据的定义,DGL 内置了一些测试数据,也支持外部导入和自己构建,这里使用了 Core 数据集。图的全部信息都存储在 g 中,包括拓扑关系、节点和边的属性。

dataset = dgl.data.CoraGraphDataset()
print(f'Number of categories: {dataset.num_classes}')
g = dataset[0]
print('Node features')
print(g.ndata)
print('Edge features')
print(g.edata)

DGL 对 GNN 模块的支持 

在 dgl.nn 中,提供了大量的基础模型(https://docs.dgl.ai/api/python/nn-pytorch.html,dgl.nn 是区分后端的,需要先配置 backend)

from dgl.nn import GraphConv

class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes)

def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h

模型训练过程与常规的 Deep Learning 较为类似。支持 Full Batch 和 Mini-Batch 不同的训练方式。

# Create the model with given dimensions
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)

def train(g, model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_val_acc = 0
best_test_acc = 0

features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
for e in range(100):
# Forward
logits = model(g, features)

# Compute prediction
pred = logits.argmax(1)

# Compute loss
# Note that you should only compute the losses of the nodes in the training set.
loss = F.cross_entropy(logits[train_mask], labels[train_mask])

# Compute accuracy on training/validation/test
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

# Save the best validation accuracy and the corresponding test accuracy.
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

if e % 5 == 0:
print(
f'In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})'
)


model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

DGL 的学习途径

  • A Blitz Introduction to DGL(https://docs.dgl.ai/tutorials/blitz/index.html),提供了6个入门的 case
  • DGL User Guide(https://docs.dgl.ai/guide/index.html),支持中英文,非常详细的文档了
  • Learn By Example(https://github.com/dmlc/dgl/tree/master/examples/pytorch),大量的用 DGL 开发的模型,都可以下载下来尝试和改造

GNN 测试数据集 —— OGB

实验需要数据集的支持,准确的说需要 Benchmark 数据集,大家都在上面实验过,有一些历史参考数据。类似 CV 领域的 ImageNet 项目,在 Stanford 大牛 Jure 牵头组织下,成立了 OGB(Open Graph Benchmark,https://ogb.stanford.edu/)项目,旨在针对图学习的三大任务,提供可靠的 Benchmark 数据集。

a diverse set of challenging and realistic benchmark datasets to facilitate scalable, robust, and reproducible graph machine learning (ML) research

OGB 数据集的来源广泛,涉及的领域很多,数据集的规模也有较为合理的梯度分布。同时提供 Leaderboard,鼓励大家在 GNN 算法上创新,通过标准 Benchmark 数据集上的实战效果,来推动 GNN 领域的发展。
虽然是个偏向学术研究的数据集,对于 GNN 算法的测试,也是个可用的基本盘,在实操工业数据集之前,不妨一试。笔者去年11月在 OGB 上刷过 No.10,今天看了下掉到了 No.14。另外,测试算法还是需要规模稍大一些的数据集,在跑得动且能快速完成的基础上,会提前帮你解决一些工程上的问题,比如提高训练的速度,检验代码中的问题,以及对模型性能有个预判。

最后的总结

写这个系列的原因,主要有两点:

  • 其一:下面这位银行哥们问了我个问题,当然我的回答有点草率。整理了下我过往在图计算方面,经历过的,看到过的,和想到过的,算是个完整点的回应。
  • 其二:在我现在的银行里,做这些“高大上”算法的团队,都较难生存(这个在第一篇里面有讲述,落地难的问题),当然也有一些需要跟传统思想持久战的因素。对结果负责当然是一种职业素养,但对理念的追求也不能停止。所以虽然现在我没在继续做图的工作,写下这些内容,也作为一种对他们的声援,一种坚持。

最后,祝大家双节愉快!

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
周伟鹏:如何用GNN来提升关系图谱的反欺诈效果?
10行代码搞定图Transformer,图神经网络框架DGL迎来1.0版本
开源图深度学习框架的机遇与挑战(消息函数的目的是在边上通过输入边上的源节点以及目标节点以及其边上的特征来计算出要传递怎样的消息)
图神经网络框架DGL v0.4.3 新版本发布
深度学习之上,图神经网络(GNN )崛起
2022简单易懂「图神经网络前沿进展与应用」中文综述
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服