用 MPNN 框架解读 GAT

Justin Gilmer 提出了 MPNN(Message Passing Neural Network)框架1 ,用于描述被用来做图上的监督学习的图神经网络模型。我发现这是一个很好用的框架,可以很好理解不同的 GNN 模型是如何工作的,方便快速弄清楚不同的 GNN 模型之间的差别。我们考虑图 $G$ 上的一个节点 $v$,它的向量表示 $h_v$ 的更新方式如下: $$m_v^{t+1}=\sum_{u\in \mathcal{N}(v)}M_t(h_v^t,h_u^t,e_{vu})$$ $$h_v^{t+1}=U_t(h_v^t,m_v^{t+1})$$

其中

  • $u$ 为 $v$ 的邻居节点,$\mathcal{N}(v)$ 则表示节点 $v$ 的所有邻居
  • $e_{vu}$ 是可选项,表示边上的特征(若有)
  • $M_t$ 是消息函数,$m_v^{t+1}$ 就是所有邻居节点的消息聚合之后的结果
  • $U_t$ 是节点更新函数

在更新完图上所有节点的向量表示之后,我们可能需要要做图级别的分类任务,在 MPNN 框架中对应为: $$\hat y=R({h_v^T|v\in G})$$

其中:

  • $R$ 为图读出函数,输入是图上所有节点的向量表示,输出为一个特征向量用于图级别的分类任务

🧐 我发现将论文的公式和具体的代码联系起来总是能够帮助理解,因此下面关于 GAT 公式的讲解我会在用 MPNN 框架的基础上,同时附上相关的代码(用 ... 表示其他省略的部分)。代码来自于 DGL 官方提供的 GAT 模块的源码

🧐 GAT 可以方便堆叠多层,下面的讨论都是从某一层 $l$某个节点 $v$ 的角度来谈的

$$h_v^{l}=W^lh_v^{l}$$

设每个节点的向量表示的长度为 $F$,在第一步中,对图上每个结点的向量先做一个线性变换,其中 $W\in\mathcal{R}^{F’\times F}$, 因此每个节点的向量都更新长度为 $F’$。为了区分不同层的向量,用上角标 $l$ 表示这是第 $l$ 层的。注意在第 $l$ 层中,所有节点共用同一个权重矩阵 $W$

📒 注意后面的 $h_v^l$ 或者 $h_u^l$ 都是经过线性变换之后的

python

class GATConv(nn.Module):
    def __init__(self, ...):
        ...
        self.fc = nn.Linear(
            self._in_src_feats, out_feats * num_heads, bias=False
        )
        ...
    def forward(self, graph, feat, ...):
        """
        Args
        ----
            feat: (N, *, D_in) where D_in is the size of input feature

        Returns
        -------
            torch.Tensor
                (N, *, num_heads, D_out)

        """
        ...
        src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
        h_src = h_dst = self.feat_drop(feat)
        # h_src: (N, *, D_in)
        feat_src = feat_dst = self.fc(h_src).view(
            *src_prefix_shape, self._num_heads, self._out_feats
        )
        # feat_src/feat_dst: (N, *, num_heads, out_feat)
        ...

注意,上面代码中的 h_src 会出来两个一样的 feat_srcfeat_dst 是因为 DGL 采用了数学上等价但是计算效率会更高的实现。后面会解释

$$e_{vu}^l=LeakyReLU\Big((a^l)^T[h_v^{l}||h_u^{l}]\Big)$$

$$\alpha_{vu}^l=Softmax_u(e_{vu}^l)$$

第二步就是要计算中心节点 $v$ 和它的所有邻居节点之间的注意力了。在上面的公式中:

  • $e_{vu}^l$ 表示注意力系数(attention coefficient)。论文中提到完全可以采用不同的注意力计算方式,在 GAT 的论文中,作者选择用一层前馈神经网络计算注意力2注意这里的 $e_{vu}^l$ 跟 MPNN 的 $e_{vu}$ 是没有关系的,只是恰好记号一样了而已
  • $||$ 表示拼接操作,即我们会将中心节点和它对应的邻居结点的向量表示拼接起来,得到一个长度为 $2F’$ 的向量,对应公式中的 $[h_v^{l}||h_u^{l}]$,然后把它送入到前面提到的一层前馈神经网络中,对应公式中的 $(a^l)^T[h_v^{l}||h_u^{l}]$,其中 $(a^l)^T$ 指的是第 $l$ 层的单层前馈神经网络的参数
  • 激活函数选用的是 $LeakyReLU$
  • 最后在节点 $v$ 的所有邻居节点之间用 Softmax 做归一化

🤔️ 步骤一和步骤二对应 MPNN 框架的 $m_v^{t+1}$ 的计算

正如 Transformer 中有多头注意力一样,GAT 的作者同样在节点更新的时候采用了多头注意力的机制: $$h_v^{l+1}= ||^{K^l} \sigma(\sum_{u\in\mathcal{N}(i)}\alpha_{vu}^{(k,l)}W^{(k,l)}h_u^{l})$$

