Conditional Batch Normalization 的概念来源于这篇文章:Modulating early visual processing by language 。后来又先后被用在 cGANs With Projection Discriminator 和Self-Attention Generative Adversarial Networks 。本文将首先简略介绍 Modulating early visual processing by language ,接着结合 Self-Attention GANs 的 pytorch 代码,详细分析 categorical conditional Batch Normalization 的具体实现。

太长不看版

Modulating early visual processing by language

这篇文章改进了一个基于图片的问答系统 (VQA: Visual Question Answering)。系统的输入为一张图片和一个针对图片的问题,系统输出问题的答案,如下图所示:

这类系统通常是这样设计的:一个预训练的图像识别网络,例如 ResNet,用于提取图片特征;一个 sequential 模型,例如 LSTM、GRU 等,用于提取句子的特征,并根据句子预测应该关注图片的什么位置(attention);将语言特征、由 attention 加权过后的图片特征结合起来,共同输入一个网络,最终输出问题的答案。

上图左侧为传统的 VQA 系统,我们发现,LSTM 提取的特征只在 ResNet 的顶层才和图片特征结合起来,因为通常意义上讲,神经网络的底层提取的是基础的几何特征,顶层是有具体含义的语义特征,因此,应该把语言模型提取的句子特征在网络顶层和图片特征结合。然而,作者认为,底层的图片特征也应该结合语言特征。理由是,神经科学证明:语言会帮助图片识别。例如,如果事先告诉一个人关于图片的内容,然后再让他看图片,那么这个人识别图片的速度会大大加快。因此,作者首创了将图片底层信息和语言信息结合的模型,如上图右侧所示。

https://github.com/ap229997/Conditional-Batch-Norm/blob/master/model/cbn.py​github.com

Categorical Conditional Batch Normalization

在 conditional generative model 里面,存在一个隐隐让人不安的问题:一个 batch 里面不同类别的训练数据,放在一起做 Batch Normalization 不太妥当。因为不同类别的数据理应对应不同的均值和方差,其归一化、放缩、偏置也应该不同。针对这个问题,一个解决方案是不再考虑整个 batch 的统计特征,各个图像只在自己的 feature map 内部归一化,例如采用 Instance Normalization 和 Layer Normalization 来代替 BN。但是这些替代品的表现都不如 BN 稳定,接受程度不如 BN 高。

这时我们想到了上一节中介绍的 conditional BN。CBN 以 LSTM 提取的自然语言特征作为 condition,预测 BN 层参数的增量,达到对不同的输入,都有相对应的归一化参数。既然自然语言特征可以作为 condition,用于预测 BN 参数的变化,那么图片的类别信息自然也可以作为 condition 来预测 BN 层的参数。因此 cGANs With Projection Discriminator 和 Self-Attention GANs 借鉴了 CBN 里面的 condition 的思想,稍加修改,用在了自己的 conditional GAN 模型中。

接下来我们将研究其具体的实现,代码来自:

https://github.com/crcrpar/pytorch.sngan_projection/blob/master/links/conditional_batchnorm.py​github.com

class ConditionalBatchNorm2d(nn.BatchNorm2d):"""Conditional Batch Normalization"""def __init__(self, num_features, eps=1e-05, momentum=0.1,affine=False, track_running_stats=True):super(ConditionalBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)def forward(self, input, weight, bias, **kwargs):self._check_input_dim(input)exponential_average_factor = 0.0if self.training and self.track_running_stats:self.num_batches_tracked += 1if self.momentum is None:  # use cumulative moving averageexponential_average_factor = 1.0 / self.num_batches_tracked.item()else:  # use exponential moving averageexponential_average_factor = self.momentumoutput = F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias,self.training or not self.track_running_stats,exponential_average_factor, self.eps)if weight.dim() == 1:weight = weight.unsqueeze(0)if bias.dim() == 1:bias = bias.unsqueeze(0)size = output.size()weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)return weight * output + bias 

