实现简单图像分类器

  • 1. 数据加载
    • 1.1 常用公共数据集加载
    • 1.2 私人数据集加载方法
  • 2. 定义神经网络
  • 3. 定义权值更新与损失函数
  • 4. 训练与测试神经网络
  • 5. 神经网络的保存与载入

本篇博客的目标是实现一个简单的图像分类器, 本篇博客主要分为以下几个步骤:数据的加载与归一、定义神经网络、定义损失函数、训练与测试神经网络以及神经网络存储与读取。

1. 数据加载

数据加载就是把训练数据导入到神经网络中并对神经网络进行训练,图像分类器训练数据一般比较大,无法一次性加载所有数据,例如:

  • CIFAR10数据集含有10个类共计6万张图片
  • ImageNet数据集含有1000个类超过100万张图片

因为数据集比较大,所以一般需要用mini-batch形式进行加载并训练:

  • 每个mini-batch只加载所有训练数据集中的一部分数据
  • 任意两个mini-batch之间的数据不重叠
  • 当所有的训练数据集中的数据都被加载并训练完一次被称作一个epoch

因为图像像素值比较大,所以需要对数据归一化:

  • 图像数据像素值一般在[0-255]之间
  • 在训练神经网络时,要把输入数据值变成[0-1]或者[-1-1]之间

与数据加载与归一化相关的PyTorch库为:

  • 数据加载:

    • torchvision.dataset
  • 数据归一:
    • torchvision.transforms
1.1 常用公共数据集加载

本篇博客所用到的数据库为CIFAR10数据库, 一共有10类,每类图片有6000张,图像参数为:

  • 大小: 32x32x3
  • 通道:R,G,B三个通道
  • 像素:每个通道有32x32个像素
# 导入相关pkg
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm# 定义归一化方法
transform = transforms.Compose(# 首先装换数据为tensor张量[transforms.ToTensor(),# 对数据进行正态分布归一化,RGB三个通道每个通道均值为0.5,标准差为0.5transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练数据集
# root为数据存放的目录,train=Ture表示训练集,download=True表示要下载,transform为之前定义的归一化方法
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 定义数据集的加载方法, trainset为训练集, batch_size为单词训练的样本数,shuffle=True表示随机抽取样本,num_workers表示加载的线程数
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,shuffle=True, num_workers=2)# 加载测试数据集
# root为数据存放的目录,train=False表示测试集,download=True表示要下载,transform为之前定义的归一化方法
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 定义数据集的加载方法, testset为测试集, batch_size为单词训练的样本数,shuffle=False表示不随机抽取样本,num_workers表示加载的线程数
testloader = torch.utils.data.DataLoader(testset, batch_size=16,shuffle=False, num_workers=2)
1.2 私人数据集加载方法

在加载私人数据集时,数据集所在文件夹应保持以下结构:

  • /toy_dataset

    • /class_1
    • /class_2

    • 程序上只需要做如下改动:
privateset = torchvision.datasets.ImageFolder(root=image_path, train=True, download=True, transform=transform)

2. 定义神经网络

该部分与上一篇博客中的代码结构一样,只不过网络的参数不同:

# 导入torch包
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义神经网络类
class Net(nn.Module):#定义神经网络结构, 输入数据 1x32x32def __init__(self): super(Net, self).__init__()# 第一层(卷积层)# 输入频道3, 输出频道6, 卷积3x3self.conv1 = nn.Conv2d(3,6,3) # 第二层(卷积层)# 输入频道6, 输出频道16, 卷积3x3self.conv2 = nn.Conv2d(6,16,3) # 第三层(全连接层)# 输入维度16x28x28=12544,输出维度 512self.fc1 = nn.Linear(16*28*28, 512) # 第四层(全连接层)# 输入维度512, 输出维度64self.fc2 = nn.Linear(512, 64) # 第五层(全连接层)# 输入维度64, 输出维度10self.fc3 = nn.Linear(64, 10) # 定义数据流向def forward(self, x): # 数据先经过第一层卷积层x = self.conv1(x)# 经过激活函数x = F.relu(x)# 数据经过第二层卷积层x = self.conv2(x)# 经过激活函数x = F.relu(x)# 调整数据维度,‘-1’表示自动计算维度x = x.view(-1, 16*28*28)# 数据经过第三层全连接层x = self.fc1(x)# 数据经过激活函数x = F.relu(x)# 数据经过第四层全连接层x = self.fc2(x)# 数据经过激活函数x = F.relu(x)# 数据经过第五层全连接层,输出结果x = self.fc3(x)return x

3. 定义权值更新与损失函数

# 新建一个网络net
net = Net()# 导入torch中优化器相关的包
import torch.optim as optim# 定义损失函数为交叉熵函数
criterion = nn.CrossEntropyLoss()# 优化器函数为随机梯度下降, 学习率为0。0001
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

4. 训练与测试神经网络

