批量归一化

批量归一化(batch normalization)层能让较深的神经网络的训练变得更加容易

通常来说,数据标准化预处理对于浅层模型就足够有效了。随着模型训练的进行,当每层中参数更新时,靠近输出层的输出较难出现剧烈变化
但对深层神经网络来说,即使输入数据已做标准化,训练中模型参数的更新依然很容易造成靠近输出层输出的剧烈变化。这种计算数值的不稳定性通常令我们难以训练出有效的深度模型。

批量归一化的提出正是为了应对深度模型训练的挑战。

在模型训练时,批量归一化利用小批量上的均值和标准差不断调整神经网络中间输出,从而使整个神经网络在各层的中间输出的数值更稳定。

批量归一化层

对全连接层和卷积层做批量归一化的方法稍有不同

1. 对全连接层做批量归一化

通常将批量归一化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为u\boldsymbol{u}u,权重参数和偏差参数分别为W\boldsymbol{W}W和b\boldsymbol{b}b,激活函数为ϕ\phiϕ。设批量归一化的运算符为BN\text{BN}BN。那么,使用批量归一化的全连接层的输出为

ϕ(BN(x)),\phi(\text{BN}(\boldsymbol{x})),ϕ(BN(x)),
其中批量归一化输入x\boldsymbol{x}x由仿射变换

x=Wu+b\boldsymbol{x} = \boldsymbol{W\boldsymbol{u} + \boldsymbol{b}}x=Wu+b

得到。

考虑一个由mmm个样本组成的小批量,仿射变换的输出为一个新的小批量
B=x(1),…,x(m)\mathcal{B} = {\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(m)} }B=x(1),…,x(m)
它们正是批量归一化层的输入。对于小批量B\mathcal{B}B中任意样本x(i)∈Rd,1≤i≤m\boldsymbol{x}^{(i)} \in \mathbb{R}^d, 1 \leq i \leq mx(i)∈Rd,1≤i≤m,批量归一化层的输出同样是ddd维向量

y(i)=BN(x(i)),\boldsymbol{y}^{(i)} = \text{BN}(\boldsymbol{x}^{(i)}),y(i)=BN(x(i)),

首先,对小批量B\mathcal{B}B求均值和方差:
μB←1m∑i=1mx(i),\boldsymbol{\mu}_\mathcal{B} \leftarrow \frac{1}{m}\sum{i = 1}^{m} \boldsymbol{x}^{(i)},μB​←m1​∑i=1mx(i), σB2←1m∑i=1m(x(i)−μB)2,\boldsymbol{\sigma}_\mathcal{B}^2 \leftarrow \frac{1}{m} \sum{i=1}^{m}(\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B})^2,σB2​←m1​∑i=1m(x(i)−μB​)2,
其中的平方计算是按元素求平方。

接下来,使用按元素开方和按元素除法对x(i)\boldsymbol{x}^{(i)}x(i)标准化:
x^(i)←x(i)−μBσB2+ϵ,\hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}},x^(i)←σB2​+ϵ​x(i)−μB​​,

这里ϵ>0\epsilon > 0ϵ>0是一个很小的常数,保证分母大于0
在上面标准化的基础上,批量归一化层引入了两个可以学习的模型参数:

  • 拉伸(scale)参数 γ\boldsymbol{\gamma}γ
  • 偏移(shift)参数 β\boldsymbol{\beta}β。

这两个参数和x(i)\boldsymbol{x}^{(i)}x(i)形状相同,皆为ddd维向量。它们与x(i)\boldsymbol{x}^{(i)}x(i)分别做按元素乘法(符号⊙\odot⊙)和加法计算

y(i)←γ⊙x^(i)+β{\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}y(i)←γ⊙x^(i)+β

至此,我们得到了x(i)\boldsymbol{x}^{(i)}x(i)的批量归一化的输出y(i)\boldsymbol{y}^{(i)}y(i)。

值得注意的是,可学习的拉伸和偏移参数保留了不对x^(i)\hat{\boldsymbol{x}}^{(i)}x^(i)做批量归一化的可能

  • 此时只需学出γ=σB2+ϵ\boldsymbol{\gamma} = \sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}γ=σB2​+ϵ​和β=μB\boldsymbol{\beta} = \boldsymbol{\mu}_\mathcal{B}β=μB​。
  • 我们可以对此这样理解:如果批量归一化无益,理论上,学出的模型可以不使用批量归一化