我们看到,这个 ConditionalBatchNorm2d类,继承自 pytorch 的 BatchNorm2d类,对比这个代码和官方的 BatchNorm2d 的代码,发现其构造函数的参数和BatchNorm2d完全相同,构造函数中直接调用了基类,也就是BatchNorm2d的构造函数。而 forward函数中,多了weightbias两个参数。forward的代码大部分也是直接 copy 自 BatchNorm2d的基类_BatchNorm的代码,无非是设置一下 moving average 的 momentum,记录一下总共读取了多少个 batch,以便在没有设置 momentum 的情况下,在全体样本上计算均值和方差。直到调用官方的底层 C 函数库 F.batch_norm,代码完全没有对_BatchNorm类的forward函数做出任何修改,其output 就是对输入的 feature map 做了一次 BatchNorm2d。 真正修改的是后面加的几行:

        if weight.dim() == 1:weight = weight.unsqueeze(0)if bias.dim() == 1:bias = bias.unsqueeze(0)size = output.size()weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)return weight * output + bias 

这里用到了forward函数参数中的 weightbias。由于是在图像 feature 上操作,需要对 weight 和 bias 的维度做一些改变,使其与 feature map output的维度相同。最后代码返回weight*output+bias 。似乎很 naive,可是说好的 condition 呢?说好的 categorical 信息呢?别着急,它们都隐藏在 weightbias中。这个类只不过是个基类,下面的类才是真正要用到的类:

class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,affine=False, track_running_stats=True):super(CategoricalConditionalBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)self.weights = nn.Embedding(num_classes, num_features)self.biases = nn.Embedding(num_classes, num_features)self._initialize()def _initialize(self):init.ones_(self.weights.weight.data)init.zeros_(self.biases.weight.data)def forward(self, input, c, **kwargs):weight = self.weights(c)bias = self.biases(c)return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)

这个类的构造函数中比它的基类多加了一项num_classes。构造函数中,首先调用了它的基类,也就是ConditionalBatchNorm2d的构造函数,用于初始化大部分参数。接下来设置了两个网络层:

        self.weights = nn.Embedding(num_classes, num_features)self.biases = nn.Embedding(num_classes, num_features)

nn.Embedding层的作用是,把图片的 label 转换成 dense 向量,而不像 one-hot-encoding,只能把 label 转换成稀疏向量。nn.Embedding的第一个参数表示总共有多少个类,第二个参数表示每个 label 映射成多少维的向量。这个网络层的好处是,可以任意指定 label vector 的 dimension,它的本质是一个 num_classesnum_feature列的矩阵,这个矩阵的参数随着网络的训练不断更新。前向传播时,label 是几就取第几行的向量出来,用以表示这个 label。其实这个 Embedding 相当于把 one-hot encoding 输入一个 bias 为 0 的 linear layer。

在构造函数的最后,通过调用 self._initialize初始化 self.weights 和 self.bias,分别把它们初始化为全 1 和全 0。这样在网络训练的初期,这俩相当于不存在一样,整个类就是一个BatchNorm2d

接下来看前向传播函数:

    def forward(self, input, c, **kwargs):weight = self.weights(c)bias = self.biases(c)return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)

这个函数也很简单,输入 feature map input和类别标签c,注意c 应该是 LongTensor 格式的,否则会报错。接下来,根据 c 挑出 weights embedding 层和 biases embedding 层中的第c行,作为 weight 和 bias 输入基类的前向传播函数,最终得到 Conditional Batch Normalization 的输出。这个 categorical condition 发挥作用的阶段,就是 embedding 的阶段。

总结

提出 conditional Batch Normalization 这一思想的论文 Modulating early visual processing by language,是为了解决特定问题:即在预训练 ResNet 提取的图片底层信息中,融合进自然语言信息,用于辅助图片信息的提取。

而后面的 cGANs With Projection Discriminator 和Self-Attention Generative Adversarial Networks 则是利用 condition 的思想,把图片的 categorical 信息用来指导生成 BN 层的映射参数。我们发现,网络训练完成后,同一个类别的图片,将对应同一套 BN 层参数,不同类别的图片,将对应不同的 BN 层参数。

通过这个微小的改动,我们终于可以愉快地在 conditional generative model 上使用 Batch Normalization 操作,而不必担心不同类别的图片对应不同的映射参数了。

