文章目录

  • 一、两种模式
  • 二、功能
    • 1. model.train()
    • 2. model.eval()
      • 为什么测试时要用 model.eval() ?
    • 3. 总结与对比
  • 三、Dropout 简介
  • 参考链接

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train()model.eval()

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。


二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。


三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。


参考链接

  1. PyTorch中train()方法的作用是什么
  2. 【pytorch】model.train()和model.evel()的用法
  3. pytorch中net.eval() 和net.train()的使用
  4. Pytorch学习笔记11----model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析、输出整个tensor的方法
  5. 好文:Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

【Pytorch】model.train() 和 model.eval() 原理与用法相关推荐

  1. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

  2. model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析

    转载:原文地址-传送门 1.model.train()与model.eval()的用法 看别人的面经时,浏览到一题,问的就是这个.自己刚接触pytorch时套用别人的框架,会在训练开始之前写上mode ...

  3. 【Pytorch】model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

    model.train() 启用 Batch Normalization 和 Dropout 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model. ...

  4. Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

    model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层. model.train() 官方文档 启用 Batch Normaliz ...

  5. 【pytorch】model.train()和model.evel()的用法

    1.model.train()与model.eval()的用法 看别人的面经时,浏览到一题,问的就是这个.自己刚接触pytorch时套用别人的框架,会在训练开始之前写上model.trian(),在测 ...

  6. pytroch:model.train()、model.eval()的使用

    前言:最近在把两个模型的代码整合到一起,发现有一个模型的代码整合后性能大不如前,但基本上是源码迁移,找了一天原因才发现是因为model.eval()和model.train()放错了位置!!!故在此介 ...

  7. model.train()和model.eval()的用法及model.eval()可能导致测试准确率的下降

    问题导入: 一般我们在训练模型时会在前面加上:model.train() 在测试模型时会在前面使用:model.eval() 但是在某次使用网络测试模型时,训练准确率很高,但测试准确率很低,排查了各种 ...

  8. PyTorch:train模式与eval模式的那些坑

    文章目录 前言 1. train模式与eval模式 2. BatchNorm 3. 数学原理 结束语 前言   博主在最近开发过程中不小心被pytorch中train模式与eval模式坑了一下o(*≧ ...

  9. 【pytorch】model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的 ...

最新文章

  1. Java黑皮书课后题第8章:**8.11(游戏:九个硬币的正反面)一个3*3的矩阵中放置了9个硬币,这些硬币有些面朝上有朝下。1表示正面0表示反面,每个状态使用一个二进制数表示。使用十进制数表示状态
  2. Java黑皮书课后题第7章:**7.18(冒泡排序)使用冒泡排序算法编写一个排序方法。编写一个测试程序,读取10个double型的值,调用这个方法,然后显示排序好的数字
  3. oracle ora01732,一天一小步_2008.5.02: ora-01732错误
  4. AM,DSB,SSB,FM信号调制matlab
  5. 计算机模拟与生态工程,2018年环境生态工程专业分析及就业前景
  6. R语言 相关分析和典型相关分析
  7. 追踪盗窃12亿用户登录数据的网络犯罪团伙
  8. LSTM网络(Long Short-Term Memory )
  9. cmake编译多个文件夹_CMake应用技巧:在一个工程中编译运行多个文件
  10. 【九度OJ】题目1084:整数拆分
  11. 区块链 xuperchain 同步模式 纯异步模式 异步阻塞模式 怎么启动
  12. 拼音模糊查询+java,Java将中文转换成拼音,用于字母的模糊查询
  13. struts的增删查改
  14. 电阻转换温度值c语言,PT1000电阻值转化为温度值的计算公式
  15. GNSS/INS组合导航(九):三维简化的INS/GPS组合导航系统
  16. 手写一个Spring Boot Starter
  17. 台式机通过网线连接笔记本的wifi网络
  18. 参考答案-数据库原理测试一
  19. 星际战甲服务器维护时间,星际战甲 官网:2月4日服务器维护结束公告
  20. 数据分析 - 单表简单查询

热门文章

  1. 今日头条、抖音创始人张一鸣
  2. fxmarket:9月25日黄金、沪深300、恒指策略分析
  3. BZOJ1050 旅行comf(kruskal)
  4. 把音频中的某个人声去掉_怎样去掉音频中的背景音乐 只保留人声?
  5. 计算机日常故障DIY维修有哪些,电脑故障排除及优化完全DIY
  6. 联想服务器能够上固态硬盘吗,拯救我的台式机:Lenovo 联想 固态硬盘 入手记
  7. 微信小程序获取二维码scene报错40129
  8. android 删除短信无效,android删除短信(绕过权限)
  9. 【C#】Excel舍入函数Round、RoundUp、RoundDown的C#版
  10. 心电图特效代码 html5,用canvas画心电图的示例代码