【GNN】硬核!一文梳理经典图网络模型
作者 | Chilia
哥伦比亚大学 nlp搜索推荐
整理 | NewBeeNLP
图神经网络已经在NLP、CV、搜索推荐广告等领域广泛应用,今天我们就来整体梳理一些经典常用的图网络模型:DeepWalk、GCN、Graphsage、GAT!
1. DeepWalk [2014]
DeepWalk是来解决图里面节点embedding问题的。Graph Embedding技术将图中的节点以低维稠密向量的形式进行表达,要求在原始图中相似(不同的方法对相似的定义不同)的节点其在低维表达空间也接近。得到的表达向量可以用来进行下游任务,如节点分类(node classification),链接预测(link prediction)等。
1.1 DeepWalk 算法原理
虽然DeepWalk是KDD 2014的工作,但却是我们了解Graph Embedding无法绕过的一个方法。
我们都知道在NLP任务中,word2vec是一种常用的word embedding方法,word2vec通过语料库中的句子序列来描述词与词的共现关系,进而学习到词语的向量表示。
DeepWalk的思想类似word2vec,使用图中节点与节点的共现关系来学习节点的向量表示。那么关键的问题就是如何来描述节点与节点的共现关系,DeepWalk给出的方法是使用**随机游走(RandomWalk)**的方式在图中进行节点采样。
RandomWalk是一种可重复访问visited节点的深度优先遍历算法。给定当前访问起始节点,从其邻居中随机采样节点作为下一个访问节点,重复此过程,直到访问序列长度 = K。获取足够数量的节点访问序列后,使用skip-gram进行向量学习,这样能够把握节点的共现信息。这样就得到了每个节点的embedding。
2. GCN [2016]
GCN的概念首次提出于ICLR 2017:SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS。
为什么要用GCN呢?这是因为对于图结构的数据,CNN、RNN都无法解决。
对于图片来说,我们用卷积核来提取特征,这是因为图片有平移不变性:一个小窗口无论移动到图片的哪一个位置,其内部的结构都是一模一样的,因此CNN可以实现参数共享。RNN主要用在NLP这种序列信息上。图片,或者语言,都属于欧式空间的数据,因此才有维度的概念,欧式空间的数据的特点就是结构很规则。
但是图结构(拓扑结构)如社交网络、知识图谱、分子结构等等是十分不规则的,可以认为是无限维的一种数据,所以它没有平移不变性。每一个节点的周围结构可能都是独一无二的,这种结构的数据,就让传统的CNN、RNN瞬间失效。
GCN,图卷积神经网络,实际上跟CNN的作用一样,就是一个特征提取器,只不过它的对象是图。GCN精妙地设计了一种从图数据中提取特征的方法,从而让我们可以使用这些特征去对图数据进行:
节点分类(node classification) 图分类(graph classification) 链接预测(link prediction)
2.1 GCN的核心公式
假设我们手头有一个图,其中有N个节点,每个节点都有自己的特征embedding,我们设这些节点的特征组成一个N×D维的矩阵 ,然后各个节点之间的关系也会形成一个N×N维的矩阵A(就是邻接矩阵)
GCN也是一个神经网络层,它的层与层之间的传播方式是:
这个公式中:
, 是单位矩阵。 是度矩阵(degree matrix),D[i][i]就是节点i的度。 H是每一层的特征,对于第一层(输入层)的话,就是矩阵 。 σ是非线性激活函数
用这个公式就可以很好地提取图的特征。假设我们构造一个两层的GCN,激活函数分别采用ReLU和Softmax,则整体的正向传播的公式为:
其中, .
那么, 为什么这个公式能提取图的特征呢?
A+I 其实是保证对于每个节点,都能够关注到其所有邻居节点和自己的embedding。 左右乘上度矩阵D是为了对A做一个标准化处理,让A的每一行加起来都是1.
当然,原论文中用非常复杂的数学公式做了很多证明,由于笔者数学不好,只能如此不求甚解的来粗略理解,感兴趣的同学可以自行阅读原论文。
3. GraphSAGE
3.1. GCN的局限
GCN本身有一个局限,即没法快速表示新节点。GCN需要把所有节点都参与训练(整个图都丢进去训练)才能得到node embedding,如果新node来了,没法得到新node的embedding。所以说,GCN是transductive的。(Transductive任务是指:训练阶段与测试阶段都基于同样的图结构)
而GraphSAGE是inductive的。inductive任务是指:训练阶段与测试阶段需要处理的graph不同。通常是训练阶段只是在子图(subgraph)上进行,测试阶段需要处理未知的顶点。
要想得到新节点的表示,需要让新的node或者subgraph去和已经优化好的node embedding去“对齐”。然而每个节点的表示都是受到其他节点的影响(牵一发而动全身),因此添加一个节点,意味着许许多多与之相关的节点的表示都应该调整。
3.2 GraphSAGE
针对这种问题,GraphSAGE模型提出了一种算法框架,可以很方便地得到新node的表示。
3.2.1 Embedding generation(前向传播算法)
Embedding generation算法共聚合K次,总共有K个聚合函数(aggregator),可以认为是K层,来聚合邻居节点的信息。假如 用来表示第k层每个节点的embedding,那么如何 从 得到呢?
就是初始的每个节点embedding。 对于每个节点v,都把它随机采样的若干邻居的k-1层的所有向量表示 、以及节点v自己的k-1层表示聚合成一个向量,这样就得到了第层的表示 。这个聚合方法具体是怎么做的后面再详细介绍。
文中描述如下:
随着层数K的增加,可以聚合越来越远距离的信息。这是因为,虽然每次选择邻居的时候就是从周围的一阶邻居中均匀地采样固定个数个邻居,但是由于节点的邻居也聚合了其邻居的信息,这样,在下一次聚合时,该节点就会接收到其邻居的邻居的信息,也就是聚合到了二阶邻居的信息了。这就像社交图谱中“朋友的朋友”的概念。
3.2.2 聚合函数选择
Mean Pooling:
这个比较好理解,就是当前节点v本身和它所有的邻居在k-1层的embedding的mean,然后经过MLP+sigmoid
LSTM Aggregator:把当前节点v的邻居随机打乱,输入到LSTM中。作者的想法是说LSTM的模型capacity更强。但是节点周围的邻居明明是没有顺序的,这样做似乎有不妥。 Pooling Aggregator:
把节点v的所有邻居节点都单独经过一个MLP+sigmoid得到一个向量,最后把所有邻居的向量做一个element-wise的max-pooling。
3.2.3 GraphSAGE的参数学习
GraphSAGE的参数就是聚合函数的参数。为了学习这些参数,需要设计合适的损失函数。
对于无监督学习,设计的损失函数应该让临近的节点的拥有相似的表示,反之应该表示大不相同。所以损失函数是这样的:
其中,节点v是和节点u在一定长度的random walk上共现的节点,所以它们的点积要尽可能大;后面这项是采了Q个负样本,它们的点积要尽可能小。这个loss和skip-gram中的negative sampling如出一辙。
对于有监督学习,可以直接使用cross-entropy loss等常规损失函数。当然,上面的这个loss也可以作为一个辅助loss。
3.3 和GCN的关系
原始GCN的方法,其实和GraphSAGE的Mean Pooling聚合方法是类似的,即每一层都聚合自己和自己邻居的归一化embedding表示。而GraphSAGE使用了其他capacity更大的聚合函数而已。
此外,GCN是一口气把整个图都丢进去训练,但是来了一个新的节点就不免又要把整个图重新训一次。而GraphSAGE则是在增加了新的节点之后,来增量更新旧的节点,调整整张图的embedding表示。只是生成新节点embedding的过程,实施起来相比于GCN更加灵活方便了。
4. GAT (Graph Attention Network)
4.1 GAT的具体做法
对于每个节点,注意力其在邻居顶点上的注意力。对于顶点 ,逐个计算它的邻居们和它自己之间的相似系数:
首先一个共享参数 的线性映射对于顶点的特征进行了增维,当然这是一种常见的特征增强(feature augment)方法;之后,对变换后的特征进行了拼接(concatenate);最后 a(·)把拼接后的高维特征映射到一个实数上,作者是通过单层的MLP来实现的。
然后,再对此相关系数用softmax做归一化:
最后,根据计算好的注意力系数,把特征加权求和一下。这也是一种aggregation,只是和GCN不同,这个aggregation是带注意力权重的。
就是输出的节点的embedding,融合了其邻居和自身带注意力的权重(这里的注意力是self-attention)。
为了增强特征提取能力,用multi-head attention来进化增强一下:
4.2 与GCN的联系
GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算)。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为 顶点特征之间的相关性被更好地融入到模型中。
GAT适用于有向图。这是因为GAT的运算方式是逐顶点的运算(node-wise),每一次运算都需要循环遍历图上的所有顶点来完成。逐顶点运算意味着,摆脱了拉普利矩阵的束缚,使得有向图问题迎刃而解。也正因如此,GAT适用于inductive任务。与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。
- END -
往期精彩回顾
适合初学者入门人工智能的路线及资料下载 (图文+视频)机器学习入门系列下载 中国大学慕课《机器学习》(黄海广主讲) 机器学习及深度学习笔记等资料打印 《统计学习方法》的代码复现专辑 AI基础下载 机器学习交流qq群955171419,加入微信群请扫码: