import torch
import divide
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
from model import Netepochs = 10    #迭代训练五次
cirterion = nn.CrossEntropyLoss()    #定义损失函数
optimizer = optim.SGD(Net.parameters(), lr=0.0001, momentum=0.9)    #定义优化器for epoch in range(epochs):running_loss = 0    #损失值train_correct = 0    #分类样本正确的总数train_total = 0    #分类样本总数for i, data in enumerate(divide.train_loader, 0):inputs, train_labels = datainputs, labels = Variable(inputs), Variable(train_labels)optimizer.zero_grad()outputs = Net(inputs)_, train_predicted = torch.max(outputs.data, 1)itrain_correct += (train_predicted == labels.data).sum()loss = cirterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()train_total += train_labels.size(0)print('train %d epoch loss: %.3f  acc: %.3f ' % (epoch + 1, running_loss / train_total, 100 * train_correct / train_total))print('finished training!')

运行:

train 1 epoch loss: 0.172  acc: 56.029
train 2 epoch loss: 0.163  acc: 62.057
train 3 epoch loss: 0.155  acc: 65.274
train 4 epoch loss: 0.149  acc: 68.006
train 5 epoch loss: 0.144  acc: 70.131
train 6 epoch loss: 0.139  acc: 71.749
train 7 epoch loss: 0.135  acc: 72.994
train 8 epoch loss: 0.130  acc: 74.360
train 9 epoch loss: 0.126  acc: 75.486
train 10 epoch loss: 0.122  acc: 76.520
finished training!
  1. 定义epoch:迭代训练的次数
  2. 定义cirterion、optimizer:损失函数、优化器
  3. 使用循环语句进行数据的迭代训练,一共迭代训练epoch次
  4. runing_loss:总损失值、train_correct:分类正确总数量、train_total:训练图片总数量
  5. enumerate返回两个值,i是下标,由enumerate()函数给定,data表示数据,数据包括Tensor(图像)与标签
  6. input存储Tensor,train_labels存储标签,将这两个参数设置为Variable,表示参数存储的值是随着训练不断更新的
  7. zero_grad()将梯度置零,将损失值关于权重的导数置零,每次训练不同的batch,将上一个batch反向传播训练的导数置零,也就是优化器权重的清除
  8. Tensor传入神经网络模型Net
  9. torch.max返回两个值,第一个值是tensor中每行最大值,第二个值是最大值的索引,最大值就是神经网络判断属于各种类别概率值中的最大概率值,索引就是这个概率值最大的类别
  10. 用train_predicted存储这个索引,即神经网络判断出的类别,概率不记录
  11. 如果这个类别与实际的类别相同即分类正确,我们将train_correct进行+1操作,记录分类正确图片总数量
  12. 将神经网络判断出的类别与正确类别传入loss中计算损失值,再进行loss的反向传播,使用step()进行优化器权重的更新
  13. 将损失值相加记录总损失值
  14. size(0)表示train_labels的第一维度,就是每次训练图片的数量:batch_size的大小,求和并由train_total记录,表示训练图片总数量
  15. 输出:running_loss / train_total为平均的损失值,train_correct / train_total为正确率

