Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition

文章下载地址:https://arxiv.org/abs/1801.07455
代码下载地址:https://github.com/yysijie/st-gcn?

1 首先看下论文中的邻接矩阵怎么实现的

class Graph():def __init__(self,layout='openpose',strategy='uniform',max_hop=1,dilation=1):self.max_hop = max_hopself.dilation = dilation# 只有下面三个主要方法# get_edge得到节点边的链接信息self.get_edge(layout)# 得到跳跃距离 也就是论文中的distance partitioning strategyself.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop)# 最后得到邻接矩阵 Aself.get_adjacency(strategy)

下面来看看三个方法具体实现

1 get_dege这个方法很简单不同数据 划分的人体骨骼节点不同 链接也不同

也就是得到相连接的边信息和中心点

    def get_edge(self, layout):if layout == 'openpose':self.num_node = 18self_link = [(i, i) for i in range(self.num_node)]neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,11),(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]self.edge = self_link + neighbor_linkself.center = 1elif layout == 'ntu-rgb+d':self.num_node = 25self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),(22, 23), (23, 8), (24, 25), (25, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 21 - 1elif layout == 'ntu_edge':self.num_node = 24self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),(23, 24), (24, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 2# elif layout=='customer settings'#     passelse:raise ValueError("Do Not Exist This Layout.")

2 get_hop_distance 对应论文中的 distance partioning和Spatial configuration partitioning

简单来说对节点进行划分

Uni-labeling. 平均权重

distance partioning 简单将节点分成两部分

Spatial configuration partitioning是将节点分成三部分

Uni-labeling. The simplest and most straight forward partition strategy is to have subset, which is the whole neighbor set itself. In this strategy, feature vectors on every neighboring node will have a inner product with the same weight vector. Actually, this strategy resembles the propagation rule introduced in (Kipf and Welling 2017).

Distance partitioning. Another natural partitioning strategy is to partition the neighbor set according to the nodes’ distance d(·, vti) to the root node vti. In this work, because we set D = 1, the neighbor set will then be separated into two subsets, where d = 0 refers to the root node itself and remaining neighbor nodes are in the d = 1 subset. Thus we will have two different weight vectors and they are capable of modeling local differential properties such as the relative translation between joints. Formally, we have K = 2 and lti(vtj ) = d(vtj , vti) .

Spatial configuration partitioning. Since the body skeleton is spatially localized, we can still utilize this specific spatial configuration in the partitioning process. We design a strategy to divide the neighbor set into three subsets: 1) the root node itself; 2)centripetal group: the neighboring nodes that are closer to the gravity center of the skeleton than the root node; 3) otherwise the centrifugal group. Here the average coordinate of all joints in the skeleton at a frame is treated as its gravity center. This strategy is inspired by the fact that motions of body parts can be broadly categorized as concentric and eccentric motions. Formally,

def get_hop_distance(num_node, edge, max_hop=1):A = np.zeros((num_node, num_node))for i, j in edge:# 得到邻接矩阵 A 对称阵A[j, i] = 1A[i, j] = 1# compute hop stepshop_dis = np.zeros((num_node, num_node)) + np.inf#  If n == 0, the identity matrix of the same shape as M is returned# 返回两个矩阵# d == 0 返回单位矩阵 表示 自己和自己链接# d == 1 返回本身 就是邻接矩阵 A本身transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]arrive_mat = (np.stack(transfer_mat) > 0)# d: 1, 0# 经过此循环 变成  如果 两个节点相邻 距离为1 如果不相邻 距离为 inf 无穷远# 对于节点自己和自己的距离变成0 也就是root节点for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = dreturn hop_dis

3 get_adjacency 

根据不同的划分策略得到不同的邻接矩阵

    def get_adjacency(self, strategy):# self.dilation = 1 self.max_hop = 1# 其中dilation=1 表示只考虑相连的节点valid_hop = range(0, self.max_hop + 1, self.dilation)adjacency = np.zeros((self.num_node, self.num_node))for hop in valid_hop:# 得到一个邻接矩阵 相连的节点为1 root节点也为1 和 hop_dis的区别就在 root节点的值 还有剩下的节点值为0 hop_dis中为infadjacency[self.hop_dis == hop] = 1#  这里是做矩阵的归一化也就是用度矩阵做归一化normalize_adjacency = normalize_digraph(adjacency)if strategy == 'uniform':# 这个划分策略表示Uni-labeling# partitioning strategy, where all nodes in a neighborhood has the same label# 根据论文中所述:feature vectors on every neighboring node will have a inner product with the same weight vectorA = np.zeros((1, self.num_node, self.num_node))A[0] = normalize_adjacencyself.A = Aelif strategy == 'distance':# 这个就是distance partitioning# 将节点分成两部分# where d = 0 refers to the root node itself and# remaining neighbor nodes are in the d = 1 subset.# shape (2, num_node, num_node)A = np.zeros((len(valid_hop), self.num_node, self.num_node))for i, hop in enumerate(valid_hop):# hop == 0 : 从hop_dis中取出节点指等于0的赋值  也就是root 对应root node it self# hop == 1 : 从hop_dis中取出节点值等于1的赋值 也就是neighbor node 相连的节点A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==hop]self.A = Aelif strategy == 'spatial':# 最后一个空间划分策略# 将节点 分成三部分# 1) the root node itself;# 2)centripetal group: the neighboring nodes# that are closer to the gravity center of the skeleton than the root node;# 3) otherwise the centrifugal group# 这里用一个数组存储A = []for hop in valid_hop:# root nodea_root = np.zeros((self.num_node, self.num_node))# the neighboring nodes that are closer to the gravity centera_close = np.zeros((self.num_node, self.num_node))# otherwise the centrifugal groupa_further = np.zeros((self.num_node, self.num_node))# 下面分析怎么实现的# 0 if rj = ri# 1 if rj < ri# 2 if rj > rifor i in range(self.num_node):for j in range(self.num_node):# 这个if 表示取出有效值 hop_dis中的 0, 1 也就是有边链接关系的节点包括root node itselfif self.hop_dis[j, i] == hop:if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:# 这里就是root节点赋值# 当hop == 0 时 进入此if的都是root itself i == j 表示根节点# hop == 1 时 进入这里的表示 i, j 有连接 但是和center没有连接 infa_root[j, i] = normalize_adjacency[j, i]elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:# the neighboring nodes that are closer to the gravity center# 表示 i 到center的距离比 j 到center的距离近# hop == 1 进入此条件语句# 当 hop_dis[j, self.center] == inf hop_dis[i, self.center] == 1.0# 或者 hop_dis[j, self.center] == 1.0 hop_dis[i, self.center] == 0.0 都可以进入此条件语句a_close[j, i] = normalize_adjacency[j, i]else:# otherwise the centrifugal groupa_further[j, i] = normalize_adjacency[j, i]if hop == 0:A.append(a_root)else:A.append(a_root + a_close)A.append(a_further)# 最终拼成一个三维矩阵当作权重输入模型# shape (3, num_node, num_node)# A[0] 有root节点还有和center相连的节点赋予权重值(也就是距离值)# A[1] (a_root + a_close)在A[0]上增加了比root距离中心点近的权重值# A[2] 就是比root距离中心点远的权重值A = np.stack(A)self.A = Aelse:raise ValueError("Do Not Exist This Strategy")

4 矩阵的归一化

def normalize_digraph(A):# 得到每个节点的度Dl = np.sum(A, 0)num_node = A.shape[0]Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:# 由每个点的度组成的对角矩阵Dn[i, i] = Dl[i] ** (-1)AD = np.dot(A, Dn)return AD

最后分析下网络结构

