pytorch的 model.eval()和model.train()作用

pytorch中model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

model.eval():认为停止Batch Normalization的均值和方差统计,否则,即使不训练,因为有输入数据,BN的均值和方差也会改变。Dropout关闭,所有神经元都参与计算。

model.train():Batch Normalization的均值和方差统计开启,使得网络用到每一批数据的均值和方差,Dropout功能开启,定义好模型后,默认是model.train()模式。

torch.no_grad()用于停止autograd模块的工作,以起到加速和节省显存的作用,也就是不保存计算图,默认是保存的。

pytorch的模型搭建与训练流程

step1, 创建模型类,初始化模型的网络结构,在这里给出模型有哪几个模块的定义。

def forward()是pytorch模型类必有函数,用来定义模型的数据流,数据的输出从这里进入,逐级到最后一层,返回模型的输出。

# 搭建神经网络
class myModel(nn.Module):def __init__(self) -> None:super().__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.Dropout(p=0.6),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.Dropout(p=0.6),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.Dropout(p=0.6),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, input):input = self.model(input)return input

step2, 模型的训练,训练优化器的选择

模型结构已经构建好了,接下来需要给出训练优化的一些设定,

创建损失函数

定义优化器,包括不同优化器对应的参数设置,如torch.optim.SGD随机梯度下降优化器的学习率设置。

# 创建网络模型
model = myModel().to(device)# 创建损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2        # 1e-2 = 1 * (10)^(-2) = 1 / 100 = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

step3, 训练时候数据流的定义,反向传播函数的连接定义

for i in range(epoch):print("------第 {} 轮训练开始------".format(i+1))# 训练步骤开始total_train_accuracy=0model.train()for data in train_dataloader:imgs, targets = dataoutputs = model(imgs)               # 将训练的数据放入loss = loss_fn(outputs, targets)    # 得到损失值optimizer.zero_grad()               # 优化过程中首先要使用优化器进行梯度清零loss.backward()                     # 调用得到的损失,利用反向传播,得到每一个参数节点的梯度optimizer.step()                    # 对参数进行优化total_train_step += 1               # 上面就是进行了一次训练,训练次数 +1accuracy_train = (outputs.argmax(1) == targets).sum()total_train_accuracy += accuracy_train# 只有训练步骤是100 倍数的时候才打印数据,可以减少一些没有用的数据,方便我们找到其他数据if total_train_step % 100 == 0:print("训练次数: {}, Loss: {}".format(total_train_step, loss))

step4,模型验证集的数据流程

 model.eval()with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = model(imgs)loss = loss_fn(outputs, targets)            # 这里的 loss 只是一部分数据(data) 在网络模型上的损失total_test_loss = total_test_loss + loss    # 整个测试集的lossaccuracy=(outputs.argmax(1)==targets).sum()total_accuracy+=accuracyprint("整体测试集上的loss: {}, test accuracy is: {}".format(total_test_loss,total_accuracy.cpu().numpy()/test_data_size))

setp5,单张图像的测试

注意,如果训练时候用了cuda, 测试时候的输入也要转换为.cuda(),否则报错,tensor的float或者int类型也必须要一致。

如果保存的模型是基于cuda的,测试时候想要改成cpu,则在load模型后,加一行代码model.to('cpu')即可切换为cpu格式的。

model=torch.load('model_9.pth')#之前保存了整个的模型,所以直接load模型了
model.eval()fpath='./dataset/test/0_125.jpg'
img=Image.open(fpath)
img=np.array(img).transpose(2,0,1)
img=np.expand_dims(img,axis=0)  #(N,Ci, Hi, Wi)img=torch.tensor(img,dtype=torch.float32).cuda()
out=model(img)
print(out.argmax())

