BN层参数详解(1,2)

  • 一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对于BN层测试的均值和方差是通过统计训练的时候所有的batch的均值和方差的平均值)或者Dropout层(对于Dropout层在测试的时候所有神经元都是激活的)。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
  • 同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainningaffinetrack_running_stats
  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=Falseγ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。(这里是一个可学习参数,)
  • trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性(意思就是说新的batch依赖于之前的batch的均值和方差这里使用momentum参数,参考了指数移动平均的算法EMA)。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

应用技巧:(1,2)

       通常pytorch都会用到optimizer.zero_grad() 来清空以前的batch所累加的梯度,因为pytorch中Variable计算的梯度会进行累计,所以每一个batch都要重新清空一次梯度,原始的做法是下面这样的:

问题:参数non_blocking,以及pytorch的整体框架??

代码(1)

for index,data,target in enumerate(dataloader):data = data.cuda(non_blocking=True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)output = model(data)loss = criterion(output,target)#清空梯度optimizer.zero_grad()loss.backward()optimizer.step()

而这里为了模仿minibacth,我们每次batch不清0,累积到一定次数再清0,再更新权重:

for index, data, target in enumerate(dataloader):#如果不是Tensor,一般要用到torch.from_numpy()data = data.cuda(non_blocking = True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)output = model(data)loss = criterion(data, target)loss.backward()if index%accumulation == 0:#用累积的梯度更新权重optimizer.step()#清空梯度optimizer.zero_grad()

虽然这里的梯度是相当于原来的accumulation倍,但是实际在前向传播的过程中,对于BN几乎没有影响,因为前向的BN还是只是一个batch的均值和方差,这个时候可以用pytorch中BN的momentum参数,默认是0.1,BN参数如下,就是指数移动平均

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

参考链接:https://www.jianshu.com/p/a646cbc913b4,https://www.zhihu.com/question/303070254/answer/573037166

pytorch---之BN层参数详解及应用(1,2,3)(1,2)?相关推荐

  1. caffe各层参数详解

    在prototxt文件中,层都是用layer{}的结构表示,而里面包含的层的参数可以在caffe.proto文件中找到,比如说Data类型的结构由message DataParameter所定义,Co ...

  2. Dropout和BN(层归一化)详解

    无论是机器学习,还是深度学习,模型过拟合是很常见的问题,解决手段无非是两个层面,一个是算法层面,一个是数据层面.数据层面一般是使用数据增强手段,算法层面不外乎是:正则化.模型集成.earlystopp ...

  3. [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)

    文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...

  4. PyTorch实现AlexNet模型及参数详解

    文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...

  5. pytorch MSELoss参数详解

    pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...

  6. pytorch中DataLoader的num_workers参数详解与设置大小建议

    Q:在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的? train_loader = torch.utils.data.Data ...

  7. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

  8. darknet 框架中.cfg文件的参数详解,以yolov3为例

    参考:darknet中cfg文件里参数的理解_zerojava0的博客-CSDN博客 参考:[Darknet源码 ]cfg文件参数详解_橘子都吃不起!的博客-CSDN博客 1.基础参数解释 batch ...

  9. Pytorch | yolov3原理及代码详解(二)

    阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...

最新文章

  1. 关于丢番图方程x^2-dy^2=-1
  2. 03:计算矩阵边缘元素之和
  3. 出现这些迹象,说明你面试可能没戏了
  4. MFC开发IM-第二十四篇、使用 acl 库针对 C++ 对象进行序列化及反序列编程
  5. python携程使用_简单了解python gevent 协程使用及作用
  6. 可能是史上最详细-Faster RCNN Pytorch 复现全纪录
  7. 机器学习 - [源码实现决策树小专题]决策树中子数据集的划分(不允许调用sklearn等库的源代码实现)
  8. AWK学习笔记四:awk的环境变量
  9. Mysql 常用函数集
  10. python 计算器功能实现
  11. Day4--MATLAB简介
  12. 安徽大学计算机专硕奖学金,2019年安徽大学新闻传播跨考华东师范大学计算机专硕,总分418,排名第一经验分享!...
  13. springmvc源码阅读之启动加载(2)-----------初始化参数
  14. HBuilderX连接安卓模拟器
  15. msrcr图像增强算法 matlab,图像处理之Retinex增强算法(SSR、MSR、MSRCR)
  16. 在Linux平台中调试C/C++内存泄漏方法 (腾讯和MTK面试的时候问到的)
  17. 设计出python_《设》字意思读音、组词解释及笔画数 - 新华字典 - 911查询
  18. 数据结构基础 之 递归算法实例讲解
  19. php微信一次性订阅消息demo,微信一次性订阅消息公众号或网页接入文档说明
  20. C语言中-条件编译#ifdef的妙用详解_透彻

热门文章

  1. bzoj 1659: [Usaco2006 Mar]Lights Out 关灯(IDA*)
  2. bzoj 1650: [Usaco2006 Dec]River Hopscotch 跳石子(二分)
  3. bzoj 1669: [Usaco2006 Oct]Hungry Cows饥饿的奶牛
  4. Pytorch生成Tensor常用方法汇总
  5. Java将excel文件转成json文件(有错误)
  6. ThinkpadT470接通电源开机显示电量0%充不进电且电源指示灯不亮的解决办法
  7. matlab设计一个简单图像直方图均衡的GUI程序
  8. 555定时器的应用——施密特触发器
  9. STM32CUBEF4 实现USB 虚拟串口
  10. appium 原理解析