class Model(nn.Module):def __init__(self, in_channels, num_class, graph_args,edge_importance_weighting, **kwargs):super().__init__()# load graph# 加载grap 和 邻接矩阵self.graph = Graph(**graph_args)# 转为tensorA = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)# 向模块添加持久缓冲区。self.register_buffer('A', A)# build networks# (3, 18, 18)# 构建网络# spatial_kernel_size 和 distance partition strategy 相关spatial_kernel_size = A.size(0)temporal_kernel_size = 9# shape (9, 3)kernel_size = (temporal_kernel_size, spatial_kernel_size)# 数据先经过bnself.data_bn = nn.BatchNorm1d(in_channels * A.size(1))kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}# 构建 stgcn block 空间时间卷积self.st_gcn_networks = nn.ModuleList((st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 128, kernel_size, 2, **kwargs),st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 256, kernel_size, 2, **kwargs),st_gcn(256, 256, kernel_size, 1, **kwargs),st_gcn(256, 256, kernel_size, 1, **kwargs),))# initialize parameters for edge importance weighting# 初始化边权重参数 可学习参数# 就是论文中 Learnable edge importance weighting.#  we add# a learnable mask M on every layer of spatial temporal graph# convolution.# 具体实现就是和A相乘if edge_importance_weighting:self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size()))for i in self.st_gcn_networks])else:# else 让边权重不可学习 设置成1 表示权重都一样self.edge_importance = [1] * len(self.st_gcn_networks)# fcn for prediction# 最后接一个1*1的卷积用来进行分类self.fcn = nn.Conv2d(256, num_class, kernel_size=1)def forward(self, x):# N.代表视频的数量,通常一个 batch 有 256 个视频(其实随便设置,最好是 2 的指数)。# C 代表关节的特征,通常一个关节包含  等 3 个特征(如果是三维骨骼就是 4 个)。# T代表关键帧的数量,一般一个视频有 150 帧。# V 代表关节的数量,通常一个人标注 18 个关节。# M代表一帧中的人数,一般选择平均置信度最高的 2 个人。# data normalizationN, C, T, V, M = x.size()# (N, M, V, C, T)x = x.permute(0, 4, 3, 1, 2).contiguous()x = x.view(N * M, V * C, T)# 进行数据归一化x = self.data_bn(x)# 重新reshapex = x.view(N, M, V, C, T)# (N, M, C, T, V)x = x.permute(0, 1, 3, 4, 2).contiguous()# (N * M, C, T, V)x = x.view(N * M, C, T, V)# forwad# 经过时间空间block卷积for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):x, _ = gcn(x, self.A * importance)# x shape (n * m , c, t, v)# global pooling# x shape (n * m, c, 1, 1)x = F.avg_pool2d(x, x.size()[2:])# shape (n, 1, c, 1, 1)x = x.view(N, M, -1, 1, 1).mean(dim=1)# prediction# shape (n, m, class_num, 1, 1)x = self.fcn(x)# shape (n, class_num)x = x.view(x.size(0), -1)return x

下面重点看下时间空间卷积的实现

