文章目录

  • GAT: 图注意力模型介绍及代码分析
    • 原理
      • 图注意力层(Graph Attentional Layer)
        • 情境一:节点和它的一个邻居
        • 情境二:节点和它的多个邻节点
      • 聚合(Aggregation)
        • 多重聚合
    • 代码分析
      • 图注意力层(Graph Attentional Layer)
        • 初始化参数
        • 前向传播
          • 拼接特征向量的处理
      • GAT(含聚合)
  • 参考资料

GAT: 图注意力模型介绍及代码分析

原理

图注意力层(Graph Attentional Layer)

图中每一个节点都由ddd维实数特征向量表示(比如可以作为节点的嵌入编码),nnn个节点的特征向量可以写成Mn×dM_{n\times d}Mn×d​的矩阵形式。
经过图注意力层以后,输出每一个节点的d′d'd′维新特征并形成矩阵Mn×d′′M'_{n\times d'}Mn×d′′​. 有Wd′×dW^{d'\times d}Wd′×d作为变换矩阵作用于每个节点的特征向量hi(d×1)h_i^{(d\times 1)}hi(d×1)​.

为了保留图的结构信息,GAT的一种计算方式只考虑节点的邻接节点并且认为该节点的每个邻节点对其影响力可以用不同的权重(即注意力系数, attention coefficient)表示,以下用一阶邻节点举例(另一种则是考虑全部的顶点):

情境一:节点和它的一个邻居

假设节点iii有一个邻节点jjj, 经过线性变换以后分别是WhiWh_iWhi​和WhjWh_jWhj​。再假设有一个映射a:Rd′×Rd′→Ra: \R^{d'}\times \R^{d'}\to \Ra:Rd′×Rd′→R, 那么邻节点对该节点的注意力系数是:
eij=a(Whi,Whj)e_{ij}=a(Wh_i,Wh_j) eij​=a(Whi​,Whj​)
GAT的具体实现方式是a∈R1×2d′a\in \R^{1\times 2d'}a∈R1×2d′,(Whi,Whj)=[Whi∣∣Whj](Wh_i,Wh_j) = [Wh_i||Wh_j](Whi​,Whj​)=[Whi​∣∣Whj​](把两个向量合并在一起,形成一个(2d′,1)(2d',1)(2d′,1)维的向量),那么eij=aT[Whi∣∣Whj]e_{ij}=a^\mathsf{T}[Wh_i||Wh_j]eij​=aT[Whi​∣∣Whj​].

情境二:节点和它的多个邻节点

当节点有多个邻节点时,为了避免某个注意力系数的值远大于其他值不便于训练,需要normalization。同时为了泛化模型的拟合能力,对线性变化后的值可以加入非线性激活函数, 最终得到的注意力系数:
αij=exp⁡(LeakyRelu⁡(eij)))∑k∈Niexp⁡(LeakyRelu⁡(eik)\alpha_{ij}=\frac{\exp(\operatorname{LeakyRelu}(e_{ij})))}{\sum_{k\in\mathcal{N_i} }\exp(\operatorname{LeakyRelu}(e_{ik})} αij​=∑k∈Ni​​exp(LeakyRelu(eik​)exp(LeakyRelu(eij​)))​
其中ei∈Nie_i\in \mathcal{N_i}ei​∈Ni​.

聚合(Aggregation)

得出节点及其邻节点的注意力系数以后,就可以用于结合WWW更好地更新h′h'h′了,论文中使用的聚合函数:
hi′=σ(∑j∈NiαijWhj)h'_i = \sigma(\sum_{j\in\mathcal{N_i} }\alpha_{ij}Wh_j) hi′​=σ(j∈Ni​∑​αij​Whj​)

多重聚合

为了提高聚合器的表现,论文中采用了multi-head attention, 即使用kkk个独立的注意力机制(采用不同的aaa和WWW),然后将得到的结果再次拼接——
hi′=∣∣⁡k=1kσ(∑j∈Niαij(k)W(k)hj)h_i'=\operatorname{||}_{k=1}^k \sigma(\sum_{j\in\mathcal{N_i} }\alpha^{(k)}_{ij}W^{(k)}h_j) hi′​=∣∣k=1k​σ(j∈Ni​∑​αij(k)​W(k)hj​)
这会导致hi′h_i'hi′​有更高的维度(1,kd′)(1,kd')(1,kd′),所以只可以做中间层而不可以做输出层。
所以对于输出层,一种聚合方式是将各注意力机制的h′h'h′平均
Output=h′=σ(1k∑i=1k∑j∈Niαij(k)W(k)hj)\text{Output=}h'=\sigma(\frac{1}{k}\sum_{i=1}^{k}\sum_{j\in\mathcal{N_i} }\alpha^{(k)}_{ij}W^{(k)}h_j) Output=h′=σ(k1​i=1∑k​j∈Ni​∑​αij(k)​W(k)hj​)