Pytorch:三、数据的迭代训练(猫狗)相关推荐

  1. ResNet-50 训练猫狗分类

    ResNet-50 训练猫狗分类 1.ResNet网络 2.猫狗数据 3.训练 4.测试 这里介绍一下怎么搭建ResNet网络,并说明一下残差网络的结构,并使用ResNet来训练一个二分类问题 1.R ...

  2. kaggle(一)训练猫狗数据集

    记录第一次使用kaggle训练猫狗数据集 import os import shutil os.listdir('../input/train/train') base_dir = './cat_do ...

  3. 【猫狗数据集】pytorch训练猫狗数据集之创建数据集

    数据集下载地址: 链接:https://pan.baidu.com/s/1tJQIY0ob2EyQn3cDipPkow?pwd=7gch  提取码:7gch 猫狗数据集的分为训练集25000张,在训练 ...

  4. tensorflow.js在nodejs训练猫狗分类模型在浏览器上使用

    目录 本人系统环境 注意事项 前言 数据集准备 处理数据集 数据集初步处理 将每一张图片数据转换成张量数据(tensor) 将图片转换成张量数组的代码和运行效果 将图片的标注转换成张量数据(tenso ...

  5. 训练猫狗数据集(及图像增强后训练)

    目录 一.所需环境(安装附链接) 二.数据集准备 三.网络模型 四.Data preprocessing (数据预处理) 五.训练 六.使用数据填充 一.所需环境(安装附链接) tensorflow和 ...

  6. MobileNetv1训练猫狗图片

    MobileNetv1是谷歌提出的轻量级的卷积神经网络(同VGG相比),它主要采用了深度可分离的卷积,从而大大降低了参数数目和网络的计算量.深度可分离卷积包括两个部分,分别是Depthwise卷积和P ...

  7. pytorch 猫狗二分类 resnet

    深度学习(猫狗二分类) 题目要求 数据获取与预处理 网络模型 模型原理 Resnet背景 Resnet原理 代码实现 模型构建 训练过程 批验证过程 单一验证APP 运行结果 训练结果 批验证结果 A ...

  8. 基于tensorflow的猫狗分类

    基于tensorflow的猫狗分类 数据的准备 引入库 数据集来源 准备数据 显示一张图片的内容 搭建网络模型 构建网络 模型的编译 数据预处理 模型的拟合与评估 模型的拟合 预测一张图片 损失和精度 ...

  9. 含噪数据的有效训练,谷歌地标图像检索竞赛2020冠军方案解读

    2020年谷歌地标图像检索竞赛(Google Landmark Retrieval 2020)是今年举行的大型图像检索算法竞赛,该比赛在Kaggle 竞赛平台进行,吸引了全球541支团队参赛,最终来自 ...

  10. 实战:利用pytorch搭建VGG-16实现从数据获取到模型训练的猫狗分类网络

    起 在学习了卷积神经网络的理论基础和阅读了VGG的论文之后,对卷积有了大致的了解,但这都只是停留在理论上,动手实践更为重要,于是便开始了0基础学习pytorch.图像处理,搭建模型. pytorch学 ...

最新文章

  1. 为什么阿里P8、P9技术大牛反复强调“结构化思维”?
  2. 我也来晒Flex编写的工作流编辑器
  3. python3 循环语句
  4. java for(o t :object) 获取顺序号_java中线程的生命周期
  5. ansible之setup模块常用的信息
  6. 一款不错的网站压力测试工具webbench
  7. final 最终 演练 java
  8. C语言之文件读写探究(六):fscanf、fprintf(格式化读写文件)
  9. [恩分到动归分类好了]取石子游戏
  10. Qt插件机制介绍及实现
  11. TCP/IP三次握手四次挥手
  12. 维纳滤波python 函数_python实现逆滤波与维纳滤波示例
  13. 数字转为大写金额(C#)
  14. LVDS接口分类,时序,输出格式
  15. Oracle身份证校验函数
  16. swarm bzz 安装0.5.3版本基础解析。
  17. VS code,Live Server更改默认浏览器
  18. 计算机演示文稿操作,计算机操作与应用 PowerPoint 演示文稿的设计与制作.ppt
  19. 极限、连续、导数与微分
  20. C语言 生成随机数 srand用法 伪随机函数rand srand需不需要重新播种问题 srand该不该放在循环里

热门文章

  1. chrome主页篡改修复
  2. 3Q之战广东高院上演“熟人新案”
  3. 一维码(条形码)二维码三维码基本原理
  4. 12306html布局,12306更新验证码
  5. Spire.XLS的使用
  6. bzoj1698 / P1606 [USACO07FEB]白银莲花池Lilypad Pond
  7. wirehark数据分析与取证attack.pcap
  8. ESP8266开发之旅 基础篇⑤ ESP8266 SPI通信和I2C通信
  9. 单片机控制光耦开关继而控制电机转动
  10. 如何看待越来越多年轻人追捧「摸鱼哲学」,拒绝努力的年轻人真比老一辈活得更通透吗?