class st_gcn(nn.Module):def __init__(self,in_channels,out_channels,kernel_size, #(9, 3)stride=1,dropout=0,residual=True):super().__init__()assert len(kernel_size) == 2assert kernel_size[0] % 2 == 1padding = ((kernel_size[0] - 1) // 2, 0)# 空间卷积self.gcn = ConvTemporalGraphical(in_channels, out_channels,kernel_size[1])# (N * M, C, T, V)# 时空卷积 在T的维度上进行卷积# 简单的 bn + relu + conv2d + bn dropout# kernel_size: (9, 1)尽在T维度上进行卷积self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,(kernel_size[0], 1),(stride, 1),padding,),nn.BatchNorm2d(out_channels),nn.Dropout(dropout, inplace=True),)if not residual:self.residual = lambda x: 0elif (in_channels == out_channels) and (stride == 1):self.residual = lambda x: xelse:self.residual = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=(stride, 1)),nn.BatchNorm2d(out_channels),)self.relu = nn.ReLU(inplace=True)def forward(self, x, A):# x shape (N * M, C, T, V)# A shape (3, 18, 18)# in_channels != out_channels这里相进行一个conv2d + bn# in_channels == out_channels 没进行任何处理res = self.residual(x)# 进行空间卷积x, A = self.gcn(x, A)# 时间维度上进行卷积x = self.tcn(x) + resreturn self.relu(x), A

最后看下ConvTemporalGraphical的实现

class ConvTemporalGraphical(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,t_kernel_size=1,t_stride=1,t_padding=0,t_dilation=1,bias=True):super().__init__()# kernel_size = 3 和distance partitioning strategy相关self.kernel_size = kernel_size# 首先定义就是很简单的1*1的卷积核 相当于只是在节点上增加特征维度self.conv = nn.Conv2d(in_channels,out_channels * kernel_size,kernel_size=(t_kernel_size, 1),padding=(t_padding, 0),stride=(t_stride, 1),dilation=(t_dilation, 1),bias=bias)def forward(self, x, A):# 根据划分策略 A.size[0]和self.kernel_size相等assert A.size(0) == self.kernel_size# (N * M, C, T, V)x = self.conv(x)n, kc, t, v = x.size()# shape (n, 3, c, t, v)x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)# 这个公式可以理解为根据邻接矩阵中的邻接关系做了一次邻接节点间的特征融合,# 输出就变回了(N *M, C, T, V)的格式进入tcn# 这里就是简单的矩阵乘法 实现很简单x = torch.einsum('nkctv,kvw->nctw', (x, A))return x.contiguous(), A

总结

从代码上我们可以看到

在空间维度上的卷积只做了两步处理

1 用一个1*1的kernel_size做卷积增加特征维度

2 然后和邻接矩进行特征融合

从这里我们可以发现一个很关键的问题 ,如果两个节点未相连 然后A[i,j]=0

和我们的特征融合后也是0 这会导致如果两个节点没有相连那么不能学习他们的相关性

比如手和脚 其实是有很大关联的  在stgcn中是无法体现这种关联关系

所以才有后续针对stgcn的优化网络

STGCN的源码分析相关推荐

  1. 【Golang源码分析】Go Web常用程序包gorilla/mux的使用与源码简析

    目录[阅读时间:约10分钟] 一.概述 二.对比: gorilla/mux与net/http DefaultServeMux 三.简单使用 四.源码简析 1.NewRouter函数 2.HandleF ...

  2. SpringBoot-web开发(四): SpringMVC的拓展、接管(源码分析)

    [SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) SpringBoot-web开发(二): 页面和图标定制(源码分析) SpringBo ...

  3. SpringBoot-web开发(二): 页面和图标定制(源码分析)

    [SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) 目录 一.首页 1. 源码分析 2. 访问首页测试 二.动态页面 1. 动态资源目录t ...

  4. SpringBoot-web开发(一): 静态资源的导入(源码分析)

    目录 方式一:通过WebJars 1. 什么是webjars? 2. webjars的使用 3. webjars结构 4. 解析源码 5. 测试访问 方式二:放入静态资源目录 1. 源码分析 2. 测 ...

  5. Yolov3Yolov4网络结构与源码分析

    Yolov3&Yolov4网络结构与源码分析 从2018年Yolov3年提出的两年后,在原作者声名放弃更新Yolo算法后,俄罗斯的Alexey大神扛起了Yolov4的大旗. 文章目录 论文汇总 ...

  6. ViewGroup的Touch事件分发(源码分析)

    Android中Touch事件的分发又分为View和ViewGroup的事件分发,View的touch事件分发相对比较简单,可参考 View的Touch事件分发(一.初步了解) View的Touch事 ...

  7. View的Touch事件分发(二.源码分析)

    Android中Touch事件的分发又分为View和ViewGroup的事件分发,先来看简单的View的touch事件分发. 主要分析View的dispatchTouchEvent()方法和onTou ...

  8. MyBatis原理分析之四:一次SQL查询的源码分析

    上回我们讲到Mybatis加载相关的配置文件进行初始化,这回我们讲一下一次SQL查询怎么进行的. 准备工作 Mybatis完成一次SQL查询需要使用的代码如下: Java代码   String res ...

  9. [转]slf4j + log4j原理实现及源码分析

    slf4j + log4j原理实现及源码分析 转载于:https://www.cnblogs.com/jasonzeng888/p/6051080.html

最新文章

  1. 面试---如何在List<Integer>中如何存放String类型的数据?
  2. 阅读 Linux 内核源码——共享内存
  3. 看到一个沙粒世界:再一次你好世界
  4. 第四节: EF调用存储过程的通用写法和DBFirst模式子类调用的特有写法
  5. Linux如何从普通用户切换到root用户
  6. Swift使用CoreLocation,你必须要看这一篇
  7. layui让文字和div平行_layui富文本的使用注意事项以及拓展
  8. 【网络安全工程师面试合集】安全角度谈UDP、TCP和DHCP协议
  9. 测试之CR规范及错误列表
  10. 大厂面试 | 阿里巴巴大数据工程师面试题汇总
  11. mybatis与spring结合
  12. Python学习之学校教学(辨别身份证的真伪,并判断性别)
  13. vue krpano 视角监听
  14. Android PAI (PlayAutoInstall)功能一些经验
  15. [论文阅读] (15)英文SCI论文审稿意见及应对策略学习笔记总结(letpub爬虫)
  16. 《物联网实战指南》读书笔记
  17. 高效学习与高度自律的可行性
  18. 2023年最新Kali安装教程(超详细,手把手教你下载安装kali虚拟机)
  19. 面对电信运营商HTTP劫持如何是好,投诉太折腾,不如路由器直接屏蔽广告源
  20. GreenPlum 时间转换函数

热门文章

  1. 利用动态气泡图进行数据分析
  2. 利用API读取日文输入方法表(Romaji-Kana conversion table)
  3. 读书:SQL必知必会
  4. 战地五自定义服务器在哪里,《战地5》推出自定义私人服务器!基础类型免费开放...
  5. 硬汉内贾德:让美国人战栗(推荐)
  6. C语言实现双向非循环链表(不带头结点)的基本操作
  7. 微博自定义来源怎么去掉android,手把手教大家如何修改微博来源
  8. 中国IC设计Fabless排行榜 TOP100
  9. php创建一个猫咪,html5的应用-画一个可爱的小猫咪效果图
  10. 基层管理之正负向激励