【2022.3.22 更新】

论文方法笔记参考:【论文阅读笔记】(2s-AGCN)Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognit_小吴同学真棒的博客-CSDN博客


目录

NTU RGB+D 120 数据示例

A(Graph)的定义

模型 Model 的定义

模型 Model 的输入 Input

TCN_GCN_unit

unit_tcn

unit_gcn

补充学习

1、nn.BatchNorm1()、nn.BatchNorm2() 和 nn.BatchNorm3() 的定义与区别

2、torch.nn.Parameter() & Variable

3、requires_grad


论文名称:Two-stream adaptive graph convolutional networks for skeleton-based action recognition

论文下载:https://openaccess.thecvf.com/content_CVPR_2019/papers/Shi_Two-Stream_Adaptive_Graph_Convolutional_Networks_for_Skeleton-Based_Action_Recognition_CVPR_2019_paper.pdf

论文代码:https://github.com/lshiwjx/2s-AGCN


NTU RGB+D 120 数据示例


A(Graph)的定义

https://github.com/BizhuWu/2s-AGCN/blob/953c14fc10883cd869646328f5d522e9e9282063/graph/ntu_rgb_d.py#L17

https://github.com/BizhuWu/2s-AGCN/blob/953c14fc10883cd869646328f5d522e9e9282063/graph/tools.py#L22

最开始没搞明白 inward_ori_index 里面的数据是什么意思?后面对照论文里的图搞明白了,就是:每个节点与向心节点的连接对

所以最后的 A 是由 单位矩阵每个节点相邻的向心节点每个节点相邻的离心节点 堆叠而成的。


模型 Model 的定义

https://github.com/BizhuWu/2s-AGCN/blob/master/model/agcn.py

对应论文里的图(模型的定义多了一层 l4 ?):


模型 Model 的输入 Input

https://github.com/BizhuWu/2s-AGCN/blob/master/model/agcn.py

N:Batch Size

C:channels

T:Frames' Number

V:Joints‘ Number

M:People's Number

模型的输入维度(N * M, C, T, V)


TCN_GCN_unit

https://github.com/BizhuWu/2s-AGCN/blob/953c14fc10883cd869646328f5d522e9e9282063/model/agcn.py#L112

每个 TCN_GCN_unit 是由一个 unit_gcn 和 一个 unit_tcn 组成的。还有额外的 residual

其中

  • 如果输入输出通道数一致 且 stride = 1,那么 residual 为传入进来的参数 x 本身
  • 如果输入输出通道数不一致 或者 且 stride ≠ 1,对传入进来的参数做一次 unit_tcn,结合后面的 unit_tcn 代码来看 kernel_size 的值为 1 意味着只是做一个 1*1 的 conv,用来改变通道数量,以便 residual 和 传入进来的参数 x 进行 element-wise 的相加操作。

输入维度为:(N*M,C,T,V)


unit_tcn

https://github.com/BizhuWu/2s-AGCN/blob/953c14fc10883cd869646328f5d522e9e9282063/model/agcn.py#L36

The convolution for the temporal dimension is the same as ST-GCN, i.e., performing the Kt × 1 convolution on the C×T×N feature maps. Both the spatial GCN and temporal GCN are followed by a batch normalization (BN) layer and a ReLU layer. As shown in Fig. 3, one basic block is the combination of one spatial GCN (Convs), one temporal GCN (Convt) and an additional dropout layer with the drop rate set as 0.5. To stabilize the training, a residual connection is added for each block.

这个模块的输入的维度也是(N*M,C,T,V),相当于 CNN 的(N,C,H,W)。unit_tcn 中的 con2dkernel_size(kernel_size, 1)。说明只对 T 这个维度进行卷积

最后一层 relu 在 TCN_GCN_unit 中。


unit_gcn

https://github.com/BizhuWu/2s-AGCN/blob/master/model/agcn.py#L53

输入维度(N*M,C,T,V)

根据论文里的图可以推断:

代码里的 A 对应图中的 Ak

