[深度应用]·实战掌握PyTorch图片分类简明教程

个人网站--> http://www.yansongsong.cn

项目GitHub地址--> https://github.com/xiaosongshine/image_classifier_PyTorch

1.引文

深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。

我们在[PyTorch小试牛刀]实战六·准备自己的数据集用于训练讲解了如何制作自己的数据集用于训练,这个教程在此基础上,进行训练与应用。

2.数据介绍

数据 下载地址

这次的实战使用的数据是交通标志数据集,共有62类交通标志。其中训练集数据有4572张照片(每个类别大概七十个),测试数据集有2520张照片(每个类别大概40个)。数据包含两个子目录分别train与test:

为什么还需要测试数据集呢?这个测试数据集不会拿来训练,是用来进行模型的评估与调优。

train与test每个文件夹里又有62个子文件夹,每个类别在同一个文件夹内:

我从中打开一个文件间,把里面图片展示出来:

其中每张照片都类似下面的例子,100*100*3的大小。100是照片的照片的长和宽,3是什么呢?这其实是照片的色彩通道数目,RGB。彩色照片存储在计算机里就是以三维数组的形式。我们送入网络的也是这些数组。

3.网络构建

1.导入Python包,定义一些参数

import torch as t
import torchvision as tv
import os
import time
import numpy as np
from tqdm import tqdmclass DefaultConfigs(object):data_dir = "./traffic-sign/"data_list = ["train","test"]lr = 0.001epochs = 10num_classes = 62image_size = 224batch_size = 40channels = 3gpu = "0"train_len = 4572test_len = 2520use_gpu = t.cuda.is_available()config = DefaultConfigs()

2.数据准备,采用PyTorch提供的读取方式(具体内容参考[PyTorch小试牛刀]实战六·准备自己的数据集用于训练

注意一点Train数据需要进行随机裁剪,Test数据不要进行裁剪了

normalize = tv.transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])transform = {config.data_list[0]:tv.transforms.Compose([tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]),tv.transforms.ToTensor(),normalize]#tv.transforms.Resize 用于重设图片大小) ,config.data_list[1]:tv.transforms.Compose([tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize])
}datasets = {x:tv.datasets.ImageFolder(root = os.path.join(config.data_dir,x),transform=transform[x])for x in config.data_list
}dataloader = {x:t.utils.data.DataLoader(dataset= datasets[x],batch_size=config.batch_size,shuffle=True) for x in config.data_list
}

3.构建网络模型(使用resnet18进行迁移学习,训练参数为最后一个全连接层 t.nn.Linear(512,num_classes)) 

def get_model(num_classes):model = tv.models.resnet18(pretrained=True)for parma in model.parameters():parma.requires_grad = Falsemodel.fc = t.nn.Sequential(t.nn.Dropout(p=0.3),t.nn.Linear(512,num_classes))return(model)

如果电脑硬件支持,可以把下述代码屏蔽,则训练整个网络,最终准确率会上升,训练数据会变慢。

for parma in model.parameters():parma.requires_grad = False

模型输出

ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)(fc): Sequential((0): Dropout(p=0.3)(1): Linear(in_features=512, out_features=62, bias=True))
)

