model.train()

启用 Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

model.eval()

不启用 Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

注意:训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

    # 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层

BN

BN的作用主要是对网络中间的每层进行归一化处理,并且使用变换重构(Batch Normalization Transform)保证每层所提取的特征分布不会被破坏。
训练时是针对每个mini-batch的,但是在测试中往往是针对单张图片,即不存在mini-batch的概念。由于网络训练完毕后参数都是固定的,因此每个batch的均值和方差都是不变的,因此直接结算所有batch的均值和方差。所有Batch Normalization的训练和测试时的操作不同。

Dropout

Dropout能够克服Overfitting,在每个训练Batch中,通过忽略一半的特征检测器,可以明显的减少过拟合现象

在训练中,每个隐层的神经元先乘以概率P,然后再进行激活。
在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P

想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

下面我们看一个我们写代码的时候常遇见的错误写法:

在这个特定的例子中,似乎每50次迭代就会降低准确度。
如果我们检查一下代码, 我们看到确实在train函数中设置了训练模式。

def train(model, optimizer, epoch, train_loader, validation_loader):model.train() # ???????????? 错误的位置for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):# model.train() # 正确的位置,保证每一个batch都能进入model.train()的模式data, target = Variable(data), Variable(target)# Inferenceoutput = model(data)loss_t = F.nll_loss(output, target)# The iconic grad-back-step triooptimizer.zero_grad()loss_t.backward()optimizer.step()if batch_idx % args.log_interval == 0:train_loss = loss_t.item()train_accuracy = get_correct_count(output, target) * 100.0 / len(target)experiment.add_metric(LOSS_METRIC, train_loss)experiment.add_metric(ACC_METRIC, train_accuracy)print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx, len(train_loader),100. * batch_idx / len(train_loader), train_loss))with experiment.validation():val_loss, val_accuracy = test(model, validation_loader) # ????????????experiment.add_metric(LOSS_METRIC, val_loss)experiment.add_metric(ACC_METRIC, val_accuracy)

这个问题不太容易注意到,在循环中我们调用了test函数。

def test(model, test_loader):model.eval()# ...

在test函数内部,我们将模式设置为eval。这意味着,如果我们在训练过程中调用了test函数,我们就会进eval模式,直到下一次train函数被调用。这就导致了每一个epoch中只有一个batch使用了dropout ,这就导致了我们看到的性能下降。

修复很简单 —— 我们将model.train() 向下移动一行,让其在训练循环中。理想的模式设置是尽可能接近推理步骤,以避免忘记设置它。修正后,我们的训练过程看起来更合理,没有中间的峰值出现。

补充:model.eval()和torch.no_grad()的区别

在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

1.主要用于通知dropout层和BN层在train和validation/test模式间切换:

  • 在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
  • 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

2.该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

Pytorch——model.train 和 model.eval_Vic_Hao的博客-CSDN博客

【Pytorch】model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别相关推荐

  1. Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

    model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层. model.train() 官方文档 启用 Batch Normaliz ...

  2. 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层

    requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...

  3. (深入理解)model.eval() 、model.train()以及torch.no_grad() 的区别

    文章目录 简要版解释 深入版解释 简要版解释 在PyTorch中进行validation或者test的时侯,会使model.eval()切换到测试模式,在该模式下,model.training=Fas ...

  4. 【Pytorch】model.train() 和 model.eval() 原理与用法

    文章目录 一.两种模式 二.功能 1. model.train() 2. model.eval() 为什么测试时要用 model.eval() ? 3. 总结与对比 三.Dropout 简介 参考链接 ...

  5. Pytorch model.train()

    文章目录 1.前言 2.作用及原因 2.1.Batch Normalization 2.1.1训练时的BN层 2.1.2测试时的BN层 2.2.Dropout 3.总结 1.前言 在使用Pytorch ...

  6. 【pytorch】model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的 ...

  7. Pytorch: model.eval(), model.train() 讲解

    文章目录 1. model.eval() 2. model.train() 两者只在一定的情况下有区别:训练的模型中含有dropout 和 batch normalization 1. model.e ...

  8. Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...

  9. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

最新文章

  1. 发现问题,是解决问题的第一步
  2. 为什么说python是世界上最好的语言-python是世界上最好的语言
  3. go语言使用go-sciter创建桌面应用(八) 窗口显示时,自动加载后端数据。
  4. C/Cpp / 设计模式 / 简单工厂模式
  5. php 数据类型伪类型,PHP之伪类型与变量
  6. ------------------uniq 去重复
  7. 4.Model Validation
  8. [SpecialJudge]构造“神秘“字符串(洛谷P3742题题解,Java语言描述)
  9. laravel mongodb如何声明数据类型_什么是MongoDB?简介,架构,功能和示例
  10. 登陆窗体相关的控件 1124
  11. Selenium模拟JQuery滑动解锁
  12. 运营前线2:一线运营专家的运营方法、技巧与实践03 3步策略做好内容管理
  13. 计算机服务重置,怎么重置电脑网络设置
  14. 鸿蒙系统和安装包,鸿蒙系统安装包
  15. golang 结构体使用chan
  16. Unity的IOS PlayerSettings的设置说明
  17. Python 办公效率化学习(自学)四.Excel文件的写入
  18. 一个简单的猜数字游戏(附带关机惩罚)
  19. ubuntu 18.04 安装gdb
  20. SSO 轻量级实现指南(原生 Java 实现):SSO Client 部分

热门文章

  1. 开发指南专题二:JEECG微云快速开发平台JEECG框架初探
  2. SpringCloud微服务(04):Turbine组件,实现微服务集群监控
  3. (九)模型驱动和属性驱动
  4. 电脑不能上网win7 解决办法
  5. Moment.js常见用法总结 1
  6. C# 之 提高WebService性能大数据量网络传输处理
  7. ubuntu12.04 安装Android Studio笔记
  8. 黑马程序员---java基础------------------多线程
  9. oracle的cols,Oraclecols_as_rows比对数据
  10. python开发转行渗透测试_月薪45K的大牛用Python开发一款密码攻击测试器!密码形同虚设!...