@Author:Runsen

上次基于CIFAR-10 数据集,使用PyTorch ​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoaderimport torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutilsimport numpy as np
import os
import warnings
from matplotlib import pyplot as plt
warnings.filterwarnings('ignore')`
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载数据集

# number of images in one forward and backward pass
batch_size = 128# number of subprocesses used for data loading
# Normally do not use it if your os is windows
num_workers = 2train_dataset = datasets.CIFAR10('./data/CIFAR10/', train = True, download = True, transform = transform_train)train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)val_dataset = datasets.CIFAR10('./data/CIFAR10', train = True, transform = transform_test)val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)test_dataset = datasets.CIFAR10('./data/CIFAR10', train = False, transform = transform_test)test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)# declare classes in CIFAR10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

之前的transform ’只是进行了缩放和归一,在这里添加RandomCrop和RandomHorizontalFlip

# define a transform to normalize the datatransform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(), # converting images to tensortransforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) # if the image dataset is black and white image, there can be just one number.
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

可视化具体的图像

# function that will be used for visualizing the datadef imshow(img):img = img / 2 + 0.5  # unnormalizeplt.imshow(np.transpose(img, (1, 2, 0)))  # convert from Tensor image# obtain one batch of imges from train dataset
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy() # convert images to numpy for display# plot the images in one batch with the corresponding labels
fig = plt.figure(figsize = (25, 4))# display images
for idx in np.arange(10):ax = fig.add_subplot(1, 10, idx+1, xticks=[], yticks=[])imshow(images[idx])ax.set_title(classes[labels[idx]])

建立常见的CNN模型

# define the CNN architectureclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.main = nn.Sequential(# 3x32x32nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), # 3x32x32 (O = (N+2P-F/S)+1)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size = 2, stride = 2), # 32x16x16nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size = 3, padding = 1), # 32x16x16nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 64x8x8nn.BatchNorm2d(64),nn.Conv2d(64, 128, 3, padding = 1), # 64x8x8nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 128x4x4nn.BatchNorm2d(128),)self.fc = nn.Sequential(nn.Linear(128*4*4, 1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, 10))def forward(self, x):# Conv and Poolilng layersx = self.main(x)# Flatten before Fully Connected layersx = x.view(-1, 128*4*4) # Fully Connected Layerx = self.fc(x)return xcnn = CNN().to(device)
cnn


torch.nn.CrossEntropyLoss对输出概率介于0和1之间的分类模型进行分类。

训练模型

# 超参数:Hyper Parameters
learning_rate = 0.001
train_losses = []
val_losses = []# Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = learning_rate)# define train function that trains the model using a CIFAR10 datasetdef train(model, epoch, num_epochs):model.train()total_batch = len(train_dataset) // batch_sizefor i, (images, labels) in enumerate(train_loader):X = images.to(device)Y = labels.to(device)### forward pass and loss calculation# forward passpred = model(X)#c alculation  of loss valuecost = criterion(pred, Y)### backward pass and optimization# gradient initializationoptimizer.zero_grad()# backward passcost.backward()# parameter updateoptimizer.step()# training statsif (i+1) % 100 == 0:print('Train, Epoch [%d/%d], lter [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, i+1, total_batch, np.average(train_losses)))train_losses.append(cost.item())n# def the validation function that validates the model using CIFAR10 datasetdef validation(model, epoch, num_epochs):model.eval()total_batch = len(val_dataset) // batch_sizefor i, (images, labels) in enumerate(val_loader):X = images.to(device)Y = labels.to(device)with torch.no_grad():pred = model(X)cost = criterion(pred, Y)if (i+1) % 100 == 0:print("Validation, Epoch [%d/%d], lter [%d/%d], Loss: %.4f"% (epoch+1, num_epochs, i+1, total_batch, np.average(val_losses)))val_losses.append(cost.item())def plot_losses(train_losses, val_losses):plt.figure(figsize=(5, 5))plt.plot(train_losses, label='Train', alpha=0.5)plt.plot(val_losses, label='Validation', alpha=0.5)plt.xlabel('Epochs')plt.ylabel('Losses')plt.legend()plt.grid(b=True)plt.title('CIFAR 10 Train/Val Losses Over Epoch')plt.show()num_epochs = 20
for epoch in range(num_epochs):train(cnn, epoch, num_epochs)validation(cnn, epoch, num_epochs)torch.save(cnn.state_dict(), './data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))plot_losses(train_losses, val_losses)


测试模型

def test(model):# declare that the model is about to evaluatemodel.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_dataset:images = images.unsqueeze(0).to(device)# forward passoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += 1correct += (predicted == labels).sum().item()print("Accuracy of Test Images: %f %%" % (100 * float(correct) / total))


经过图像数据增强。模型从60提升到了84。

测试模型在哪些类上表现良好,

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = cnn(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度相关推荐

  1. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...

    「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...

  2. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

  3. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  4. 【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext

    @Author:Runsen 对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext. 之前使用 torchDataLoader类直接加载图像并将其转换为张量. ...

  5. 【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

    「@Author:Runsen」 有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难. 因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作.对此,PyTor ...

  6. 【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型

    @Author:Runsen BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?首先介绍使用BERT做文本多标签分类任务. 文本多标签分类是常见的NLP ...

  7. 【小白学习PyTorch教程】十九、 基于torch实现UNet 图像分割模型

    @Author:Runsen 在图像领域,除了分类,CNN 今天还用于更高级的问题,如图像分割.对象检测等.图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段. 上面图片 ...

  8. 【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类

    「@Author:Runsen」 上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类. ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算 ...

  9. 【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型

    「@Author:Runsen」 当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词.这可以称为记忆.卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网 ...

最新文章

  1. 【Java 并发编程】线程指令重排序问题 ( 指令重排序规范 | volatile 关键字禁止指令重排序 )
  2. 086_html5Input类型
  3. mysql语言中有什么运算_SQL知识点,新手感悟
  4. Shiro系列-Shiro的怎么进行授权操作
  5. P1262 间谍网络 (tarjan缩点 水过去)
  6. Docker使用小结(四)发布镜像
  7. 网站后台开发 java_Java前后台开发
  8. 最新emoji表情代码大全_git commit 时使用 Emoji ?
  9. 手把手教你架构3D引擎高级篇系列二
  10. 【C++】STL学习小总结
  11. 什么是TorchScript
  12. python量化分析
  13. linux 动态监控进程
  14. AliOS Things的启动过程分析(一)
  15. 无需越狱或安装应用在 iPhone 和 iPad 上打开 Flash 视频
  16. CSAPP导读第3章 程序的机器级表示
  17. Profibus DP-Slave in C
  18. Graph Neural Network for Traffic Forecasting: A Survey
  19. 信创操作系统--麒麟Kylin桌面版 (项目七 网络连接:有线、无线网络)
  20. 无法安装战网,提示007D

热门文章

  1. Qt / 伪状态和子部件
  2. mmap 和 shm 区别
  3. python queue 调试_python:如何创建用于调试的持久内存结构
  4. kinux mysql报错10038_navicat连接linux系统中mysql-错误:10038
  5. 日志中台不重不丢实现浅谈
  6. glide默认的缓存图片路径地址_手写一个静态资源中间件,加深了解服务器对文件请求的缓存策略...
  7. python做自动化测试的优点_乐搏讲自动化测试-python语言特点及优缺点(5)
  8. 最小帧长度的计算公式_网络工程师考试常用计算公式汇总(二)
  9. Universal-Image-Loader(UIL)图片载入框架使用简介
  10. 20162303 2016-2017-2 《程序设计与数据结构》第五周学习总结