for epoch in range(2):#训练for i, data in enumerate(trainloader):# 获得数据与标签images, labels = data        # 得到网络的输出outputs = net(images)# 计算损失loss = criterion(outputs, labels) # 清零梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 计算总的损失running_loss += loss.item()#每1000 mini batch 测试一次if(i%1000 == 0):print('Epoch: %d,Step: %d, Loss: %.3f'%(epoch,i,loss.item()) )

运行的一部分结果为:

可以发现,随着训练数据的增加,Loss有明显减少的趋势。

5. 神经网络的保存与载入

保存模型:

# 保存模型参数到路径'./model.pt'中
torch.save(net.state_dict(), './model.pt')

读取模型

# 载入模型
# 定义新Net类的网络net_2
net_2 = Net()
# 将原参数路径的参数加载到net_2网络即可
net_2.load_state_dict(torch.load('./model.pt'))

PyTorch入门(三)--实现简单图像分类器相关推荐

  1. [CS231n Assignment #1] 简单图像分类器——高级图像特征分类

    文章目录 作业介绍 1. 加载数据 2. 提取特征 3. 使用SVM进行训练 5. 使用神经网络训练特征 6. 测试集上测试 作业介绍 作业主页:Assignment #1 作业目的: 在之前的作业中 ...

  2. PyTorch入门(二)--实现简单神经网络

    实现简单神经网络 1. 神经网络基本介绍 2. Autograd包 3. 实现神经网络 3.1 定义神经网络与训练流程 3.2 运行神经网络与计算损失 3.3 反向传递与权值更新 3.4 神经网络中损 ...

  3. 使用 Fastai 构建食物图像分类器

    背景 社交媒体平台是分享有趣的图像的常用方式.食物图像,尤其是与不同的美食和文化相关的图像,是一个似乎经常流行的话题.Instagram 等社交媒体平台拥有大量属于不同类别的图像.我们都可能使用谷歌图 ...

  4. 【PyTorch】实现一个简单的CNN图像分类器

    本文记录了一个简单的基于pytorch的图像多分类器模型构造过程,参考自Pytorch官方文档.磐创团队的<PyTorch官方教程中文版>以及余霆嵩的<PyTorch 模型训练实用教 ...

  5. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  6. 【深度学习】翻译:60分钟入门PyTorch(三)——神经网络

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  7. 如何用PyTorch训练图像分类器

    本文为 AI 研习社编译的技术博客,原标题 : How to Train an Image Classifier in PyTorch and use it to Perform Basic Infe ...

  8. 用PyTorch创建一个图像分类器?So easy!(Part 1)

    经过了几个月的学习和实践,我完成了优达学城网站上<Python Programming with Python Nanodegree>课程的学习,该课程的终极项目就是使用Pytorch为1 ...

  9. pytorch学习笔记(1):开始一个简单的分类器

    参考文档:https://mp.weixin.qq.com/s/wj8wxeaGblJijiHFZA6lXQ 回想了一下自己关于 pytorch 的学习路线,一开始找的各种资料,写下来都能跑,但是却没 ...

最新文章

  1. 学历高的人,喜欢关注什么?
  2. php和mysql的概述_PHP的MySQL扩展:MySQL数据库概述_MySQL
  3. Stanford UFLDL教程 稀疏编码自编码表达
  4. [PBI催化剂]国际水准,中国首款重量级PowerBIDeskTop外部工具问世
  5. openlayers3 根据经纬度 自动画框_用这软件,让你的电脑自动搞黄色
  6. Ubuntu创建python虚拟环境
  7. iOS开发日记49-详解定位CLLocation
  8. reactrouter4路由钩子_react router @4 和 vue路由 详解(八)vue路由守卫
  9. vue学习日志-过滤器
  10. exp oracle 904,9i exp时出现ORA-904、ORA-1003的解决过程
  11. 面试题---测试用例设计
  12. coldfusion_ColdFusion中的数据结构简介
  13. pytorch中的nn.Unfold()函数和fold(函数详解
  14. C++Primer 习题 第7章
  15. 怎样使用Scrapy爬取NVD网站上的数据
  16. AISummit全球人工智能技术大会顺利开幕:首日精彩回顾
  17. idea让字体更圆滑
  18. web前端-CSS Border(边框)-011
  19. vu16和u16的区别 volatile关键字的用法简介
  20. 订单规格数据统计功能总结

热门文章

  1. 湖首大学计算机科学硕士申请,湖首大学王牌专业之一丨计算机科学专业
  2. 2016版系统集成项目管理工程师下午案例分析考试范围
  3. java log4j 热部署_JAVA类加载器分析--热部署的缺陷(有代码示例,及分析)
  4. WordPress疑难问题以及解决方案汇总
  5. CSS浮动(Float)(二)
  6. gulp-notify处理报错----gulp系列(二)
  7. hdu 1556:Color the ball(第二类树状数组 —— 区间更新,点求和)
  8. DEDE留言板调用导航的方法
  9. Javascript学习笔记一 之 数据类型
  10. Centos 云服务器磁盘占用率90%以上的排查解决