基于pytorch的简单图片分类问题实现
pytorch中基于简单图片分类问题的实现大致可以分为以下几个步骤:
1.建立处理图片的神经网络,提前设置好损失函数(图片分类问题一般使用交叉熵损失函数),以及优化器。
2.在每一个学习的步骤中,将训练集的图片输入神经网络,并根据结果对神经网络进行更新,更新后,将测试集的图片输入神经网络得到每一步的得分
3.根据每一步的得分,观察训练的成果。
此处列举一个简单的对CIFAR10图片集进行分类的代码:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# model是提前写好的神经网络,此处直接import进来,下面会给出model的网络结构大致形状
from model import *# 从pytor上获得CIFAR10的训练集和测试集
train_data= torchvision.datasets.CIFAR10(root="18-data", train=True,transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="18-data2", train=False,transform=torchvision.transforms.ToTensor(), download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
# print(f"训练数据集合的长度为{train_data_size}")train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
bao = Bao()# 损失函数:交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器,学习率为0.01,写在外面是因为在某些情况下可以在学习过程中对学习率进行修改
learning_rate = 0.01
optimizer = torch.optim.SGD(bao.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10# tensorboard进行记录想要观察的对象
writer = SummaryWriter("18-logs")for i in range(epoch):bao.train() # 将网络设置成训练模式,只有当网络模型中有一些特定的层的时候才有用,其他时候可以省略for data in train_dataloader:imgs, targets = dataoutputs = bao(imgs) # 得到输出loss = loss_fn(outputs, targets) # 得到损失值optimizer.zero_grad() # 对梯度清零,防止上一轮的梯度影响下一轮的学习loss.backward() # 根据loss值,求出新一轮的梯度optimizer.step() # 根据新一轮的梯度,对神经网络的参数进行更新total_train_step = total_train_step + 1# 测试步骤bao.eval() # 将模型设置成验证状态test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = bao(imgs)loss = loss_fn(outputs, targets)test_loss = test_loss + loss.item() # item()的作用:将tensor数据类型的loss变为真正的数字accuracy = (outputs.argmax(1) == targets).sum() # 求出给出的最大值得种类和实际种类相同的个数total_accuracy = total_accuracy + accuracytotal_test_step = total_test_step + 1writer.add_scalar("loss_per_epoch", test_loss, total_test_step)writer.add_scalar("accuracy_per_epoch", total_accuracy/test_data_size, total_test_step)print(f"整体测试集上的loss为{test_loss}")print(f"整体测试集上的正确率为{total_accuracy/test_data_size}")torch.save(bao, f"bao_{i}.pth")writer.close()
其中的model.py中定义了神经网络:
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linearclass Bao(nn.Module):def __init__(self):super(Bao, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=(5, 5), padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':bao = Bao()input = torch.ones((64, 3, 32, 32))output = bao(input)print(output.shape)
运行程序,在终端中输入tensorboard --logdir=绝对路径,打开tensorboard进行参数观察,得到:
可见,正确率在逐渐上升,loss在逐渐下降。
基于pytorch的简单图片分类问题实现相关推荐
- PyTorch入门-简单图片分类
一. CNN图像分类 PyTorch Version: 1.0.0 import torch import torch.nn as nn import torch.nn.functional as F ...
- 基于Pytorch实现猫狗分类
基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...
- Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)
Kaggle猫狗大战--基于Pytorch的CNN网络分类:数据获取.预处理.载入(1) 第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来.新人 ...
- 轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur
作者丨科技猛兽 编辑丨极市平台 清华大学自动化系智能计算实验室团队开源基于 PyTorch 的视频 (图片) 去模糊框架 SimDeblur. 基于 PyTorch 的视频 (图片) 去模糊框架 Si ...
- 基于keras的CNN图片分类模型的搭建以及参数调试
基于keras的CNN图片分类模型的搭建与调参 更新一下这篇博客,因为最近在CNN调参方面取得了一些进展,顺便做一下总结. 我的项目目标是搭建一个可以分五类的卷积神经网络,然后我找了一些资料看了一些博 ...
- 基于android手机相册,基于安卓的手机图片分类软件的设计与实现.pdf
ELECTRONICS WORLD ・技术交流 基于安卓的手机图片分类软件的设计与实现 武警工程大学研究生管理大队12队 张 鑫 武警广州指挥学院 姜 波 [摘要] 本文针对安卓手机中图片浏览器的快速 ...
- PyTorch ResNet 实现图片分类
PyTorch ResNet 实现图片分类 建党 100 年 Resnet 深度网络退化 代码实现 残差块 超参数 ResNet 18 网络 获取数据 训练 测试 完整代码 建党 100 年 百年风雨 ...
- 基于Pytorch的简单深度学习项目实战
目录 任务详情 训练设置 网络模型 损失函数 优化器 训练步骤 具体代码 导入环境 导入数据 加载数据 创建网络 损失函数 优化器 开始训练 任务详情 利用CIFAR10数据集,基于pytorch环境 ...
- 基于Pytorch的猫狗分类
无偿分享~ 猫狗二分类文件下载地址 在下一章说 猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...
最新文章
- 只学python找工作难吗-只学python语言找工作难吗?
- thinkPHP开发基础知识 包括变量神马的
- Go的cannot convert t(type interface {}) to type string: need type assertion 使用fmt.Sprintf转换
- IDEA中注解注释快捷键及模板
- 蔬菜大棚成本_蔬菜大棚建设标准和成本
- 100999凑整到万位进一_四年级数学第一单元练习作业
- 虚拟桌面显示未注册的情况,可以检查一下几点
- Ubuntu、CentOS、redHat的区别与联系
- Atitit 翻页功能的解决方案与版本历史 v4 r49
- java getname threads_Java8并发教程:Threads和Executors
- 计算机信息检索自考知识点,计算机信息检索02139自考资料.doc
- 用摄动法证明fibs的一个公式(变形)
- 苹果屏蔽更新描述文件_iPhone|IOS10-IOS12屏蔽系统更新描述文件|去除设置①小红点教程...
- 什么农村大学生大多混得比较差-第一性原理分析
- win10去除右下角激活水印
- sql server 替换字段中的部分字符,替换指定字符的方法
- NOI 2016 游记
- 文件下载时设置文件名以及中文被转换成下划线的解决办法
- Perameter estimation for text analyse (下)
- 银行业务中的清算和结算分别是什么样的过程
热门文章
- EXCEL自定义函数无法运行的原因:可以在VBA里运行的函数,在EXCEL用自定义函数为什么报错?
- Unity Shader自定义光照模型
- 【大数据 AI】视觉ChatGPT来了,微软发布,代码已开源
- c++ memset函数使用及头文件
- 北科大计算机专业研究生多少分能考上,2022年北京科技大学计算机考研初试复习经验分享...
- Xmind基础教程-下钻和上钻
- sq语句l补充(一)
- Python使用opencv 打开摄像头
- 关于仕族_仕族信息_服务中心_仕族服务_男装:衬衫、法式衬衫、袖扣领带、西服西裤等男士正装服饰-仕族官网...
- 小程序学习记录(二)——view、text、image标签、flex布局