2. 对卷积层做批量归一化

对卷积层来说,批量归一化发生在卷积计算之后、应用激活函数之前

  • 如果卷积计算输出多个通道,我们需要对这些通道的输出分别做批量归一化,且每个通道都拥有独立的拉伸和偏移参数,并均为标量。

设小批量中有mmm个样本。在单个通道上,假设卷积计算输出的高和宽分别为ppp和qqq。我们需要对该通道中m×p×qm \times p \times qm×p×q个元素同时做批量归一化。对这些元素做标准化计算时,我们使用相同的均值和方差,即该通道中m×p×qm \times p \times qm×p×q个元素的均值和方差。

3. 预测时的批量归一化

  • 使用批量归一化训练时,我们可以将批量大小设得大一点,从而使批量内样本的均值和方差的计算都较为准确
  • 将训练好的模型用于预测时,我们希望模型对于任意输入都有确定的输出。因此,单个样本的输出不应取决于批量归一化所需要的随机小批量中的均值和方差
  • 一种常用的方法是通过移动平均估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出
  • 可见,和丢弃层一样,批量归一化层在训练模式和预测模式下的计算结果也是不一样的。

实现批量归一化层

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):# 判断当前模式是训练模式还是预测模式if not is_training:# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持# X的形状以便后面可以做广播运算mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)# 训练模式下用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 拉伸和偏移return Y, moving_mean, moving_var

自定义一个BatchNorm层。它保存参与求梯度和迭代的拉伸参数gamma和偏移参数beta,同时也维护移动平均得到的均值和方差,以便能够在模型预测时被使用。

BatchNorm实例所需指定的num_features参数对于全连接层来说应为输出个数,对于卷积层来说则为输出通道数。该实例所需指定的num_dims参数对于全连接层和卷积层来说分别为2和4。

