前言:

看了很多关于Stacked Hourglass网络模型的解析的博文,大多数都是只是对模型结构的分析,或者是对论文的翻译,但是自己根据论文以及去复现的时候却遇到很多问题,比如计算loss时要如何计算,acc如何计算,生成的heatmap如何转为关键点等等小问题。这些问题十分重要,但是在论文以及博文中并没有详细的讲到如何处理。因此自己根据在复现的经历,写这一系列博客进行记录,希望能帮到需要的朋友。

  • 该系列围绕下图进行记录,初步打算分 网络结构训练细节运行demo这三个主题进行编写,后期可能会随时改动。

  • 当前已完成至训练细节的loss计算,有些细节还未扣清楚,之后会一点点更新

  • 我的复现进度:mxnet版Stacked Hourglass

概览

模形都由子模块一点点拼凑成一个大网络,该篇也是打算从子模块讲起,再讲讲子模块组成的大一点点的模块,再讲大一点点的模块构成的整体网络。整体模块的组成结构如下图


当然这只是大概的结构,在各个大小模块的组成之间还有一些小的细节。下面开始各个模块的介绍。

Redidual


我在这里贴了俩张图,第二张是我自己画的,感觉会比上一张详细,但是不如上一张明了。

  1. 图一中蓝色的代表BatchNormal 紫色的代表Relu
  2. 图二中的参数是通过官方源码获取的,所以跟着图片设置就好啦

总体上来看,Redidual模块分为俩条支路,一条是convBlock,一条是skipLayer,很显然就是残差网络的那个思想。convBlock对输入的数据进行深度的卷积,以便获取深层次的图像信息,最后再将原来输入的数据经过skipLayer进行element add并输出,这样可以将深层次的信息和原始信息进行相结合,获得更好的效果。

对于convBlock,无脑照着参数卷积就是了
对于skipLayer,当输入数据的channel与Redidual模块要求的output_channel不一致的时候要使用1*1的conv卷积核进行卷积,以便后面的element_add能加的起来

以下是mxnet版本的代码:

class Residual(nn.HybridBlock):def __init__(self, in_channels, out_channels, **kwargs):super(Residual, self).__init__(**kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.residual_conv = nn.HybridSequential()self.residual_skip = nn.HybridSequential()with self.residual_conv.name_scope():# 卷积路self.residual_conv.add(nn.BatchNorm())self.residual_conv.add(nn.Activation('relu'))self.residual_conv.add(nn.Conv2D(self.out_channels // 2, (1, 1)))self.residual_conv.add(nn.BatchNorm())self.residual_conv.add(nn.Activation('relu'))self.residual_conv.add(nn.Conv2D(self.out_channels // 2, (3, 3), (1, 1), (1, 1)))self.residual_conv.add(nn.BatchNorm())self.residual_conv.add(nn.Activation('relu'))self.residual_conv.add(nn.Conv2D(self.out_channels, (1, 1)))# 连接路if not self.in_channels == self.out_channels:with self.residual_skip.name_scope():self.residual_skip.add(nn.Conv2D(self.out_channels, (1, 1)))def hybrid_forward(self, F, x):    temp_x = xx = self.residual_conv(x)if not self.in_channels == self.out_channels:x = x + self.residual_skip(temp_x)else:x = x + temp_xreturn x

Hourglass

这个模块就比较灵活了,是整个模型的核心,是由Redidual为基础模块构成。这个模块还有一个阶数n,不同的阶数有不同的特征。

一阶Hourglass长这样

二阶Hourglass长这样

四阶长这样

各阶区别:

通过观察不难发现,一阶二阶以及多阶的区别就在于虚线框内装的不一样
我在阅读官方源码的时候注意到以下几点:

  1. 绿色的代表Redidual模块,箭头朝下是MaxPool,箭头朝上是UpSamplingNearest
  2. 上面几张图的大部分Redidual块(上方的SkipLayer、下采样部分前后)都是3*Redidual构成,但是在官方代码中,这个的数量通过一个叫opt.nModules参数进行设置的,图片中为3,官方源码为1,本文以及复现代码按照官方源码的1来设置(下面的图会略微不同)。
  3. 官方源码中阶数由参数n控制,默认为4阶。由递归方法去构造n阶的Hourglass(具体看下面源码)。

下图是我自己根据官方源码画的详细一点的Hourglass模块

看图会发现,这里也有残差的思想(跳级结构)(也可看做Skiplayer 与 convBlock),这个思想很无敌啊。。。
其次就是Hourglass(漏斗)的体现就在于先MaxPool后UpSampling,
通过MaxPool将feature map缩小,再UpSampling进行扩大(编码-解码那一套)

用图片画出来就粗略的长这样(下图由2个Hourglass 组成,不含内部构造):

4阶的详细的长这样(单个Hourglass以及其内部构造 4阶):

下面是mxnet源码

  1. 我把Hourglass命名为HourGlassBlock模块
  2. args.nModules参数设置Redidual的数量,根据官方这里设置为1
class HourGlassBlock(nn.HybridBlock):def __init__(self, n, in_channels, **kwargs):'''args:n:              当前HourGlass所在的阶数in_channels:    当前HourGlass输入的channels'''super(HourGlassBlock, self).__init__(**kwargs)self.n = nself.in_channels = in_channelswith self.name_scope():# Upper branchself.up1 = nn.HybridSequential()for _ in range(args.nModules):self.up1.add(Residual(self.in_channels, self.in_channels))# Lower branchself.low1_MaxPool = nn.MaxPool2D((2, 2), (2, 2))self.low1 = nn.HybridSequential()for _ in range(args.nModules):self.low1.add(Residual(self.in_channels, self.in_channels))# 递归生成上文图中虚线框内的子阶Hourglassif self.n > 1:self.low2 = HourGlassBlock(self.n - 1, self.in_channels)else:self.low2 = nn.HybridSequential()for _ in range(args.nModules):self.low2.add(Residual(self.in_channels, self.in_channels))self.low3 = nn.HybridSequential()for _ in range(args.nModules):self.low3.add(Residual(self.in_channels, self.in_channels))def hybrid_forward(self, F, x):up1 = self.up1(x)    # SkipLayer# ConvBlockx = self.low1_MaxPool(x)x = self.low1(x)x = self.low2(x)x = self.low3(x)up2 = F.UpSampling(x, scale=2, sample_type="nearest")return up1 + up2

Linear

在讲完整网络之前要提一下这个模块,该模块主要由1*1的Conv与BN以及Relu组成,紧跟在Hourglass之后

class Lin(nn.HybridBlock):def __init__(self, numOut, **kwargs):super(Lin, self).__init__(**kwargs)self.numOut = numOutself.lin = nn.HybridSequential()with self.lin.name_scope():self.lin.add(nn.Conv2D(numOut, 1))self.lin.add(nn.BatchNorm())self.lin.add(nn.Activation('relu'))def hybrid_forward(self, F, x):return self.lin(x)

完整网络

现在到具体的完整的网络

先看看整个网络大致长啥样


一张图片输入,经过一些简单的卷积模块,再通过N个Hourglass模块,最终得到输出

下面看看N个Hourglass模块之前的卷积模块

  1. 输入的图片是HW3 (尺寸默认是256*256)
  2. 在图片经过N个n阶的Hourglass模块之前是有一些处理模块的,位置:图中从左到右红色4个模块

接下来看看封装的Hourglass模块

N个n阶的Hourglass模块并不是简简单单的拼在一起,每两个Hourglass之间还有一些卷积模块(也就是说单个Hourglass是被一些卷积模块封装起来的,如下图),以及论文中提到的“中继监督”也在这里运用。下面详细说说这里的细节

  1. 这里的输入指上一个相同的Hourglass模块或者上文说的Hourglass模块之前是有一些处理模块
  2. 在这里有三条SkipLayer,通过element_add相加在一起
  3. 在我的代码中,用了一个[]存放N个out(heatmap),中继监督就是这里体现,存放每一个Hourglass模块的输出,N个out都用作loss计算
  4. 关于这个out,首先要说明下N在论文中使用的是8(代码中对应是nStack),那么out中便有8个heatmap
  5. 因为训练的数据集MPII一共有(16个关节点),所以每个heatmap的shape是(16, x, x, x),一个heatmap对应一个关节点,类似这样

    6.用N个heatmap去计算loss时候,或者计算acc的时候,很显然要把标注文件也转为对应的16个heatmap,然后进行训练,这个打算在下一篇中讲

具体代码如下:

class Hourglass(nn.HybridBlock):def __init__(self, **kwargs):super(Hourglass, self).__init__(**kwargs)self.out = []# HourglassBlock模块之前的图片处理模块self.preprocess = nn.HybridSequential(prefix="pre")with self.preprocess.name_scope():self.preprocess.add(nn.Conv2D(64, 7, (2, 2), (3, 3)))self.preprocess.add(nn.BatchNorm())self.preprocess.add(nn.Activation("relu"))self.preprocess.add(Residual(64, 128))self.preprocess.add(nn.MaxPool2D((2, 2), (2, 2)))self.preprocess.add(Residual(128, 128))self.preprocess.add(Residual(128, args.nFeats))# HourglassBlock模块self.hourglass_blocks = nn.HybridSequential(prefix="hg")with self.hourglass_blocks.name_scope():for _ in range(args.nStack):hourglass_block = nn.HybridSequential()hourglass_block.add(HourGlassBlock(4, args.nFeats))   # args.nFeats = 256for _ in range(args.nModules):  # args.nModules = 1hourglass_block.add(Residual(args.nFeats, args.nFeats))hourglass_block.add(Lin(args.nFeats))hourglass_block.add(nn.Conv2D(args.nJoints, (1, 1), (1, 1), (0, 0)))    # args.nJoints = 16 数据集16个关节点 这层之后可以生成outself.hourglass_blocks.add(hourglass_block)self.conv1 = nn.Conv2D(args.nFeats, (1, 1), (1, 1), (0, 0))self.conv2 = nn.Conv2D(args.nFeats, (1, 1), (1, 1), (0, 0))def hybrid_forward(self, F, x):x = self.preprocess(x)for i in range(args.nStack):temp_x = xx = self.hourglass_blocks[i](x)self.out.append(x)if i < args.nStack:x1 = self.conv1(x)x2 = self.conv2(x)x = temp_x + x1 + x2return self.out

参考资料

  • 官方源码
  • pytorch版本
  • keras版本
  • 论文地址(国内查看比较方便的平台)

如有错误,还请指正

Stacked Hourglass笔记源码(一)网络结构相关推荐

  1. C++Primer Plus (第六版)阅读笔记 + 源码分析【目录汇总】

    C++Primer Plus (第六版)阅读笔记 + 源码分析[第一章:预备知识] C++Primer Plus (第六版)阅读笔记 + 源码分析[第二章:开始学习C++] C++Primer Plu ...

  2. 一箭双雕 刷完阿里P8架构师spring学习笔记+源码剖析,涨薪8K

    关于Spring的叙述: 我之前死磕spring的时候,刷各种资料看的我是一头雾水的,后面从阿里的P8架构师那里拿到这两份资料,从源码到案例详细的讲述了spring的各个细节,是我学Spring的启蒙 ...

  3. DCGAN论文笔记+源码解析

    论文地址:UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS 源码 ...

  4. 不止面试题,笔记源码统统都有

    前言 其实前几篇文章已经写了好多有关于Spring源码的文章,事实上,很多同学虽然一直在跟着阅读.学习这些Spring的源码教程,但是一直都很迷茫,这些Spring的源码学习,似乎只是为了面试吹逼用, ...

  5. dubbo笔记+源码刨析

    会不断更新!冲冲冲!跳转连接 https://blog.csdn.net/qq_35349982/category_10317485.html dubbo笔记 1.概念 RPC全称为remote pr ...

  6. C++| 匠心之作 从0到1入门学编程【视频+课件+笔记+源码】

    目录 1.课程简介 1.1.语言特点(支持数据封装和数据隐藏) 1.2.工作原理 1.3.课程目录 2.视频(资料+视频)百度网盘 2.1.视频在线观看地址 2.2.视频源码 3.博客笔记汇总表 第1 ...

  7. zookeeper笔记+源码刨析

    会不断更新!冲冲冲!跳转连接 https://blog.csdn.net/qq_35349982/category_10317485.html zookeeper 1.介绍 Zookeeper 分布式 ...

  8. mysql数据库源码安装_学习笔记-源码安装mariadb 20210128

    源码安装Mariadb数据库 安装之前先检查一下空间: 1 [15:13:16 root@centos8 ~]#free -h(#检查空间)2 total used free shared buff/ ...

  9. MyBatis学习笔记-源码分析篇

    引言 SQL 语句的执行涉及多个组件,其中比较重要的是 Executor. StatementHandler. ParameterHandler 和 ResultSetHandler. Executo ...

  10. InfoGAN论文笔记+源码解析

    论文地址:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial ...

最新文章

  1. 图片放大显示的jQuery插件
  2. 57-高级路由:分发列表:多协议分发列表实验:DV、LS
  3. Unity制作游戏中的场景
  4. 包含min函数的栈 大坑
  5. this和self区别
  6. cmake 指定编译器_我们需要懂得CMake文件
  7. SQL Calendar Table
  8. 深度学习:卷积神经网络(convolution neural network)
  9. RHCE-samba服务
  10. fundamental-react在POC中的一个应用
  11. Gradle中的buildScript代码块
  12. css32D、3D、动画、过渡
  13. [Git]4.1 分支与合并
  14. CentOS6源码安装VSFTPD3
  15. Msm8960(APQ8064)平台的MSM-AOSP-kitkat编译适配(7):信号通讯
  16. MPEG-TS封装格式
  17. 胆囊息肉,需要切除吗
  18. vim工具——常用插件
  19. xpath爬取智联招聘--大数据开发职位并保存为csv
  20. 解读电商平台10大促销活动类型

热门文章

  1. Java-编辑图片,添加文字
  2. Hadoop技术内幕-Hadoop远程过程调用
  3. 五 IP核行业潜在投资方向和机会
  4. Docker容器技术与应用(项目1 Docker容器简介)
  5. 计算机是人类的好伴侣 作文,有电脑真好作文
  6. Xcode5 创建模板和UIView 关联XIB
  7. mybatis处理xml大于小于号报异常
  8. 关于H5版本及说明-为什么优雅草YYC蜻蜓系统H5版本打包不成功以及相关问题
  9. 【重要】国庆节快乐!有三AI所有课程限时7天优惠
  10. Gabor滤波器详解