Conditional Batch Normalization 详解(SFT思路来源)相关推荐

  1. Batch Normalization详解(原理+实验分析)

    Batch Normalization详解(原理+实验分析) 1. 计算过程 2. 前向传播过程 3. 反向传播过程 4. 实验分析 4.1 实验一:验证有没有BatchNorm下准确率的区别 4.2 ...

  2. 批归一化(Batch Normalization)详解

    批归一化(Batch Normalization)详解 文章目录 批归一化(Batch Normalization)详解 前言 一.数据归一化 二.BN解决的问题:Internal Covariate ...

  3. batch normalization详解

    1.引入BN的原因 1.加快模型的收敛速度 2.在一定程度上缓解了深度网络中的"梯度弥散"问题,从而使得训练深层网络模型更加容易和稳定. 3.对每一批数据进行归一化.这个数据是可以 ...

  4. 【深度学习】Batch Normalization详解

    Batch Normalization 学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50866313 作者:hjimce 一.背景意义 ...

  5. Batch Normalization详解以及pytorch实验

    Batch Normalization是google团队在2015年论文<Batch Normalization: Accelerating Deep Network Training by R ...

  6. TensorFlow实现条件批归一化(Conditional Batch Normalization)

    TensorFlow实现条件批归一化(Conditional Batch Normalization) 条件批归一化(Conditional Batch Normalization) TensorFl ...

  7. @Conditional注解的详解和应用

    Spring中@Conditional注解的详解和应用 一.@Conditional注解的作用 二.条件判断在什么时候执行? 2.1 什么是配置类? 2.2 Spring对配置类的处理阶段 2.3 @ ...

  8. python batch normalization_Batch Normalization 详解

    一.背景意义 本篇博文主要讲解2015年深度学习领域,非常值得学习的一篇文献:<Batch Normalization: Accelerating Deep Network Training b ...

  9. 天刀手游制作人亲笔详解制作思路

    小楼一夜听春雨,江湖何处不飞花.与谁把酒邀明月,将我行兮向天涯.自天刀手游面世以来,强大的引擎技术,细腻的江湖烟火:海阔天空的唯美意境,精彩纷呈的战斗体验等,都吸引了无数少侠. 而在这一项项惊艳大家的 ...

最新文章

  1. javamail gmail
  2. 在CentOS 6.9 x86_64的nginx 1.12.2上开启标准模块ngx_http_auth_request_module实录
  3. 剑指offer:孩子们的游戏(圆圈中最后剩下的数)
  4. python中os.path.isdir()等函数的作用及用法
  5. NET Core的代码安全分析工具 - Security Code Scan
  6. js 引入 缓存_引入故意缓存
  7. C#中数组、ArrayList和List三者的区别(转) ,加修改
  8. aspx转发php_asp,php,aspx一句话合集
  9. php三维数组转换二维数组,php 三维数组转二维数组(多维数组变合拼二维数组)(foreach循环 数组叠加)...
  10. Python使用wordnet工具计算词集与词条基本用法(三)
  11. [转载] Java中如何在方法中return返回多个值
  12. 区块链入门实战教程—看完本文你也会开发区块链
  13. Ubuntu“ System Program Problem Detected”问题
  14. Matlab:Matlab 软件学习之GUI图像用户界面简介(工具栏/菜单栏/对话框)、GUI界面设计案例应用(设计二级菜单栏)之详细攻略
  15. 如何查看虚拟机ip地址
  16. android 闹钟设置铃声,安卓手机闹钟设置音乐铃声的方法
  17. 超级好上手的告白小程序
  18. Dlang如何禁用垃圾回收(GC)
  19. 浅谈微信活码架构及其简易实现
  20. echarts省份地图

热门文章

  1. 268条PCB Layout设计规范
  2. 高层货架一般需要计算机控制,FMS的自动化仓库一般由货架、堆垛机和计算机控制管理系统组成。()...
  3. 麒麟810处理器_荣耀Play4T Pro评测:麒麟810处理器,堪称“真香”千元手机
  4. cocos2dx-lua升级spine
  5. 2020年十大最佳自动化测试工具
  6. AutoML论文笔记(九)CARS Continuous Evolution for Efficient Neural Architecture Search:连续进化神经网络搜索
  7. SIFT地理特征匹配
  8. win10IE浏览器运行VBScript脚本语言的简单方法
  9. 可以识别图片上的文字的小程序
  10. 你还敢继续玩免单、玩好评返现吗?