class BatchNorm(nn.Module):def __init__(self, num_features, num_dims):super(BatchNorm, self).__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成0和1self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 不参与求梯度和迭代的变量,全在内存上初始化成0self.moving_mean = torch.zeros(shape)self.moving_var = torch.zeros(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成falseY, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

使用批量归一化层的LeNet

修改5.5节(卷积神经网络(LeNet))介绍的LeNet模型,从而应用批量归一化层。我们在所有的卷积层或全连接层之后、激活层之前加入批量归一化层。

class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterdef train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
net = nn.Sequential(nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_sizeBatchNorm(6, num_dims=4),nn.Sigmoid(),nn.MaxPool2d(2, 2), # kernel_size, stridenn.Conv2d(6, 16, 5),BatchNorm(16, num_dims=4),nn.Sigmoid(),nn.MaxPool2d(2, 2),FlattenLayer(),nn.Linear(16*4*4, 120),BatchNorm(120, num_dims=2),nn.Sigmoid(),nn.Linear(120, 84),BatchNorm(84, num_dims=2),nn.Sigmoid(),nn.Linear(84, 10))

训练修改后的模型:

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

与自己定义的BatchNorm类相比,Pytorch中nn模块定义的BatchNorm1d和BatchNorm2d类使用起来更加简单,二者分别用于全连接层和卷积层,都需要指定输入的num_features参数值。

用PyTorch实现使用批量归一化的LeNet:

net = nn.Sequential(nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_sizenn.BatchNorm2d(6),nn.Sigmoid(),nn.MaxPool2d(2, 2), # kernel_size, stridenn.Conv2d(6, 16, 5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.MaxPool2d(2, 2),FlattenLayer(),nn.Linear(16*4*4, 120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120, 84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84, 10))

(pytorch-深度学习)批量归一化相关推荐

  1. PyTorch深度学习-跟着小土堆学习

    目录 学习视频链接 一些问题 P4:Python/PyTorch学习中两大法宝函数-dir().help() P5:PyCharm及Jupyter使用及对比 P6:PyTorch加载数据初认识 P7: ...

  2. 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)

    文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...

  3. 《PyTorch深度学习实践》

    [<PyTorch深度学习实践>完结合集] https://www.bilibili.com/video/BV1Y7411d7Ys/?share_source=copy_web&v ...

  4. 【PyTorch】PyTorch深度学习实践|视频学习笔记|P6-P9

    PyTorch深度学习实践 逻辑斯蒂回归及实现 背景与概念 基于分类问题中属性是类别性的,所以不能采取基于序数的线性回归模型,而提出了新的分类模型--逻辑斯蒂回归模型,输出每个样本在各个预测值上的概率 ...

  5. PyTorch深度学习笔记之四(深度学习的基本原理)

    本文探讨深度学习的基本原理.取材于<PyTorch深度学习实战>一书的第5章.也融入了一些自己的内容. 1. 深度学习基本原理初探 1.1 关于深度学习的过程的概述 给定输入数据和期望的输 ...

  6. 【PyTorch深度学习实践 | 刘二大人】B站视频教程笔记

    资料 [参考:<PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili] [参考 分类专栏:PyTorch 深度学习实践_错错莫的博客-CSDN博客] 全[参考 分类专栏:PyT ...

  7. 笔记|(b站)刘二大人:pytorch深度学习实践(代码详细笔记,适合零基础)

    pytorch深度学习实践 笔记中的代码是根据b站刘二大人的课程所做的笔记,代码每一行都有注释方便理解,可以配套刘二大人视频一同使用. 用PyTorch实现线性回归 # 1.算预测值 # 2.算los ...

  8. PyTorch深度学习实践

    根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...

  9. pytorch深度学习_用于数据科学家的深度学习的最小pytorch子集

    pytorch深度学习 PyTorch has sort of became one of the de facto standards for creating Neural Networks no ...

  10. pytorch深度学习入门笔记

    Pytorch 深度学习入门笔记 作者:梅如你 学习来源: 公众号: 阿力阿哩哩.土堆碎念 B站视频:https://www.bilibili.com/video/BV1hE411t7RN? 中国大学 ...

最新文章

  1. NC:MetaSort通过降低微生物群落复杂度以突破宏基因组组装难题
  2. Linux命令之whereis
  3. UNIX再学习 -- 守护进程(转)
  4. sir跟seir模型有啥区别_H3C B5mini拆机,看一下跟B5有啥区别
  5. 徐小平:全员拥抱区块链是内部分享 1比特币寻泄密者
  6. python flask Blueprint搭建
  7. Linux: 系统配置 crond 和 crontab(有图有代码有真相!!!)
  8. 华为手机免root改mac_拿到华为手机,这4个默认设置一定要改,不然流量电量很快被耗光...
  9. LINQ to SQL自定义映射表关系(1:N or 1:1)
  10. Java双十二活动代码_双十二直播脚本怎么写?戳我速领!
  11. java xmlutil_XmlUtil工具类(toxml()和toBean())
  12. 用python获取实时地球图像作为壁纸(windows)
  13. 发现尖叫--生物电体感
  14. 【数据可视化】第五章—— 基于PyEcharts的数据可视化
  15. [修練營ASP.NET]淺談多層式架構 (Multi Tiers)
  16. Nginx + Lua 搭建网站WAF防火墙
  17. win7 微信 代理服务器,Win7系统使用电脑版微信如何@别人
  18. TOPSIS(优劣解距离法)【附Python实现代码及可视化代码】
  19. RocketMQ 的安装和可视化界面
  20. 优思学院|质量大师的那些名言(一)【质量是免费的】

热门文章

  1. 大数据职业理解_大数据带给我们职业三大根本改变
  2. linux mongo 服务器,如何用MongoDB在Linux服务器上创建大量连接和线程的记忆
  3. python矩阵相乘例题_百道Python入门级练习题(新手友好)第一回合——矩阵乘法...
  4. ubuntu 虚拟机 串口 socket_上篇 | 虚拟机Ubuntu向开发板AMR传送文件
  5. php 字符串比较txt,PHP读到txt中文字符串比较失败
  6. js监听iframe关闭_Node.js文档NET[翻译]
  7. Java开发以及Web 和移动程序员必须了解的10个框架
  8. Java开发需要达到什么样的水平才称得上架构师?
  9. antix linux安装教程,antiX 19.1 发布,轻量级的桌面Linux发行版
  10. yii2 调用未定义函数_Python 函数(三) 使用规则