4.训练模型(支持自动GPU加速,GPU使用教程参考:[开发技巧]·PyTorch如何使用GPU加速

def train(epochs):model = get_model(config.num_classes)print(model)loss_f = t.nn.CrossEntropyLoss()if(config.use_gpu):model = model.cuda()loss_f = loss_f.cuda()opt = t.optim.Adam(model.fc.parameters(),lr = config.lr)time_start = time.time()for epoch in range(epochs):train_loss = []train_acc = []test_loss = []test_acc = []model.train(True)print("Epoch {}/{}".format(epoch+1,epochs))for batch, datas in tqdm(enumerate(iter(dataloader["train"]))):x,y = datasif (config.use_gpu):x,y = x.cuda(),y.cuda()y_ = model(x)#print(x.shape,y.shape,y_.shape)_, pre_y_ = t.max(y_,1)pre_y = y#print(y_.shape)loss = loss_f(y_,pre_y)#print(y_.shape)acc = t.sum(pre_y_ == pre_y)loss.backward()opt.step()opt.zero_grad()if(config.use_gpu):loss = loss.cpu()acc = acc.cpu()train_loss.append(loss.data)train_acc.append(acc)#if((batch+1)%5 ==0):time_end = time.time()print("Batch {}, Train loss:{:.4f}, Train acc:{:.4f}, Time: {}"\.format(batch+1,np.mean(train_loss)/config.batch_size,np.mean(train_acc)/config.batch_size,(time_end-time_start)))time_start = time.time()model.train(False)for batch, datas in tqdm(enumerate(iter(dataloader["test"]))):x,y = datasif (config.use_gpu):x,y = x.cuda(),y.cuda()y_ = model(x)#print(x.shape,y.shape,y_.shape)_, pre_y_ = t.max(y_,1)pre_y = y#print(y_.shape)loss = loss_f(y_,pre_y)acc = t.sum(pre_y_ == pre_y)if(config.use_gpu):loss = loss.cpu()acc = acc.cpu()test_loss.append(loss.data)test_acc.append(acc)print("Batch {}, Test loss:{:.4f}, Test acc:{:.4f}".format(batch+1,np.mean(test_loss)/config.batch_size,np.mean(test_acc)/config.batch_size))t.save(model,str(epoch+1)+"ttmodel.pkl")if __name__ == "__main__":train(config.epochs)

训练结果如下:

Epoch 1/10
115it [00:48,  2.63it/s]
Batch 115, Train loss:0.0590, Train acc:0.4635, Time: 48.985504150390625
63it [00:24,  2.62it/s]
Batch 63, Test loss:0.0374, Test acc:0.6790, Time :24.648272275924683
Epoch 2/10
115it [00:45,  3.22it/s]
Batch 115, Train loss:0.0271, Train acc:0.7576, Time: 45.68823838233948
63it [00:23,  2.62it/s]
Batch 63, Test loss:0.0255, Test acc:0.7524, Time :23.271782875061035
Epoch 3/10
115it [00:45,  3.19it/s]
Batch 115, Train loss:0.0181, Train acc:0.8300, Time: 45.92648506164551
63it [00:23,  2.60it/s]
Batch 63, Test loss:0.0212, Test acc:0.7861, Time :23.80789279937744
Epoch 4/10
115it [00:45,  3.28it/s]
Batch 115, Train loss:0.0138, Train acc:0.8767, Time: 45.27525019645691
63it [00:23,  2.57it/s]
Batch 63, Test loss:0.0173, Test acc:0.8385, Time :23.736321449279785
Epoch 5/10
115it [00:44,  3.22it/s]
Batch 115, Train loss:0.0112, Train acc:0.8950, Time: 44.983638286590576
63it [00:22,  2.69it/s]
Batch 63, Test loss:0.0156, Test acc:0.8520, Time :22.790074348449707
Epoch 6/10
115it [00:44,  3.19it/s]
Batch 115, Train loss:0.0095, Train acc:0.9159, Time: 45.10426950454712
63it [00:22,  2.77it/s]
Batch 63, Test loss:0.0158, Test acc:0.8214, Time :22.80412459373474
Epoch 7/10
115it [00:45,  2.95it/s]
Batch 115, Train loss:0.0081, Train acc:0.9280, Time: 45.30439043045044
63it [00:23,  2.66it/s]
Batch 63, Test loss:0.0139, Test acc:0.8528, Time :23.122379541397095
Epoch 8/10
115it [00:44,  3.23it/s]
Batch 115, Train loss:0.0073, Train acc:0.9300, Time: 44.304762840270996
63it [00:22,  2.74it/s]
Batch 63, Test loss:0.0142, Test acc:0.8496, Time :22.801835536956787
Epoch 9/10
115it [00:43,  3.19it/s]
Batch 115, Train loss:0.0068, Train acc:0.9361, Time: 44.08414030075073
63it [00:23,  2.44it/s]
Batch 63, Test loss:0.0142, Test acc:0.8437, Time :23.604419231414795
Epoch 10/10
115it [00:46,  3.12it/s]
Batch 115, Train loss:0.0063, Train acc:0.9337, Time: 46.76597046852112
63it [00:24,  2.65it/s]
Batch 63, Test loss:0.0130, Test acc:0.8591, Time :24.64351773262024

训练10个Epoch,测试集准确率可以到达0.86,已经达到不错效果。通过修改参数,增加训练,可以达到更高的准确率。

[深度应用]·实战掌握PyTorch图片分类简明教程相关推荐

  1. [深度应用]·实战掌握Dlib人脸识别开发教程

    [深度应用]·实战掌握Dlib人脸识别开发教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com/xi ...

  2. 实战:掌握PyTorch图片分类的简明教程 | 附完整代码

    作者 | 小宋是呢 转载自CSDN博客 1.引文 深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分.如果大家 ...

  3. 【tensorflow 深度学习】8.训练图片分类模型

    1.训练图片分类模型的三种方法 (1).从无到有,先确定好算法框架,准备好需要训练的数据集,从头开始训练,参数一开始也是初始化的随机值,一个批次一个批次地进行训练. (2).准备好已经训练好的模型,权 ...

  4. 【深度学习】卷积神经网络-图片分类案例(pytorch实现)

    前言 前文已经介绍过卷积神经网络的基本概念[深度学习]卷积神经网络-CNN简单理论介绍.下面开始动手实践吧.本文任务描述如下: 从公开数据集CIFAR10中创建训练集.测试集数据,使用Pytorch构 ...

  5. python猫狗大战pytorch_深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  6. [深度学习 - 实战项目] CRAFTCRNN_seq2seq图片文字提取

    图片文字提取项目 检测网络:CRAFT,基于字符区域感知的文本检测: CRAFT源码:https://github.com/clovaai/CRAFT-pytorch 识别网络:crnn+seq2se ...

  7. 深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  8. matlab对直方图分类,matlab根据直方图进行图片分类

    matlab根据直方图进行图片分类 matlab根据直方图进行图片分类 感觉还有一些bug需要调试,不过还是先写出来吧 将一张图片由rgb转hsv空间,并进行量化 function [Hh,Vv,Ss ...

  9. Java Web 简明教程

    点此查看 所有教程.项目.源码导航 1. 前言 本教程用于介绍Java Web开发入门的方方面面,包括开发环境.工具.网页.Java.数据库等. 本教程写于2016年底,一些内容相对比较陈旧了,新版的 ...

最新文章

  1. 我是如何在尼日利亚的沃里创立Google Developers Group GDG分会的,并达到了100位成员...
  2. Linux系统查看系统是32位还是64位方法总结 in 创新实训
  3. Android中自定义Dialog外形,去除黑底和白色边框
  4. 为什么我从 npm 到 yarn 再到 npm?
  5. ACM OJ反馈结果大全
  6. 在非主线程中创建窗口
  7. swiper.js使用心得
  8. double类型占几个字节_MongoDB 中的数据类型
  9. 数据仓库组件:Hive环境搭建和基础用法
  10. 微信登陆报错:redirect_uri域名与后台配置不一致,错误码:10003 微信支付报错 微信登录报错 微信开发
  11. mysql 压测结果_用mysqlslap压测mysql
  12. oracle ora00020,ORA-00020: maximum number of processes (1000) 错误处理
  13. ECHAR学习-Part1 文字特效
  14. WeNet语音识别实战
  15. 12306抢票使用教程
  16. python 中的self和cls
  17. 服务器 分辨率问题 显示器不显示不出来,显示器没有最佳分辨率及分辨率调不了的解决方法...
  18. C# 线程的挂起与唤醒 (AutoResetEvent,ManualResetEvent)
  19. WIFI AP和STATION
  20. IIS 返回405报错解决过程

热门文章

  1. 发明专利申请的费用核流程
  2. 赶紧注册你的@live.xx邮箱吧!
  3. 关于 A/B 测试那些事儿
  4. 如何提高自己的知识水平?
  5. 用手机访问计算机共享资源,手机访问电脑文件 手机与电脑如何共享数据?AirShareUp 云悦享...
  6. `英语` 2022/8/24
  7. 【spring】PO,VO,DAO,BO,POJO,Bean之间的区别与解释
  8. 新整理常见互联网公司职级和薪资一览!
  9. 基于SSM实现餐厅收银系统
  10. 《2017中国智慧停车行业大数据报告》