pytorch进行CIFAR-10分类(4)训练

我的系列博文:

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

Pytorch打怪路(一)pytorch进行CIFAR-10分类(2)定义卷积神经网络

Pytorch打怪路(一)pytorch进行CIFAR-10分类(3)定义损失函数和优化器

Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练(本文

Pytorch打怪路(一)pytorch进行CIFAR-10分类(5)测试

1、简述

经过前面的数据加载和网络定义后,就可以开始训练了,这里会看到前面遇到的一些东西究竟在后面会有什么用,所以这一步希望各位也能仔细研究一下

2、代码

for epoch in range(2):  # loop over the dataset multiple times 指定训练一共要循环几个epochrunning_loss = 0.0  #定义一个变量方便我们对loss进行输出for i, data in enumerate(trainloader, 0): # 这里我们遇到了第一步中出现的trailoader,代码传入数据# enumerate是python的内置函数,既获得索引也获得数据,详见下文# get the inputsinputs, labels = data   # data是从enumerate返回的data,包含数据和标签信息,分别赋值给inputs和labels# wrap them in Variableinputs, labels = Variable(inputs), Variable(labels) # 将数据转换成Variable,第二步里面我们已经引入这个模块# 所以这段程序里面就直接使用了,下文会分析# zero the parameter gradientsoptimizer.zero_grad()                # 要把梯度重新归零,因为反向传播过程中梯度会累加上一次循环的梯度# forward + backward + optimize      outputs = net(inputs)                # 把数据输进网络net,这个net()在第二步的代码最后一行我们已经定义了loss = criterion(outputs, labels)    # 计算损失值,criterion我们在第三步里面定义了loss.backward()                      # loss进行反向传播,下文详解optimizer.step()                     # 当执行反向传播之后,把优化器的参数进行更新,以便进行下一轮# print statistics                   # 这几行代码不是必须的,为了打印出loss方便我们看而已,不影响训练过程running_loss += loss.data[0]         # 从下面一行代码可以看出它是每循环0-1999共两千次才打印一次if i % 2000 == 1999:    # print every 2000 mini-batches   所以每个2000次之类先用running_loss进行累加print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))  # 然后再除以2000,就得到这两千次的平均损失值running_loss = 0.0               # 这一个2000次结束后,就把running_loss归零,下一个2000次继续使用print('Finished Training')

3、分析

①autograd

在第二步中我们定义网络时定义了前向传播函数,但是并没有定义反向传播函数,可是深度学习是需要反向传播求导的,
Pytorch其实利用的是Autograd模块来进行自动求导,反向传播。
Autograd中最核心的类就是Variable了,它封装了Tensor,并几乎支持所有Tensor的操作,这里可以参考官方给的详细解释:
http://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py
以上链接详细讲述了variable究竟是怎么能够实现自动求导的,怎么用它来实现反向传播的。
这里涉及到计算图的相关概念,这里我不详细讲,后面会写相关博文来讨论这个东西,暂时不会对我们理解这个程序造成影响
只说一句, 想要计算各个variable的梯度,只需调用根节点的backward方法,Autograd就会自动沿着整个计算图进行反向计算
而在此例子中,根节点就是我们的loss,所以:
程序中的loss.backward()代码就是在实现反向传播,自动计算所有的梯度。
所以训练部分的代码其实比较简单:
running_loss和后面负责打印损失值的那部分并不是必须的,所以关键行不多,总得来说分成三小节
第一节:把最开始放在trainloader里面的数据给转换成variable,然后指定为网络的输入;
第二节:每次循环新开始的时候,要确保梯度归零
第三节:forward+backward,就是调用我们在第三步里面实例化的net()实现前传,loss.backward()实现后传
每结束一次循环,要确保梯度更新

Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练相关推荐

  1. Pytorch打怪路(三)Pytorch创建自己的数据集2

    前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...

  2. gitlab 迁移、升级打怪之路:8.8.5-- 8.10.8 -- 8.17.8 -- 9.5.9 -- 10.1.4 -- 10.2.5

    gitlab 迁移.升级打怪之路:8.8.5--> 8.10.8 --> 8.17.8 --> 9.5.9 --> 10.1.4 --> 10.2.5 gitlab 数据 ...

  3. GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练

    GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练 目录 GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练 GPU基础信息查看 Pytorch是否使用 ...

  4. pytorch 转换onnx_新版PyTorch发布!新增TorchScript API,扩展ONNX导出

    铜灵 发自 凹非寺 量子位 出品 | 公众号 QbitAI 今天,PyTorch 1.2.0版正式发布. 官方表示,和1.1版本相比,新版本在使用体验上又往前迈进了一大步.主要新增/改动的功能包括: ...

  5. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  6. [PyTorch] 基于Python和PyTorch的MNIST的手写数字数据集的分类

    文章目录 讲解 MNIST的介绍 须导入的函数库 检查 pytorch 的版本 定义超参数 下载 MNIST的数据集 定义网络 网络实例化 定义训练函数 定义测试函数 主函数 全部源代码 2020.0 ...

  7. [PyTorch] 基于Python和PyTorch的cifar-10分类

    cifar-10数据集介绍 CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训练图像和10000个测试图像. 数据集分为5个训练批次和1个测 ...

  8. 动物数据集+动物分类识别训练代码(Pytorch)

    动物数据集+动物分类识别训练代码(Pytorch) 目录 动物数据集+动物分类识别训练代码(Pytorch) 1. 前言 2. Animals-Dataset动物数据集说明 (1)Animals90动 ...

  9. Pytorch实战:基于pytorch预测文章阅读量

    介绍 这里用简单的网络来预测CSDN的阅读量. 我希望训练后的模型,给它前7天的阅读量,让它预测出第八天的阅读量. 阅读量的数据(很少)采用CSDN提供的excel文件: 这里采用xrld工具包来读取 ...

最新文章

  1. MEMS传感器科普文
  2. Webix 1.5发布:一个强大的JavaScript UI组件库
  3. php如何生成html,php生成html文件方法总结
  4. python提供的默认的构造方法是什么_Python面试常见问题,__init__是构造函数吗?...
  5. 09945 oracle 解决方法_ORACLE rman与RMAN-00054ORA-09945
  6. python在运维自动化的前景_现在学运维自动化python和大数据?
  7. 用toad实现oracle数据迁移,Oracle 使用TOAD实现导入导出Excel数据
  8. 复制Linux虚拟机后的网卡问题解决
  9. left join,right join,inner join,full join之间的区别
  10. java 注释 代码,如何在Java中注释代码块
  11. java在dos命令_JAVA中如何执行DOS命令
  12. python 批量爬取网页pdf_python爬取网页内容转换为PDF文件
  13. ST的硬盘固件门给数据恢复带来的巨大收益
  14. 【洛谷刷题笔记】P4093 [HEOI2016/TJOI2016] 序列
  15. 制作u盘winpe启动盘_u盘启动盘制作工具教程
  16. 指数型组织到底是什么
  17. 基于天牛须搜索算法的函数寻优算法
  18. 黄小宁罪大恶极!!!!!!!!!!黄小宁罪大恶极!!!!!!!!!!
  19. 自定义 Oh My Zsh 主题 cchi.zsh-theme
  20. win10 关闭自动更新

热门文章

  1. 快速了解SDK和API的区别
  2. 开发者,你是如何做到高效开发的
  3. PR如何修改节目的名字,PR如何修改序列的名字
  4. 一键安装部署SSL(https)——腾讯云服务器提供
  5. 实习没成长,想离职了
  6. C++二维图形的打印 详解
  7. 软件领域的作家、导师兼咨询师杰里·温伯格去世
  8. 大数据在开发的过程中,主要会遇到哪些难点?
  9. 本杰明·格拉汉姆股略
  10. 水中浮力插件buoyancy_NaughtyWaterBuoyancy浮力插件解析