记号越来越复杂了,但是仔细思索一番还是可以理清楚的,上角标 $(k,l)$ 的意思是第 $l$ 层的第 $k$ 个头。其中 $K^l$ 为第 $l$ 层“头”的数量。上面的公式的意思就是每个“头”会计算出一个向量表示,然后不同“头”之间的会拼接起来

python

class GATConv(nn.Module):
    def __init__(self, ...):
        ...
        self.attn_l = nn.Parameter(
            th.FloatTensor(size=(1, num_heads, out_feats))
        )
        self.attn_r = nn.Parameter(
            th.FloatTensor(size=(1, num_heads, out_feats))
        )
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        ...

    def forward(self, graph, feat, ...):
        """
        Args
        ----
            feat: (N, *, D_in) where D_in is the size of input feature

        Returns
        -------
            torch.Tensor
                (N, *, num_heads, D_out)

        """
        # feat_src/feat_dst: (N, *, num_heads, out_feat)
        el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
        er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
        # el/er: (N, *, num_heads, 1)
        graph.srcdata.update({"ft": feat_src, "el": el})
        graph.dstdata.update({"er": er})
        graph.apply_edges(fn.u_add_v("el", "er", "e"))
        e = self.leaky_relu(graph.edata.pop("e"))
        # e: (N, *, num_heads, 1)

        # normalization
        graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
        # a: (N, *, num_heads, 1)

        # weighted sum
        graph.update_all(
            # ft: (N, *, num_heads, out_feat)
            # a: (N, *, num_heads, 1)
            # m: (N, *, num_heads, out_feat)
            fn.u_mul_e("ft", "a", "m"), 
            fn.sum("m", "ft")
        )
        rst = graph.dstdata["ft"]
        # rst: (N, *, num_heads, out_feat)
        ...

DGL 的实现基于下面这个事实: $$a^T[h_v||h_u]=a_l^Th_v+a_r^Th_u$$ 为什么会更为高效呢?

  • 不需要存储 $[h_v||h_u]$ 这个中间变量了(DGL 会将消息存储为边的属性)
  • 加法可以用 DGL 优化过的 fn.u_add_v 函数

最后节点更新的方式就是计算邻居的向量表示的加权和(基于前面算好的注意力): $$h_v^{l+1}=\sigma(\sum_{u\in \mathcal{N}(i)}\alpha_{vu}W^{l}h_u^{l})$$

🤔️ 上面就对应 MPNN 框架中的 $h_v^{t+1}$ 的计算,注意 GAT 算 $h_t^{l+1}$ 的时候并不会用到自己上一层表示 $h_t^l$。同时 GAT 的提出是用于解决图上的节点分类问题,因此也没有图读出的操作。

在 GAT 的论文中,作者是要将 GAT 用于节点级别的分类任务2。假设我们堆叠了 $L$ 层的 GAT 之后,在最后一层如果还采用拼接的方式显然是不合理的。因此作者在最后一层 GAT 中,是取多个头的平均值之后才应用了激活函数,这里的激活函数如果采用 Softmax,就可以直接做节点分类了2。公式如下所示: $$final\ embedding\ of\ h_v= \sigma\Big(\frac{1}{K^L}\sum_{k=1}^{K^L}\sum_{u\in\mathcal{N}(i)}\alpha_{vu}^{(k,L)}W^{(k,L)}h_u^{L}\Big)$$

🤔️ 图神经网络的流行框架之一 DGL 的设计就是基于 MPNN 框架,不过他们的公式会稍有不同,他们还有一个聚合函数 $\rho$,用于决定一个节点如何聚合从邻居那边收到的所有信息。我认为他们的公式更具有泛化性,能够适用于更多种情况。他们还贴心地写了关于如何使用 DGL 的 MPNN 相关函数的教程,推荐一看👍👍👍

至于 GAT 的实现,DGL 不仅提供了 GAT 模块,而且还写了一篇不错的用自带的 message_funcreduce_func 手动实现 GAT 的教程。一个完整的 GAT 任务训练脚本可以看这里

以上就是如何用 MPNN 框架解释 GAT 的方法,其中加上了 DGL 的源码分析进行解释,用注意力计算节点之间的关系看起来是一个很自然的思路,可以看成是对 GCN 的一种泛化。GAT 能够很好学习到图的局部结构表示,而且计算注意力的方式可以并行,是十分高效的🍻🍻🍻


  1. Gilmer J, Schoenholz S S, Riley P F, et al. Neural message passing for quantum chemistry[C]//International conference on machine learning. PMLR, 2017: 1263-1272. arXiv ↩︎

  2. Veličković P, Cucurull G, Casanova A, et al. Graph attention networks[J]. arXiv preprint arXiv:1710.10903, 2017. arXiv ↩︎ ↩︎ ↩︎