代码分析

代码地址(PyTorch版本):https://github.com/Diego999/pyGAT

图注意力层(Graph Attentional Layer)

图注意力层(即上文原理中提到的注意力机制)的功能是接受由各节点特征向量组成的特征矩阵Hn×dH_{n\times d}Hn×d​, 输出新的特征矩阵Hn×d′H_{n\times d'}Hn×d′​.

初始化参数

一共有两组参数,WWW和aaa,需要训练。其中aaa适用于所有的特征向量对。

self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)

前向传播

    def forward(self, h, adj):# 首先对节点的本身特征向量进行线性变换Wh = torch.mm(h, self.W)  # h.shape: (N, in_features), Wh.shape: (N, out_features)a_input = self._prepare_attentional_mechanism_input(Wh)e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))  # 计算未normalized的注意力系数zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)  # 只计算邻接节点attention = F.softmax(attention, dim=1)attention = F.dropout(attention, self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh)# 输出层的self.concat为False, 不进行非线性变化if self.concat:return F.elu(h_prime)else:return h_prime
拼接特征向量的处理
    def _prepare_attentional_mechanism_input(self, Wh):N = Wh.size()[0]  # number of nodesWh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)  # 对每一个特征向量重复N次Wh_repeated_alternating = Wh.repeat(N, 1)  # 将特征矩阵重复N次# 下面得到每个节点和其他所有节点组合并拼接而成的特征向量all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)# all_combinations_matrix.shape == (N * N, 2 * out_features)return all_combinations_matrix.view(N, N, 2 * self.out_features)

GAT(含聚合)

整个GAT的框架就非常直观了,在输入层添加dropout防止过拟合,只有中间层使用了拼接法。

    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropoutself.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.out_att(x, adj))return F.log_softmax(x, dim=1)

参考资料

  • Velicˇkovic, P., Cucurull, G., Casanova, A., and Bengio, Y.: ‘GRAPH ATTENTION NETWORKS’. Proc. ICLR2018
  • 向往的GAT(图注意力模型) - superbrother的文章 - 知乎
    https://zhuanlan.zhihu.com/p/81350196
  • https://github.com/Diego999/pyGAT

