pytorch中基于简单图片分类问题的实现大致可以分为以下几个步骤:

1.建立处理图片的神经网络,提前设置好损失函数(图片分类问题一般使用交叉熵损失函数),以及优化器。

2.在每一个学习的步骤中,将训练集的图片输入神经网络,并根据结果对神经网络进行更新,更新后,将测试集的图片输入神经网络得到每一步的得分

3.根据每一步的得分,观察训练的成果。

此处列举一个简单的对CIFAR10图片集进行分类的代码:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# model是提前写好的神经网络,此处直接import进来,下面会给出model的网络结构大致形状
from model import *# 从pytor上获得CIFAR10的训练集和测试集
train_data= torchvision.datasets.CIFAR10(root="18-data", train=True,transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="18-data2", train=False,transform=torchvision.transforms.ToTensor(), download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
# print(f"训练数据集合的长度为{train_data_size}")train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
bao = Bao()# 损失函数:交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器,学习率为0.01,写在外面是因为在某些情况下可以在学习过程中对学习率进行修改
learning_rate = 0.01
optimizer = torch.optim.SGD(bao.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10# tensorboard进行记录想要观察的对象
writer = SummaryWriter("18-logs")for i in range(epoch):bao.train()    # 将网络设置成训练模式,只有当网络模型中有一些特定的层的时候才有用,其他时候可以省略for data in train_dataloader:imgs, targets = dataoutputs = bao(imgs)    # 得到输出loss = loss_fn(outputs, targets)    # 得到损失值optimizer.zero_grad()    # 对梯度清零,防止上一轮的梯度影响下一轮的学习loss.backward()    # 根据loss值,求出新一轮的梯度optimizer.step()    # 根据新一轮的梯度,对神经网络的参数进行更新total_train_step = total_train_step + 1# 测试步骤bao.eval()  # 将模型设置成验证状态test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = bao(imgs)loss = loss_fn(outputs, targets)test_loss = test_loss + loss.item()    # item()的作用:将tensor数据类型的loss变为真正的数字accuracy = (outputs.argmax(1) == targets).sum()    # 求出给出的最大值得种类和实际种类相同的个数total_accuracy = total_accuracy + accuracytotal_test_step = total_test_step + 1writer.add_scalar("loss_per_epoch", test_loss, total_test_step)writer.add_scalar("accuracy_per_epoch", total_accuracy/test_data_size, total_test_step)print(f"整体测试集上的loss为{test_loss}")print(f"整体测试集上的正确率为{total_accuracy/test_data_size}")torch.save(bao, f"bao_{i}.pth")writer.close()

其中的model.py中定义了神经网络:

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linearclass Bao(nn.Module):def __init__(self):super(Bao, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=(5, 5), padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':bao = Bao()input = torch.ones((64, 3, 32, 32))output = bao(input)print(output.shape)

运行程序,在终端中输入tensorboard --logdir=绝对路径,打开tensorboard进行参数观察,得到:

可见,正确率在逐渐上升,loss在逐渐下降。

基于pytorch的简单图片分类问题实现相关推荐

  1. PyTorch入门-简单图片分类

    一. CNN图像分类 PyTorch Version: 1.0.0 import torch import torch.nn as nn import torch.nn.functional as F ...

  2. 基于Pytorch实现猫狗分类

    基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...

  3. Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)

    Kaggle猫狗大战--基于Pytorch的CNN网络分类:数据获取.预处理.载入(1) 第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来.新人 ...

  4. 轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur

    作者丨科技猛兽 编辑丨极市平台 清华大学自动化系智能计算实验室团队开源基于 PyTorch 的视频 (图片) 去模糊框架 SimDeblur. 基于 PyTorch 的视频 (图片) 去模糊框架 Si ...

  5. 基于keras的CNN图片分类模型的搭建以及参数调试

    基于keras的CNN图片分类模型的搭建与调参 更新一下这篇博客,因为最近在CNN调参方面取得了一些进展,顺便做一下总结. 我的项目目标是搭建一个可以分五类的卷积神经网络,然后我找了一些资料看了一些博 ...

  6. 基于android手机相册,基于安卓的手机图片分类软件的设计与实现.pdf

    ELECTRONICS WORLD ・技术交流 基于安卓的手机图片分类软件的设计与实现 武警工程大学研究生管理大队12队 张 鑫 武警广州指挥学院 姜 波 [摘要] 本文针对安卓手机中图片浏览器的快速 ...

  7. PyTorch ResNet 实现图片分类

    PyTorch ResNet 实现图片分类 建党 100 年 Resnet 深度网络退化 代码实现 残差块 超参数 ResNet 18 网络 获取数据 训练 测试 完整代码 建党 100 年 百年风雨 ...

  8. 基于Pytorch的简单深度学习项目实战

    目录 任务详情 训练设置 网络模型 损失函数 优化器 训练步骤 具体代码 导入环境 导入数据 加载数据 创建网络 损失函数 优化器 开始训练 任务详情 利用CIFAR10数据集,基于pytorch环境 ...

  9. 基于Pytorch的猫狗分类

    无偿分享~ 猫狗二分类文件下载地址 在下一章说        猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...

最新文章

  1. 只学python找工作难吗-只学python语言找工作难吗?
  2. thinkPHP开发基础知识 包括变量神马的
  3. Go的cannot convert t(type interface {}) to type string: need type assertion 使用fmt.Sprintf转换
  4. IDEA中注解注释快捷键及模板
  5. 蔬菜大棚成本_蔬菜大棚建设标准和成本
  6. 100999凑整到万位进一_四年级数学第一单元练习作业
  7. 虚拟桌面显示未注册的情况,可以检查一下几点
  8. Ubuntu、CentOS、redHat的区别与联系
  9. Atitit 翻页功能的解决方案与版本历史 v4 r49
  10. java getname threads_Java8并发教程:Threads和Executors
  11. 计算机信息检索自考知识点,计算机信息检索02139自考资料.doc
  12. 用摄动法证明fibs的一个公式(变形)
  13. 苹果屏蔽更新描述文件_iPhone|IOS10-IOS12屏蔽系统更新描述文件|去除设置①小红点教程...
  14. 什么农村大学生大多混得比较差-第一性原理分析
  15. win10去除右下角激活水印
  16. sql server 替换字段中的部分字符,替换指定字符的方法
  17. NOI 2016 游记
  18. 文件下载时设置文件名以及中文被转换成下划线的解决办法
  19. Perameter estimation for text analyse (下)
  20. 银行业务中的清算和结算分别是什么样的过程

热门文章

  1. EXCEL自定义函数无法运行的原因:可以在VBA里运行的函数,在EXCEL用自定义函数为什么报错?
  2. Unity Shader自定义光照模型
  3. 【大数据 AI】视觉ChatGPT来了,微软发布,代码已开源
  4. c++ memset函数使用及头文件
  5. 北科大计算机专业研究生多少分能考上,2022年北京科技大学计算机考研初试复习经验分享...
  6. Xmind基础教程-下钻和上钻
  7. sq语句l补充(一)
  8. Python使用opencv 打开摄像头
  9. 关于仕族_仕族信息_服务中心_仕族服务_男装:衬衫、法式衬衫、袖扣领带、西服西裤等男士正装服饰-仕族官网...
  10. 小程序学习记录(二)——view、text、image标签、flex布局