深度学习面试高频之Batch Normalization
在我找这份工作的时候,面试了将近有十家公司左右。总体分两个方面一个是偏项目,另外一个是偏深度学习的常见应用层。比如为什么使用batch norm可以加速收敛,卷积的实现原理和池化的作用等等。因为我是本科应届生,可能问的问题不会特别的深入。但是个人觉得如果掌握好这些基础知识,无论是在应聘岗位还是在实际工作中。都会起着决定性作用(最近在用C++实现前馈网络和网络结构的剪枝)。如果coding能力跟上,对数据熟悉。调优一个模型是没有太大问题的。
本文主要讲解batch norm的基础实现过程,真正在训练过程中会有momentum(均值,方差平滑)等相关参数
momentum 参数讲解
from torch import nn
import numpy as np
import torch
torch.manual_seed(18)
tensor = torch.randint(1,9,(1,3,3,3)).float() # 创建 [batch,channels,height,width]
batch_norm = nn.BatchNorm2d(3) # channels = 3
# batch_norm.training =False
if __name__ == "__main__":print(tensor)for _ in range(10):batch_norm(tensor)print(batch_norm.running_mean)
我们发现这里的均值是有三个,这里每一个均值对应的一个channel中的feature map。我们来手动计算一下第一个特征图的均值(3+4+1+6+7+2+7+3+3)/9=4.0,我们发现running_mean不仅不等于均值4.0,反而还是在不断变化的。这是我在实现bn过程中遇到的坑,还记得我们前面说过的momentum参数吗?更新公式为running_mean=(1−momentum)∗old_mean+momentum∗current_meanrunning\_mean = (1-momentum)*old\_mean+momentum*current\_meanrunning_mean=(1−momentum)∗old_mean+momentum∗current_mean,这里是指均值,方差平滑的含义。个人感觉是减轻数据中噪音点的影响。这里再解释就很轻松啦,初始化的old_mean=0old\_mean=0old_mean=0。bn默认的momentum=0.1,故第一次计算:0.4=0.14+0.90 ,第二次计算:0.76=0.14+0.90.4…,如果循环100次,均值就非常接近4了。
这种平滑的思想运用的很多,包括梯度平滑,权重平滑等等。都能够有效的增强模型训练过程中对异常点的鲁棒性。
实现原理
# -*- coding: utf-8 -*-
# @Time : 2019/8/8 13:35
# @Author : ljf
import torch
from torch import nn
from torch import optim
import numpy as np# TODO
# 一 数据
train_x = torch.rand(size=[10,3,8,8])
train_y = torch.rand(size=[10,3,8,8])
# np.random.seed(18)
temp_x = [[1,2,3,4,5,6,7,8],[-1,-2,-3,-4,-5,-6,-7,-8],[1,2,3,4,5,6,7,8],[-1,-2,-3,-4,-5,-6,-7,-8],[1, 2, 3, 4, 5, 6, 7, 8],[-1, -2, -3, -4, -5, -6, -7, -8],[1, 2, 3, 4, 5, 6, 7, 8],[-1, -2, -3, -4, -5, -6, -7, -8]]
temp_y = np.array([[temp_x,temp_x,temp_x]])
test_x = torch.Tensor(temp_y)
print(test_x.size())
# print(test_x)
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.bn = nn.BatchNorm2d(num_features=3)# for m in self.modules():# if isinstance(m,nn.Conv2d):# nn.init.constant_(m.weight,1.0)# nn.init.constant_(m.bias,1.0)def forward(self, x):out = self.bn(x)return out
# 三 优化器,损失函数
is_evaluate = True
model = Net()
if is_evaluate:model.load_state_dict(torch.load("./pth/batchnorm2d.pth"))mean = test_x.mean(dim=[2, 3], keepdim=True)var = test_x.var(dim=[2, 3], keepdim=True)# print(model.bn.running_mean)_out = (test_x - model.bn.running_mean.view(1, 3, 1, 1)) / torch.sqrt(model.bn.running_var.view(1, 3, 1, 1) + model.bn.eps)_output = model.bn.weight.view(mean.size()) * _out + model.bn.bias.view(mean.size())print(_output)model.eval()# print(model.bn.eps)pred_y = model(test_x)print(pred_y)
else:optimizer = optim.SGD(model.parameters(), lr=0.001)criterion = nn.MSELoss()# 四 迭代数据for i in range(20):output = model(train_x)if i ==0:print(output.size())loss = criterion(output, train_y)optimizer.zero_grad()loss.backward()optimizer.step()# 五 模型保存torch.save(model.state_dict(),"./pth/batchnorm2d.pth")
常见作用
防止梯度消失
针对sigmoid等激活函数,通过bn能够将数据放缩到线性区。sigmoid在区间[-4,4]之外它的导数基本接近0.
数据分布一致
机器学习的本质是学习分布,NN深了之后会产生偏移,BN使每一层的数据分布缩放到一致。
数据增强
一定程度的数据扩充,加了数据抖动的操作(加了偏移量?)
对参数w适应性更强,训练更稳定
测试阶段使用 Batch Normalization?
训练阶段每一个mini-batch的均值μbatch\mu_{batch}μbatch和σbatch2\sigma^2_{batch}σbatch2的均值作为测试使用,或者使用吴恩达课程中提出的指数加权平均范化和线性变换使得每一层网络的输入数据的均值和方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现每一层网络中独立学习,有利于提高整个神经网络的学习速度。
学习率设置太高时,会使得参数更新步伐过大,容易出现震荡和不收敛。但是使用BN的网络将不会受到参数数值大小的影响。例如,我们对参数WWW进行缩放得到aWa WaW。对于缩放前的值WxWxWx,我们设其均值为μ1\mu_1μ1,方差为σ12\sigma_1^2σ12;对于缩放值(也就是上一层的输出)αWx\alpha W xαWx,设其均值为μ2\mu_2μ2,方差为σ22\sigma_2^2σ22,于是我们有
μ2=aμ1,σ22=a2σ12\mu_{2}=a \mu_{1}, \quad \sigma_{2}^{2}=a^{2} \sigma_{1}^{2}μ2=aμ1,σ22=a2σ12,如果忽略ϵ\epsilonϵ,则有:
BN(aWu)=γ⋅aWx−μ2σ22+β=γ⋅aWx−aμ1a2σ12+β=γ⋅Wu−μ1σ12+β=BN(Wx)B N(a W u)=\gamma \cdot \frac{a W x-\mu_{2}}{\sqrt{\sigma_{2}^{2}}}+\beta=\gamma \cdot \frac{a W x-a \mu_{1}}{\sqrt{a^{2} \sigma_{1}^{2}}}+\beta=\gamma \cdot \frac{W u-\mu_{1}}{\sqrt{\sigma_{1}^{2}}}+\beta=B N(W x)BN(aWu)=γ⋅σ22aWx−μ2+β=γ⋅a2σ12aWx−aμ1+β=γ⋅σ12Wu−μ1+β=BN(Wx)∂BN((aW)u)∂x=γ⋅aWσ22=γ⋅aWa2σ12=∂BN(Wx)∂x\frac{\partial B N((a W) u)}{\partial x}=\gamma \cdot \frac{a W}{\sqrt{\sigma_{2}^{2}}}=\gamma \cdot \frac{a W}{\sqrt{a^{2} \sigma_{1}^{2}}}=\frac{\partial B N(W x)}{\partial x}∂x∂BN((aW)u)=γ⋅σ22aW=γ⋅a2σ12aW=∂x∂BN(Wx) 对输入xxx求导
∂BN((aW)x)∂(aW)=γ⋅xσ22=γ⋅xaσ12=1a⋅∂BN(Wx)∂W\frac{\partial B N((a W) x)}{\partial(a W)}=\gamma \cdot \frac{x}{\sqrt{\sigma_{2}^{2}}}=\gamma \cdot \frac{x}{a \sqrt{\sigma_{1}^{2}}}=\frac{1}{a} \cdot \frac{\partial B N(W x)}{\partial W}∂(aW)∂BN((aW)x)=γ⋅σ22x=γ⋅aσ12x=a1⋅∂W∂BN(Wx)这里的awa waw是放大后的权重值
我们可以看到,经过BN操作以后,权重的缩放值会被“抹去”,因此保证了输入数据分布稳定在一定范围内。另外,权重的缩放并不会影响到对 xxx 的梯度计算;并且当权重越大时,即 aaa 越大, 1a\frac{1}{a}a1 越小,意味着权重 WWW 的梯度反而越小,这样BN就保证了梯度不会依赖于参数的scale,使得参数的更新处在更加稳定的状态。
因此,在使用Batch Normalization之后,抑制了参数微小变化随着网络层数加深被放大的问题,使得网络对参数大小的适应能力更强,此时我们可以设置较大的学习率而不用过于担心模型divergence的风险。
深度学习面试高频之Batch Normalization相关推荐
- 【深度学习】深入理解Batch Normalization批归一化
[深度学习]深入理解Batch Normalization批归一化 转自:https://www.cnblogs.com/guoyaohua/p/8724433.html 这几天面试经常被问到BN层的 ...
- 【深度学习】深入理解Batch Normalization批标准化
这几天面试经常被问到BN层的原理,虽然回答上来了,但还是感觉答得不是很好,今天仔细研究了一下Batch Normalization的原理,以下为参考网上几篇文章总结得出. Batch Normaliz ...
- (转)【深度学习】深入理解Batch Normalization批标准化
原文链接:https://www.cnblogs.com/guoyaohua/p/8724433.html 这几天面试经常被问到BN层的原理,虽然回答上来了,但还是感觉答得不是很好,今天仔细研究了一下 ...
- 【深度学习】简单理解Batch Normalization批标准化
资源 相关的Paper请看这两篇 Batch Normalization Accelerating Deep Network Training by Reducing Internal Covaria ...
- 系统学习深度学习(十三)--Batch Normalization
Batch Normalization,简称BN,来源于<Batch Normalization: Accelerating Deep Network Training by Reducing ...
- 机器学习、深度学习面试知识点汇总
作者丨Oldpan 来源丨oldpan博客 编辑丨极市平台 导读 本文总结了一些秋招面试中会遇到的问题和一些重要的知识点,适合面试前突击和巩固基础知识. 前言 最近这段时间正临秋招,这篇文章是老潘在那 ...
- 常见的12个深度学习面试问题(提高篇)
序言 整理了一篇来自公众号AI公园的文章,原文链接:常见的12个深度学习面试问题,通过对文章知识点整理来巩固所学的知识,也为了以后更好的复习. 正文 1. 介绍Batch Normalization的 ...
- 21个热门的深度学习面试问答的综合指南
本文列出了一系列热门的深度学习面试的问题,每一个问题都有相应的答案.认真阅读,或许你会对深度学习面试的知识有个全面的了解. 介绍 你打算参加深度学习面试吗?你是否已经迈出了第一步,申请了一个深度学习的 ...
- 【深度学习面试八股文】-- 1-5
最近会更新一个深度学习面试中常见问题,及相应的答案,希望对找工作的朋友有所帮助,总结不易,欢迎持续关注.公众号:羽峰码字,欢迎来撩. 目录 1.如何处理样本不均衡问题 2.数据增强方法 3.过拟合的解 ...
最新文章
- webpack中loader加载器(打包非js模块)
- Altium Designer 规则设置Advance(Query)的使用
- python基础知识点-Python基础中的29个知识点
- mysql dump 查看器_mysql备份之mysqldump工具
- CENTOS7配置静态IP后无法ping通外部网络的问题
- IntelliJ IDEA导入一个已经存在的子模块
- linux 抓包文件 导出,tcpdump抓包和scp导出以及Wireshark查看
- JS——实现短信验证码的倒计时功能(没有验证码,只有倒计时)
- [css] 请描述css的权重计算规则
- linux如何卸载conky,Linux Deepin 15.10.2 下折腾 简单自制的 Conky Conky-manager
- php多少内置函数,php有多少个内置函数
- 对抗搜索之【最大最小搜索】【Alpha-Beta剪枝搜索】
- tk芯片智能机刷机方法_MTK通用刷机教程 MTK芯片智能机刷机方法
- l0phtcrack 7(爆破管理员密码)使用教程
- java中长整形怎么定义_java中长整型定义
- 批量剔除consul无效服务
- C语言编程--根据麦克劳林公式计算任意角的正弦余弦
- Nginx for Mac - 苹果系统SSL证书安装
- 【聊聊Java】Java中HashMap常见问题 -- 扩容、树化、死链问题
- 如何在PPT中插入LaTeX公式
热门文章
- cad块无法分解,炸不开怎么办?
- MTK Android N 源码Rom Root
- 索信达2021届校园招聘春招正式启动
- 一键收藏:OEE / TEEP / 六大损失 / SMED / 约束理论 / 持续改进 / 短间隔控制
- 解决微信开放平台分享图片失败问题
- 维修记录:东芝打印机2802am出现故障C449解决方法
- 谷歌趋势图显示“QR码”关键词搜索量创历史新高
- 湖南省湘西土家族苗族自治州谷歌高清卫星地图下载
- Win10插入耳机后无声音,声音问题疑难解答提示“外设似乎没有插上” 三种解决方法
- 一种新的满意度调查方法 NPS(net promoter score 净推荐值)