代码里的 self.PA 对应图中的 Bk(这里有点奇怪,明明论文里说的是初始化为 0,但是代码里初始化为 1e-6。。。)

代码里的 self.conv_a 对应图中的 θk

代码里的 self.conv_b 对应图中的 φk

代码里的 self.conv_d 对应图中的 wk


补充学习

1、nn.BatchNorm1()、nn.BatchNorm2() 和 nn.BatchNorm3() 的定义与区别

1d:BatchNorm1d — PyTorch 1.11.0 documentation

2d:BatchNorm2d — PyTorch 1.11.0 documentation

3d:BatchNorm2d — PyTorch 1.11.0 documentation

the Batch Normalization is done over the C dimension

1d:Applies Batch Normalization over a 2D or 3D input

2d:Applies Batch Normalization over a 4D input

3d:Applies Batch Normalization over a 5D input

2、torch.nn.Parameter() & Variable

PyTorch里面的torch.nn.Parameter() - 简书

Pytorch 中的 Tensor , Variable和Parameter区别与联系_念及她名的博客-CSDN博客_pytorch variable和parameter

Variable

Variable 这个数据结构,是为了引入计算图(自动求导),方便构建神经网络。

每一个 Variable 被构建的时候,都包含三个属性:

  • Variable 中所包含的 tensor
  • tensor 的梯度 .grad
  • 何种方式得到这种梯度 .grad_fn

通过调用 backward(),我们可以对某个 Variable(譬如说y)进行一次自动求导,但如果我们再对这个 Variable 进行一次 backward() 操作,会发现程序报错。这是因为 PyTorch 默认做完一次自动求导后,就把计算图丢弃了。我们可以通过设置 retain_graph 来实现多次求导

引入 Parameter 的原因:

  • Variable 默认是不需要求梯度的,那还需要手动设置参数 requires_grad = True
  • Variable 因为要多次反向传播,那么在 backward 的时候还要手动注明参数 w.backward(retain_graph=True)

Parameter

