用PyTorch实现一个卷积神经网络进行图像分类

原创 2017年07月18日 21:27:13
  • 标签:
  • 卷积神经网络 /
  • PyTorch /
  • 深度学习

1. 回顾

在进入这一篇博客的内容之前,我们先确保已经成功安装好PyTorch,可以参考我之前的一篇博客“Ubuntu12.04下PyTorch详细安装记录”:

http://blog.csdn.net/wblgers1234/article/details/72902016
  • 1

接下来,我们用设计一个简单的卷积神经网络的方式来熟悉PyTorch的用法。

2. 设计卷积神经网络

在设计复杂的神经网络之前,我们依然考虑按照斯坦福大学的“UFLDL Tutorial”的CNN部分来构建一个简单的卷积神经网络,即按照以下的设计:

输入层->二维特征卷积->sigmoid激励->均值池化->全连接网络->softmax输出
  • 1

按照下面的代码对应来看神经网络的结构。注释得很清晰,有不清楚的可以留言,这里就不再赘述。

class CNN_net(nn.Module):def __init__(self):# 先运行nn.Module的初始化函数super(CNN_net, self).__init__()# 卷积层的定义,输入为1channel的灰度图,输出为4特征,每个卷积kernal为9*9self.conv = nn.Conv2d(1, 4, 9)# 均值池化self.pool = nn.AvgPool2d(2, 2)# 全连接后接softmaxself.fc = nn.Linear(10*10*4, 10)self.softmax = nn.Softmax()def forward(self, x):# 卷积层,分别是二维卷积->sigmoid激励->池化out = self.conv(x)out = F.sigmoid(out)out = self.pool(out)print(out.size())# 将特征的维度进行变化(batchSize*filterDim*featureDim*featureDim->batchSize*flat_features)out = out.view(-1, self.num_flat_features(out))# 全连接层和softmax处理out = self.fc(out)out = self.softmax(out)return outdef num_flat_features(self, x):# 四维特征,第一维是batchSizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

3. 数据准备

还记得torchvision吗?我们在做和图像有关的实验时会更多地与它打交道。这次我们选择最简单也是最广为人知的MNIST数据库来训练和测试CNN。同时在torchvision中有一个torchvision.datasets,它为很多常用的图像数据库提供接口,其中就包括MNIST。

from torchvision.datasets import MNIST
  • 1

需要先下载MNIST,并且转换为PyTorch可以识别的数据格式:

# MNIST图像数据的转换函数
trans_img = transforms.Compose([transforms.ToTensor()])# 下载MNIST的训练集和测试集
trainset = MNIST('./MNIST', train=True, transform=trans_img, download=True)
testset = MNIST('./MNIST', train=False, transform=trans_img, download=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

我们查看transforms.ToTensor()的解释,将原本的二维图像格式转换为PyTorch的基本单位torch.FloatTensor。

Converts a PIL.Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
  • 1

4. 训练和测试

4.1 训练数据集

从代码中可以清晰的看见“前向传播”,“反向传播”,optimizer的求解。

# 训练过程
for i in range(epoches):running_loss = 0.running_acc = 0.for (img, label) in trainloader:# 转换为Variable类型img = Variable(img)label = Variable(label)optimizer.zero_grad()# feedforwardoutput = net(img)loss = criterian(output, label)# backwardloss.backward()optimizer.step()# 记录当前的lost以及batchSize数据对应的分类准确数量running_loss += loss.data[0]_, predict = torch.max(output, 1)correct_num = (predict == label).sum()running_acc += correct_num.data[0]# 计算并打印训练的分类准确率running_loss /= len(trainset)running_acc /= len(trainset)print("[%d/%d] Loss: %.5f, Acc: %.2f" %(i+1, epoches, running_loss, 100*running_acc))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

在训练完成之后,有一个处理很重要,需要将当前的网络设置为“测试模式”,然后才可以进行测试集的验证。

# 将当前模型设置到测试模式
net.eval()
  • 1
  • 2

4.2 测试数据集

在测试过程中,只有“前向传播”过程对输入的图像进行分类预测。

# 测试过程
testloss = 0.
testacc = 0.
for (img, label) in testloader:# 转换为Variable类型img = Variable(img)label = Variable(label)# feedforwardoutput = net(img)loss = criterian(output, label)# 记录当前的lost以及累加分类正确的样本数testloss += loss.data[0]_, predict = torch.max(output, 1)num_correct = (predict == label).sum()testacc += num_correct.data[0]# 计算并打印测试集的分类准确率
testloss /= len(testset)
testacc /= len(testset)
print("Test: Loss: %.5f, Acc: %.2f %%" %(testloss, 100*testacc))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

4.3 代码运行结果

从下面的结果,可以看到迭代10次的训练分类准确率和测试分类准确率:

CNN_net ((conv): Conv2d(1, 4, kernel_size=(9, 9), stride=(1, 1))(pool): AvgPool2d ()(fc): Linear (400 -> 10)(softmax): Softmax ()
)
[1/10] Loss: 1.78497, Acc: 68.79
[2/10] Loss: 1.54269, Acc: 93.10
[3/10] Loss: 1.52096, Acc: 94.93
[4/10] Loss: 1.51040, Acc: 95.82
[5/10] Loss: 1.50393, Acc: 96.45
[6/10] Loss: 1.49967, Acc: 96.77
[7/10] Loss: 1.49655, Acc: 97.02
[8/10] Loss: 1.49401, Acc: 97.24
[9/10] Loss: 1.49192, Acc: 97.45
[10/10] Loss: 1.49050, Acc: 97.56
Test: Loss: 1.48912, Acc: 97.62 %
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

该工程完整的代码我已经放到github上,有兴趣的可以去下载试试:

https://github.com/wblgers/stanford_dl_cnn/tree/master/PyTorch
  • 1
版权声明:本文为博主原创文章,未经博主允许不得转载。

用PyTorch实现一个卷积神经网络进行图像分类相关推荐

  1. 在PyTorch中使用卷积神经网络建立图像分类模型

    概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题--CNN的一个经典和广泛使用的应用 我们将以实用的格式介绍深度学习概念 介绍 我被神经网络的力量和能力所 ...

  2. PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类

    在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...

  3. pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类

    只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...

  4. 图像处理神经网络python_深度学习使用Python进行卷积神经网络的图像分类教程

    深度学习使用Python进行卷积神经网络的图像分类教程 好的,这次我将使用python编写如何使用卷积神经网络(CNN)进行图像分类.我希望你事先已经阅读并理解了卷积神经网络(CNN)的基本概念,这里 ...

  5. 卷积神经网络和图像分类识别

    Andrew Kirillov 著 Conmajia 译 2019 年 1 月 15 日 原文发表于 CodeProject(2018 年 10 月 28 日). 中文版有小幅修改,已获作者本人授权. ...

  6. PyTorch实现基于卷积神经网络的面部表情识别

    基于卷积神经网络的面部表情识别(Pytorch实现)----台大李宏毅机器学习作业3(HW3) 一.项目说明 给定数据集train.csv,要求使用卷积神经网络CNN,根据每个样本的面部图片判断出其表 ...

  7. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

  8. 基于卷积神经网络的图像分类

    实验任务与要求: 使用PyTorch编写并训练卷积神经网络模型,用于识别花卉.花卉数据集17flowers.zip 与Resnet-50 预训练权值文件resnet50.pth 请从https://p ...

  9. DeepDream、反向运行一个卷积神经网络在 DeepDream和卷积神经网络的可视化 中的应用

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 反向运行一个卷积神经网络在 卷积神经网络的可视化 中的应用 D ...

  10. pytorch实现:卷积神经网络识别FashionMNIST

    pytorch实现:卷积神经网络识别FashionMNIST 一.卷积神经网络 1.1 导入需要的包 1.2 图像数据准备 1.3 卷积神经网络搭建 1.4 卷积神经网络训练与预测 二.空洞卷积神经网 ...

最新文章

  1. WebAssembly基础
  2. 打印杨辉三角--for循环
  3. c 语言现代方法13章习题6
  4. 面试官灵魂拷问:为什么 SQL 语句不要过多的 join?
  5. jwt同一会话_在会话中使用JWT
  6. IIR数字滤波器的设计及应用——MATLAB
  7. TypeScript事件处理程序
  8. SAP License:SAP S/4HANA Cloud介绍
  9. 【java笔记】接口的定义,接口的使用
  10. 《Redis设计与实现》阅读:Redis底层研究之简单动态字符串SDS
  11. python3解密栅栏密码的正确方法
  12. cad详图怎么画_CAD结构图怎么画?手把手教你CAD结构图的绘制方法
  13. 动态RAM(DRAM)和静态RAM(SRAM)的比较
  14. 持NPDP和PMP证书,可以享受深圳、北京等多项福利!
  15. inflate的使用
  16. arduino控制小车转向_利用XECU和激光雷达快速搭建入门级的自动驾驶小车
  17. http://www.hi-donet.com/网站
  18. Java的基本特性和优势
  19. 学会用python识别图像
  20. 大学生学完python靠几个接单网站兼职,实现经济独立

热门文章

  1. Javascript脚本之清除浏览器历史数据
  2. 5月30日任务 访问日志不记录静态文件、访问日志切割、静态元素过期时间
  3. 阿里影业“灯塔平台”今日正式启动,阿里影视云解决方案强势推出
  4. PMC 任命Edward Sharp为首席战略及技术官
  5. 简析边缘数据中心技术
  6. Memcached FAQ(2) 集群架构方面的问题
  7. 【BZOJ1500】[NOI2005]维修数列
  8. 转载:详解C中volatile关键字
  9. perl发送天气预报
  10. 2010中国十大杰出IT博客大赛—唯有行动才能改造命运