GAT: 图注意力模型介绍及PyTorch代码分析相关推荐

  1. GAT:图注意力模型介绍及PyTorch代码分析

    文章目录 1.计算注意力系数 2.聚合 2.1 附录--GAT代码 2.2 附录--相关代码 3.完整实现 3.1 数据加载和预处理 3.2 模型训练 1.计算注意力系数 对于顶点 iii ,通过计算 ...

  2. Graph Attention Network (GAT) 图注意力模型

    文章目录 1. GAT基本原理 1.1 计算注意力系数(attention coefficient) 1.2 特征加权求和(aggregate) 1.3 multi-head attention 2. ...

  3. 【ICLR 2018图神经网络论文解读】Graph Attention Networks (GAT) 图注意力模型

    论文题目:Graph Attention Networks 论文地址:https://arxiv.org/pdf/1710.10903.pdf 论文代码:https://github.com/Peta ...

  4. 数据挖掘期末-图注意力模型

    PyGAT图注意力模型 ​  PyGAT实现的分类器: https://www.aliyundrive.com/s/vfK8ndntpyc   还在发烧,不是特别清醒,就简单写了写.用GAT进行关系预 ...

  5. 深度学习100+经典模型TensorFlow与Pytorch代码实现大合集

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! [导读]深度学习在过去十年获得了极大进展,出现很多新的模型,并且伴随TensorF ...

  6. 特征图注意力_向往的GAT(图注意力模型)

    0 GRAPH ATTENTION NETWORKS的诞生 随着GCN的大红大紫(可以参考如何理解 Graph Convolutional Network(GCN)?),graph领域的deep le ...

  7. 深度学习中一些注意力机制的介绍以及pytorch代码实现

    文章目录 前言 注意力机制 软注意力机制 代码实现 硬注意力机制 多头注意力机制 代码实现 参考 前言 因为最近看论文发现同一个模型用了不同的注意力机制计算方法,因此懵了好久,原来注意力机制也是多种多 ...

  8. 【从零开始学习深度学习】25.卷积神经网络之LeNet模型介绍及其Pytorch实现【含完整代码】

    目录 1. LeNet模型介绍与实现 2. 输入为Fashion-MNIST时各层输出形状 3. 获取Fashion-MNIST数据和并使用LeNet模型进行训练 4.完整代码 之前我们对Fashio ...

  9. 图深度学习,入门教程七,残差多层图注意力模型

    深度学习还没学完,怎么图深度学习又来了?别怕,这里有份系统教程,可以将0基础的你直接送到图深度学习.还会定期更新哦. 主要是基于图深度学习的入门内容.讲述最基本的基础知识,其中包括深度学习.数学.图神 ...

最新文章

  1. 中国平民百姓与富翁的五大差距
  2. 在ubuntu上安装wireshark之后提示Couldn't run /usr/bin/dumpcap in child process:权限不够
  3. 重温1 Android系统架构及版本
  4. 10g中如何修改数据库字符集-2
  5. Docker入门之五数据管理
  6. UCB DS100 讲义《数据科学的原理与技巧》校对活动正式启动 | ApacheCN
  7. DB2数据库连接问题:java.lang.NoClassDefFoundError
  8. 小米电视共享计算机权限,小米电视局域网共享文件 小米盒子局域网共享视频通用方法...
  9. matlab鲍威尔方法求函数,基于MATLAB的鲍威尔法求极值问题
  10. CMD中可执行的结束进程命令
  11. 手机可用熵_思想丨在商言“熵”
  12. NPOI导出Excel自动计算公式问题
  13. 微信公众号推送多图文消息,直接跳转至外部链接(wxJava)
  14. 超详细Redis入门教程——Redis命令(上)
  15. 应用生命周期、页面生命周期、组件生命周期
  16. 手机ANR问题处理方法及策略
  17. CSMA/CD最大/最小帧长 争用期
  18. 2021-09-02 Day17-JS-第七天 Web APIs和DOM
  19. 2012最具有技术影响力本版图书评选
  20. android 身高体重设计,Android开发--身高体重指数(BIM)计算--设计用户界面--指定输入类型(InputType)...

热门文章

  1. 广发聚丰股票型证券投资基金
  2. android判断存储卡,Android中判断SD卡状态
  3. 选课策略——0-1整数规划
  4. 申请评分卡分析及建模
  5. android微信下拉出现小程序,Android仿微信首页下拉显示小程序列表
  6. python爬虫scrapy框架爬取糗妹妹段子首页
  7. 如何将手机中的视频做成动图?手机端视频转gif怎么操作
  8. 盒子里面图片的位置怎么设置_绝地求生视角位置是什么?怎么才能选择正确的设置呢?...
  9. 125张图告诉你全世界最前沿的科学问题
  10. Python变量的引用、标识、相等性 is和==区别