Pytorch:三、数据的迭代训练(猫狗)
import torch
import divide
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
from model import Netepochs = 10 #迭代训练五次
cirterion = nn.CrossEntropyLoss() #定义损失函数
optimizer = optim.SGD(Net.parameters(), lr=0.0001, momentum=0.9) #定义优化器for epoch in range(epochs):running_loss = 0 #损失值train_correct = 0 #分类样本正确的总数train_total = 0 #分类样本总数for i, data in enumerate(divide.train_loader, 0):inputs, train_labels = datainputs, labels = Variable(inputs), Variable(train_labels)optimizer.zero_grad()outputs = Net(inputs)_, train_predicted = torch.max(outputs.data, 1)itrain_correct += (train_predicted == labels.data).sum()loss = cirterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()train_total += train_labels.size(0)print('train %d epoch loss: %.3f acc: %.3f ' % (epoch + 1, running_loss / train_total, 100 * train_correct / train_total))print('finished training!')
运行:
train 1 epoch loss: 0.172 acc: 56.029
train 2 epoch loss: 0.163 acc: 62.057
train 3 epoch loss: 0.155 acc: 65.274
train 4 epoch loss: 0.149 acc: 68.006
train 5 epoch loss: 0.144 acc: 70.131
train 6 epoch loss: 0.139 acc: 71.749
train 7 epoch loss: 0.135 acc: 72.994
train 8 epoch loss: 0.130 acc: 74.360
train 9 epoch loss: 0.126 acc: 75.486
train 10 epoch loss: 0.122 acc: 76.520
finished training!
- 定义epoch:迭代训练的次数
- 定义cirterion、optimizer:损失函数、优化器
- 使用循环语句进行数据的迭代训练,一共迭代训练epoch次
- runing_loss:总损失值、train_correct:分类正确总数量、train_total:训练图片总数量
- enumerate返回两个值,i是下标,由enumerate()函数给定,data表示数据,数据包括Tensor(图像)与标签
- input存储Tensor,train_labels存储标签,将这两个参数设置为Variable,表示参数存储的值是随着训练不断更新的
- zero_grad()将梯度置零,将损失值关于权重的导数置零,每次训练不同的batch,将上一个batch反向传播训练的导数置零,也就是优化器权重的清除
- Tensor传入神经网络模型Net
- torch.max返回两个值,第一个值是tensor中每行最大值,第二个值是最大值的索引,最大值就是神经网络判断属于各种类别概率值中的最大概率值,索引就是这个概率值最大的类别
- 用train_predicted存储这个索引,即神经网络判断出的类别,概率不记录
- 如果这个类别与实际的类别相同即分类正确,我们将train_correct进行+1操作,记录分类正确图片总数量
- 将神经网络判断出的类别与正确类别传入loss中计算损失值,再进行loss的反向传播,使用step()进行优化器权重的更新
- 将损失值相加记录总损失值
- size(0)表示train_labels的第一维度,就是每次训练图片的数量:batch_size的大小,求和并由train_total记录,表示训练图片总数量
- 输出:running_loss / train_total为平均的损失值,train_correct / train_total为正确率
Pytorch:三、数据的迭代训练(猫狗)相关推荐
- ResNet-50 训练猫狗分类
ResNet-50 训练猫狗分类 1.ResNet网络 2.猫狗数据 3.训练 4.测试 这里介绍一下怎么搭建ResNet网络,并说明一下残差网络的结构,并使用ResNet来训练一个二分类问题 1.R ...
- kaggle(一)训练猫狗数据集
记录第一次使用kaggle训练猫狗数据集 import os import shutil os.listdir('../input/train/train') base_dir = './cat_do ...
- 【猫狗数据集】pytorch训练猫狗数据集之创建数据集
数据集下载地址: 链接:https://pan.baidu.com/s/1tJQIY0ob2EyQn3cDipPkow?pwd=7gch 提取码:7gch 猫狗数据集的分为训练集25000张,在训练 ...
- tensorflow.js在nodejs训练猫狗分类模型在浏览器上使用
目录 本人系统环境 注意事项 前言 数据集准备 处理数据集 数据集初步处理 将每一张图片数据转换成张量数据(tensor) 将图片转换成张量数组的代码和运行效果 将图片的标注转换成张量数据(tenso ...
- 训练猫狗数据集(及图像增强后训练)
目录 一.所需环境(安装附链接) 二.数据集准备 三.网络模型 四.Data preprocessing (数据预处理) 五.训练 六.使用数据填充 一.所需环境(安装附链接) tensorflow和 ...
- MobileNetv1训练猫狗图片
MobileNetv1是谷歌提出的轻量级的卷积神经网络(同VGG相比),它主要采用了深度可分离的卷积,从而大大降低了参数数目和网络的计算量.深度可分离卷积包括两个部分,分别是Depthwise卷积和P ...
- pytorch 猫狗二分类 resnet
深度学习(猫狗二分类) 题目要求 数据获取与预处理 网络模型 模型原理 Resnet背景 Resnet原理 代码实现 模型构建 训练过程 批验证过程 单一验证APP 运行结果 训练结果 批验证结果 A ...
- 基于tensorflow的猫狗分类
基于tensorflow的猫狗分类 数据的准备 引入库 数据集来源 准备数据 显示一张图片的内容 搭建网络模型 构建网络 模型的编译 数据预处理 模型的拟合与评估 模型的拟合 预测一张图片 损失和精度 ...
- 含噪数据的有效训练,谷歌地标图像检索竞赛2020冠军方案解读
2020年谷歌地标图像检索竞赛(Google Landmark Retrieval 2020)是今年举行的大型图像检索算法竞赛,该比赛在Kaggle 竞赛平台进行,吸引了全球541支团队参赛,最终来自 ...
- 实战:利用pytorch搭建VGG-16实现从数据获取到模型训练的猫狗分类网络
起 在学习了卷积神经网络的理论基础和阅读了VGG的论文之后,对卷积有了大致的了解,但这都只是停留在理论上,动手实践更为重要,于是便开始了0基础学习pytorch.图像处理,搭建模型. pytorch学 ...
最新文章
- 为什么阿里P8、P9技术大牛反复强调“结构化思维”?
- 我也来晒Flex编写的工作流编辑器
- python3 循环语句
- java for(o t :object) 获取顺序号_java中线程的生命周期
- ansible之setup模块常用的信息
- 一款不错的网站压力测试工具webbench
- final 最终 演练 java
- C语言之文件读写探究(六):fscanf、fprintf(格式化读写文件)
- [恩分到动归分类好了]取石子游戏
- Qt插件机制介绍及实现
- TCP/IP三次握手四次挥手
- 维纳滤波python 函数_python实现逆滤波与维纳滤波示例
- 数字转为大写金额(C#)
- LVDS接口分类,时序,输出格式
- Oracle身份证校验函数
- swarm bzz 安装0.5.3版本基础解析。
- VS code,Live Server更改默认浏览器
- 计算机演示文稿操作,计算机操作与应用 PowerPoint 演示文稿的设计与制作.ppt
- 极限、连续、导数与微分
- C语言 生成随机数 srand用法 伪随机函数rand srand需不需要重新播种问题 srand该不该放在循环里
热门文章
- chrome主页篡改修复
- 3Q之战广东高院上演“熟人新案”
- 一维码(条形码)二维码三维码基本原理
- 12306html布局,12306更新验证码
- Spire.XLS的使用
- bzoj1698 / P1606 [USACO07FEB]白银莲花池Lilypad Pond
- wirehark数据分析与取证attack.pcap
- ESP8266开发之旅 基础篇⑤ ESP8266 SPI通信和I2C通信
- 单片机控制光耦开关继而控制电机转动
- 如何看待越来越多年轻人追捧「摸鱼哲学」,拒绝努力的年轻人真比老一辈活得更通透吗?