将一个不可训练的类型 Tensor 转换成可以训练的类型 parameter,并将这个 parameter 绑定到这个 module 里面(net.parameter() 中就有这个绑定的 parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个 self.v 变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

Pytorch 主要通过引入 nn.Parameter 类型的变量和 optimizer 机制来解决了这个问题。Parameter 是 Variable 的子类,本质上和后者一样,只不过 parameter 默认是求梯度的,同时一个网络 net 中的 parameter 变量是可以通过 net.parameters() 来很方便地访问到的,只需将网络中所有需要训练更新的参数定义为 Parameter 类型,再佐以 optimizer,就能够完成所有参数的更新了。

3、requires_grad

手把手教你使用PyTorch(2)-requires_grad&computation graph_慕课手记

只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。

而反之,若所有的输入都不需要保存梯度,那么输出的 requires_grad 会自动设置为 False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。

注意:

  • 对于那些要求梯度的 tensor,PyTorch 会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的 tensor 是不需要梯度的
  • 而我们在使用神经网络时,这些全连接层、卷积层等结构的参数都是默认需要梯度的

解读 2s-AGCN 代码相关推荐

  1. python实现胶囊网络_Capsule Network胶囊网络解读与pytorch代码实现

    本文是论文<Dynamic Routing between Capsules>的论文解读与pytorch代码实现. 如需转载本文或代码请联系作者 @Riroaki 并声明. 众所周知,卷积 ...

  2. 解读阿里巴巴 Java 代码规范(2): 从代码处理等方面解读阿里巴巴 Java 代码规范...

    前言 2017 年阿里云栖大会,阿里发布了针对 Java 程序员的<阿里巴巴 Java 开发手册(终极版)>,这篇文档作为阿里数千位 Java 程序员的经验积累呈现给公众,并随之发布了适用 ...

  3. 解读阿里官方代码规范

    2017年开春,阿里对外公布了「阿里巴巴Java开发手册」从头到尾浏览了一遍这份手册之后,感觉很棒.虽然其中的某些观点笔者不能苟同,但大部分的规范还是值得绝大多数程序员学习和遵守的. 笔者将对这份代码 ...

  4. Resnet论文解读与TensorFlow代码分析

    残差网络Resnet论文解读 1.论文解读 博客地址:https://blog.csdn.net/loveliuzz/article/details/79117397 2.理解ResNet结构与Ten ...

  5. WordCount 官方源码解读及工程代码

    一.WordCount是MapReduce分布式计算框架的demo,可以作为MapReduce入门Demo,了解其思想. WordCount是MapReduce计算的官方demo代码,通过解读Word ...

  6. 分枝定界图解(含 Real-Time Loop Closure in 2D LIDAR SLAM论文部分解读及BB代码部分解读)

    分枝定界图解 网上对分枝定界的解读很多都是根据这篇必不可少的论文<Real-Time Loop Closure in 2D LIDAR SLAM>来的. 分枝定界是一种深度优先的树形搜索方 ...

  7. Kaggle神器LightGBM最全解读(附代码说明)!

    AI派干货 来源:Microstrong,编辑:AI有道 本文主要内容概览: 1. LightGBM简介 GBDT (Gradient Boosting Decision Tree) 是机器学习中一个 ...

  8. java车间调度算法_混合算法(GA+TS)求解作业车间调度问题代码解读+完整JAVA代码...

    程序猿声 代码黑科技的分享区 前两篇文章中,我们介绍了FJSP问题,并梳理了一遍HA算法.这一篇文章对小编实现的(很乱很烂的)代码进行简单解读. 往期回顾: 代码下载请关注公众号,后台回复[FJSPH ...

  9. java制作烟花源码_java源码解读-java烟花代码!

    解读JAVA代码 正在自学java,刚入门,图形界面没太看懂 贴一段简单的代码,求逐行解释.. java程序解析 public class text6{ final String color; pub ...

  10. 深入解读GLIDE/PITI代码

    Diffusion Models专栏文章汇总:入门与实战 前言:GLIDE是diffusion models text-to-image的一项非常经典的模型,PITI是一项基于GLIDE的工作,读懂P ...

最新文章

  1. android ViewPager动画的实现原理及效果
  2. Linux——安装FTP服务器
  3. zabbix监控pppoe线路_Zabbix 完整的监控流程
  4. spring AbstractBeanDefinition创建bean类型是动态代理类的方式
  5. mysql之explain详解(分析索引的最佳使用)
  6. 关于SQL中Between语句查询日期的问题
  7. 编译警告级别之重要性
  8. js实现图片压缩上传
  9. 用自己写的六爻装卦程序了占卜一下 2010 年
  10. TS + vue3.2 + vite2 + element-plus 通用弹框组件封装
  11. 【论文翻译】Playing Atari with Deep Reinforcement Learning
  12. 利用matlab批量修改文件名称或后缀
  13. JavaScript实现鼠标点击监听---弹出社会主义核心价值观(面向对象小练习)
  14. java打印 好看的图形_分享java打印简单图形的实现代码
  15. houdini环境变量服务器文件读不了,Windows下在普通命令行窗口里初始化Houdini环境...
  16. 学计算机要学数学么,学计算机数学要求高吗 数学不好怎么办?
  17. SATA硬盘与IDE硬盘的优劣势对比
  18. linux ram大小 arm,在linux / arm下对RAM(无ECC)进行基准测试的最佳方法是什么?
  19. uniPush消息推送 ios证书配置
  20. Android Q行为变更

热门文章

  1. 什么是 TF-IDF 算法?
  2. 【Applied Algebra】物理学中的群论漫谈1:群论基础
  3. 自动点击按钮html,如何自动点击网页按钮
  4. Unity跑酷游戏中的路点生成算法
  5. 如何使用gif制作软件快速合成gif动图....
  6. 新型病毒来了【PcaPatchDbTask】
  7. win10设置默认浏览器
  8. 打工与创业的有什么区别,丨国仁网络资讯
  9. GO 语言离线安装包
  10. Python数据分析实战之葡萄酒质量分析