pytorch的训练测试流程总结,以及model.evel(), model.train(),torch.no_grad()作用相关推荐

  1. yolov1模型结构和训练测试流程详解

    一.网络结构 ①首先经过一个VGG主干网络提取特征,这里的主干网络可以自己选择,使用resnet也可以. ②reshape为一维,然后进行全连接,in_dim=25088,out_dim=4096,需 ...

  2. Pytorch: model.eval(), model.train() 讲解

    文章目录 1. model.eval() 2. model.train() 两者只在一定的情况下有区别:训练的模型中含有dropout 和 batch normalization 1. model.e ...

  3. Pytorch模型训练和模型验证

    文章目录 前言 模型训练套路 1.准备数据集 2.训练数据集和测试数据集的长度 3.搭建网络模型 4.创建网络模型.损失函数以及优化器 5.添加tensorboard 6.设置训练网络的一些参数 7. ...

  4. 装不了 pytorch=0.4.0? ubuntu下 基于 cuda=92 和 pytorch=1.2 配置环境跑通 CornerNet 训练和测试流程 步骤详解(包括GPU限制问题详解)

    ubuntu下跑通CornerNet的流程步骤 环境配置 写这篇博客原因? 更改conda_packagelist.txt conda下基于conda_packagelist.txt创建新环境 安装p ...

  5. pytorch神经网络训练及测试流程代码

    神经网络的训练及测试其实是个相对固定的流程,下面进行详细说明,包括命令行设置基本参数.如数据集路径等其他参数的设置.学习率.损失函数.模型参数的保存与加载及最终train.py与test.py的mai ...

  6. pytorch dataset读取数据流程_高效 PyTorch :如何消除训练瓶颈

    加入极市专业CV交流群,与 10000+来自港科大.北大.清华.中科院.CMU.腾讯.百度 等名校名企视觉开发者互动交流! 同时提供每月大咖直播分享.真实项目需求对接.干货资讯汇总,行业技术交流.关注 ...

  7. Pytorch分布式训练/多卡训练(二) —— Data Parallel并行(DDP)(2.2)(代码示例)(BN同步主卡保存梯度累加多卡测试inference随机种子seed)

    DDP的使用非常简单,因为它不需要修改你网络的配置.其精髓只有一句话 model = DistributedDataPrallel(model, device_ids=[local_rank], ou ...

  8. Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...

  9. MMSegmentation 训练测试全流程

    MMSegmentation 训练测试全流程 1.按照执行顺序的流程梳理 Level 0: 运行 Shell 命令: Level 1: 在 tools/train.py 内: Level 2: 转进到 ...

最新文章

  1. 【转】读马化腾的产品设计观
  2. binder IPC TRANSACTION过程分析(BC_TRANSACTION-Binder Driver)
  3. 二分图行列匹配--- hdu2119,hdu1498
  4. 使用脚本编写 Vim 编辑器,第 5 部分: 事件驱动的脚本编写和自动化
  5. 干货 | 机器学习正在面临哪些主要挑战?
  6. SAP loyalty management点击了公式超链接后的处理逻辑
  7. c++中用new和不用new创建对象的本质区别
  8. 需要在计算机安装msxml版本,Win7安装Office2010提示需要MSXML 6.10.1129.0组件怎么办?...
  9. 树莓派4B设置USB启动
  10. 免流解密之SAOML二开
  11. 5.RefineDNet论文阅读
  12. 4103 yxc 的日常
  13. 解决:el-input添加clearable属性后出现2个×清除图标
  14. Oracle定时任务-查询-创建-删除-调用-定时任务时间参数
  15. 香饽饽:腾讯强推的Redis天花板笔记,帮助初学者快速入门和提高(核心笔记+面试高频解析)
  16. 谷歌浏览器截图快捷键是什么?谷歌浏览器截图操作方法介绍
  17. 5g网络模式是以什么划分的_5g组网模式有几种
  18. php behaviors,从behaviors()来研究组件绑定行为的原理
  19. [转载] 信息系统项目管理师视频教程——25 战略管理
  20. *.accdb数据文件的数据解析工具类

热门文章

  1. 在Python中如何判断一个对象的类型?
  2. 汉锐USB会议摄像机、1080P让商务视频会议更加轻松
  3. 树莓派系统剪裁、克隆
  4. JS原生轮播(JS篇)
  5. Estimator::relativePose
  6. sqlmap 使用方法
  7. 为什么总学不好PS?300集PS从入门到高级自学教程,全面且系统
  8. mac java串口驱动,使用CH340/341的模块在Mac上驱动安装
  9. 联想台式机计算机接口,如果不能使用Lenovo台式计算机的USB接口怎么办
  10. 【用CSS让单行文本溢出显示省略号】