1 图片channels_【深度学习】卷积神经网络图片分类案例(pytorch实现)
文 | 菊子皮 (转载请注明出处)B站:科皮子菊
前言
前文已经介绍过卷积神经网络的基本概念【深度学习】卷积神经网络-CNN简单理论介绍[1]。下面开始动手实践吧。本文任务描述如下:从公开数据集CIFAR10中创建训练集、测试集数据,使用Pytorch构建CNN模型对训练集数据进行训练,然后在测试集中测试,查看简单的CNN模型训练效果如何。CIFAR10公开数据地址:http://www.cs.toronto.edu/~kriz/cifar.html[2]。CIFAR-10数据集包含 60000 张 32x32 的彩色 10 中类型的数据, 其中50000张训练图片和10000测试图片。下面是每个类别随机10张图片的结果。
下面就根据CNN的原理:【深度学习】卷积神经网络-CNN简单理论介绍[3]设计相关的网络和程序,建议使用jupyter。
1 数据集
这里结合torchvision包下载相关数据集并进行数据的预处理。代码如下:
import torchimport torchvisionimport torchvision.transforms as transforms
# 数据预处理转换器transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 构建训练集数据,使用transform处理数据集供模型直接使用train_set = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)# 将数据集转成可迭代的批次处理数据train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=1)# 同理构建测试集数据test_set = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform)test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=True, num_workers=1)# 数据对应标签classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
注: 如果数据下载较慢,可以在官网下载(cifar-10-python.tar.gz),然后放在data目录下,再将download改为False即可。
另外:transforms.Compose组合两种处理图片的方法,一种是将图片转成模型输入的张量格式数据,另一个则为数据标准化函数。
到此数据应该处理好了。下面我们可以查看一下数据,代码如下:
import matplotlib.pyplot as pltimport numpy as np%matplotlib inline
def imshow(img): """ 展示图片 img:图片数据 """ img = img / 2 + 0.5 # 反标准化 npimg = img.numpy() # 将数据转换成numpy格式 plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 随机获取部分训练数据dataiter = iter(train_loader)images, labels = dataiter.next()# 显示图像imshow(torchvision.utils.make_grid(images))# 打印标签print(" ".join('%5s' % classes[labels[j]] for j in range(4))) # 结果:car ship truck horse
我这边随机显示的图片如下:
到这里数据已基本上没什么问题了。
注: 这里暂不考究图像预处理的内容,感兴趣的可以深入了解。
2 构建网络
下面就是构建CNN网络了,根据之前的介绍以及pytorch,代码如下:
import torch.nn as nnimport torch.nn.functional as F
# 如果有gpu的可使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class CNNNet(nn.Module): def __init__(self): super(CNNNet, self).__init__() # 定义第一个卷积层 self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1) # 定义第一个池化层 self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 定义第二个卷积层 self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1) # 定义第二个池化层 self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 定义第一个全连接层 self.fc1 = nn.Linear(1296, 128) # 定义第二个全连接层 self.fc2 = nn.Linear(128, 10)
def forward(self, x): # 连接各个cnn各个模块 x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) # print(x.shape) x=x.view(-1, 36*6*6) x=F.relu(self.fc2(F.relu(self.fc1(x)))) # 返回运算后的结果 return x
# 实例化模型net = CNNNet()net.to(device) # 模型设备转移# 查看模型print(net)
打印的模型输出结果如下:
CNNNet( (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1)) (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(16, 36, kernel_size=(3, 3), stride=(1, 1)) (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (fc1): Linear(in_features=1296, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=10, bias=True))
针对模型中一些层的参数介绍如下:
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,bias=True, padding_model='zeros') 其中主要参数释义:
- in_channels(int):输入图片的通道数目,彩色图片的通道数为3(RGB)
- out_channels(int): 卷积产生的通道数
- kernel_size(int or tuple):卷积核的尺寸,单个值则认为卷积核长宽相同
- stride(int or tuple):卷积步长
- padding(int or tuple, optional):输入的每一条边填充0的圈数,参数可选,默认为0
- bias(bool, optional):如果bias=True,添加偏置。
这里需要补充一下多维数据中卷积是如何计算的
对于一张三通道的图片,每个通道对应一个卷积核,最后计算得到三个结果矩阵,三个结果矩阵对应值相加到最后的结果,即为一个输出通道。输出形状的计算可参考前文。
torch.nn.MaxPool2d(kernel_size,stride=None,padding=0,diltion=1,return_indices=False,ceil_mode=False) 其中主要参数如下:
- kernel_size:池化窗口的大小[height, weight],如果一个数,则两者相等
- stride:窗口在每个维度上滑动的步长,一般为[stride_h, stride_w],如果两者相等,则可为一个数字
输出形状的计算可参考前文。
3 选择优化器和损失函数
由于是分类问题,以及做个简单的模型实现,选择交叉熵损失函数以及带动量的随机梯度下降算法,如下:
## 选择优化器import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4 模型训练
模型训练10轮,使用train_loader构建的mini-batch迭代器进行mini-batch的数据训练,每2000个mini-batch打印一次。代码如下:
## 训练模型for epoch in range(10): running_loss = 0.0 # 迭代,批次训练 for i, data in enumerate(train_loader, 0): # 获取训练数据 inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # 权重参数梯度清零 optimizer.zero_grad()
# 正向传播 outputs = net(inputs) # 计算损失值 loss = criterion(outputs, labels) # 反向传播 loss.backward() # 参数更新 optimizer.step() # 损失值累加 running_loss += loss.item() # 每2000个mini-batch显示一次损失值 if i % 2000 == 1999: print('[%d, %d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0print('Finished Training')
对应的训练过程如下图:
5 模型测试
模型训练结束之后,就可以进行模型的测试了。模型使用时,不需要进去相关梯度计算,则需要使用torch.no_grad()
方法。
correct = 0 # 预测正确数目total = 0 # 测试样本总数with torch.no_grad(): for data in test_loader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) # 获取批次预测结果 total += labels.size(0) # 批次数目累加 correct += (predicted == labels).sum().item() # 预测正确数累加
print('Accuracy of the network on the 10000 test images: %d %%' %(100*correct/total))# Accuracy of the network on the 10000 test images: 68 %
最后得到的结果是:这个简单的CNN模型对10分类的数据正确率达到68%,总体还是可以的。我们可以看看各个类别预测的正确率,使用代码如下:
## 各类别的准确率class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad(): for data in test_loader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1
for i in range(10): print('Accuracy of %5s: %2d %%' %(classes[i], 100 * class_correct[i] /class_total[i]))
结果如下:
Accuracy of plane: 74 %Accuracy of car: 83 %Accuracy of bird: 56 %Accuracy of cat: 41 %Accuracy of deer: 65 %Accuracy of dog: 61 %Accuracy of frog: 75 %Accuracy of horse: 73 %Accuracy of ship: 78 %Accuracy of truck: 71 %
6 总结
总体来说,CNN是一个简单而高效的神经网络算法,当前也有很多基于经典CNN改进的卷积神经网络,感兴趣的可以深入学习,后面我们可以看看CNN是如何进行文本分类的,也就是textcnn模型。
参考资料
[1]
【深度学习】卷积神经网络-CNN简单理论介绍: https://piqiandong.blog.csdn.net/article/details/109905697
[2]
http://www.cs.toronto.edu/~kriz/cifar.html: http://www.cs.toronto.edu/~kriz/cifar.html
[3]
【深度学习】卷积神经网络-CNN简单理论介绍: https://piqiandong.blog.csdn.net/article/details/109905697
剑指Offer刷题集| Python基础
更多精彩文章
卷积神经网络-CNN简单理论介绍
RST文件打开——以torchtext官方github文档为例
机器学习分类器性能标准(Accuracy、Precision、Recall、P-R曲线、F1等)你是否真的懂了?
不仅仅要会导别人的包也要会导自定义的包——Python导包总结!
如何用Python书写计算任一多变量函数任一点的偏导导数值?
一个绘制决策树的工具——graphviz,但你忽视了它的其他功能,如结构句法分析结果!
VSCode再添新功能——绘制流程图,来看看如何操作吧!
如何解决下载Github源码慢的问题
Python中的迭代器和生成器
最优化问题:拉格朗日乘子法、KKT条件以及对偶问题
1 图片channels_【深度学习】卷积神经网络图片分类案例(pytorch实现)相关推荐
- pytorch卷积神经网络_【深度学习】卷积神经网络图片分类案例(pytorch实现)
文 | 菊子皮 (转载请注明出处)B站:科皮子菊 前言 前文已经介绍过卷积神经网络的基本概念[深度学习]卷积神经网络-CNN简单理论介绍[1].下面开始动手实践吧.本文任务描述如下:从公开数据集CIF ...
- 深度学习/联邦学习笔记(六)卷积神经及相关案例+pytorch
深度学习/联邦学习笔记(六) 卷积神经及相关案例+pytorch 卷积神经网络不同于一般的全连接神经网络,卷积神经网络是一个3D容量的神经元,即神经元是以三个维度来排列的:宽度.高度和深度 卷积神经网 ...
- 【人工智能项目】卷积神经网络图片分类框架
[人工智能项目]卷积神经网络图片分类框架 本次硬核分享当时做图片分类的工作,主要是整理了一个图片分类的框架,如果想换模型,引入新模型,在config中修改即可.那么走起来瓷!!! 整体结构 confi ...
- 1 图片channels_深度学习中各种图像库的图片读取方式
深度学习中各种图像库的图片读取方式总结 在数据预处理过程中,经常需要写python代码搭建深度学习模型,不同的深度学习框架会有不同的读取数据方式(eg:Caffe的python接口默认BGR格式,Te ...
- 卷积神经网络图片滤镜_使用深度神经网络创建艺术性的实时视频滤镜
卷积神经网络图片滤镜 将CoreML用于iPhone的复杂视频滤镜和效果 (Using CoreML for complex video filters and effects for iPhone) ...
- DL:深度学习算法(神经网络模型集合)概览之《THE NEURAL NETWORK ZOO》的中文解释和感悟(四)
DL:深度学习算法(神经网络模型集合)概览之<THE NEURAL NETWORK ZOO>的中文解释和感悟(四) 目录 CNN DN DCIGN 相关文章 DL:深度学习算法(神经网络模 ...
- 毕设 深度学习卷积神经网络的花卉识别
文章目录 0 前言 1 项目背景 2 花卉识别的基本原理 3 算法实现 3.1 预处理 3.2 特征提取和选择 3.3 分类器设计和决策 3.4 卷积神经网络基本原理 4 算法实现 4.1 花卉图像数 ...
- 深度学习 卷积神经网络原理
深度学习 卷积神经网络原理 一.前言 二.全连接层的局限性 三.卷积层 3.1 如何进行卷积运算? 3.2 偏置 3.3 填充 3.4 步长 3.5 卷积运算是如何保留图片特征的? 3.6 三维卷积 ...
- 深度学习中神经网络模型压缩的解决办法( flask API、onnx、ncnn在嵌入式、流媒体端口应用)
深度学习中神经网络模型压缩的解决办法( flask API.onnx.ncnn在嵌入式.流媒体端口应用) 1 开发环境的创建 1.1 Conda简介 1.2 miniconda 1.3 conda操作 ...
最新文章
- jupyter notebook xdg-settings 错误
- 你们可能都小看了Windows!
- Java开发环境的搭建(JDK和Eclipse的安装)
- javascript学习-创建json对象数据,遍历
- Linux学习之服务器搭建——DHCP服务器
- 复习Java的精华总结
- ffmpeg 截图 java_Java Web 中使用ffmpeg实现视频转码、视频截图
- .1 matlab,1 MATLAB集成环境
- python递归查找_[Python]递归查找文件(最简洁)
- 第四届UG2研讨会和竞赛:弥合计算成像与视觉识别之间的鸿沟
- 【网摘】ActiveX组件及其注册
- Excel实验情况对比排序
- 产品经理职责和工作内容
- 如何保证缓存一致性?
- C++:实现量化daycounters 日计数器测试实例
- Think-swoole的使用
- 买房贷款在什么情况下会被拒? 你避开这些雷区了吗?
- 编程规范 --- 可读性
- Excel如何分别提取出数值整数部分和小数部分
- Gensim:word2vec(jieba分词,去停用词)