目录

简要介绍PyTorch、张量和NumPy

为什么选择卷积神经网络(CNNs)?

识别服装问题

使用PyTorch实现CNNs

1.简要介绍PyTorch、张量和NumPy

让我们快速回顾一下第一篇文章中涉及的内容。我们讨论了PyTorch和张量的基础知识,还讨论了PyTorch与NumPy的相似之处。

PyTorch是一个基于python的库,提供了以下功能:

用于创建可序列化和可优化模型的TorchScript

以分布式训练进行并行化计算

动态计算图,等等

PyTorch中的张量类似于NumPy的n维数组,也可以与gpu一起使用。在这些张量上执行操作几乎与在NumPy数组上执行操作类似。这使得PyTorch非常易于使用和学习。

在本系列的第1部分中,我们构建了一个简单的神经网络来解决一个案例研究。使用我们的简单模型,我们在测试集中获得了大约65%的基准准确度。现在,我们将尝试使用卷积神经网络来提高这个准确度。

2.为什么选择卷积神经网络(CNNs)?

在我们进入实现部分之前,让我们快速地看看为什么我们首先需要CNNs,以及它们是如何工作的。

我们可以将卷积神经网络(CNNs)看作是帮助从图像中提取特征的特征提取器。

在一个简单的神经网络中,我们把一个三维图像转换成一维图像,对吧?让我们看一个例子来理解这一点:

你能认出上面的图像吗?这似乎说不通。现在,让我们看看下面的图片:

我们现在可以很容易地说,这是一只狗。如果我告诉你这两个图像是一样的呢?相信我,他们是一样的!唯一的区别是第一个图像是一维的,而第二个图像是相同图像的二维表示

空间定位

人工神经网络也会丢失图像的空间方向。让我们再举个例子来理解一下:

你能分辨出这两幅图像的区别吗?至少我不能。由于这是一个一维的表示,因此很难确定它们之间的区别。现在,让我们看看这些图像的二维表示:

在这里,图像某些定位已经改变,但我们无法通过查看一维表示来识别它。

这就是人工神经网络的问题——它们失去了空间定位。

大量参数

神经网络的另一个问题是参数太多。假设我们的图像大小是28283 -所以这里的参数是2352。如果我们有一个大小为2242243的图像呢?这里的参数数量为150,528。

这些参数只会随着隐藏层的增加而增加。因此,使用人工神经网络的两个主要缺点是:

丢失图像的空间方向

参数的数量急剧增加

那么我们如何处理这个问题呢?如何在保持空间方向的同时减少可学习参数?

这就是卷积神经网络真正有用的地方。CNNs有助于从图像中提取特征,这可能有助于对图像中的目标进行分类。它首先从图像中提取低维特征(如边缘),然后提取一些高维特征(如形状)。

我们使用滤波器从图像中提取特征,并使用池技术来减少可学习参数的数量。

在本文中,我们不会深入讨论这些主题的细节。如果你希望了解滤波器如何帮助提取特征和池的工作方式,我强烈建议你从头开始学习卷积神经网络的全面教程。

3.问题:识别服装

理论部分已经铺垫完了,开始写代码吧。我们将讨论与第一篇文章相同的问题陈述。这是因为我们可以直接将我们的CNN模型的性能与我们在那里建立的简单神经网络进行比较。

你可以从这里下载“识别”Apparels问题的数据集。

https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-apparels/?utmsource=blog&utmmedium=building-image-classification-models-cnn-pytorch

让我快速总结一下问题陈述。我们的任务是通过观察各种服装形象来识别服装的类型。我们总共有10个类可以对服装的图像进行分类:

数据集共包含70,000张图像。其中60000张属于训练集,其余10000张属于测试集。所有的图像都是大小(28*28)的灰度图像。数据集包含两个文件夹,一个用于训练集,另一个用于测试集。每个文件夹中都有一个.csv文件,该文件具有图像的id和相应的标签;

准备好开始了吗?我们将首先导入所需的库:

加载数据集

现在,让我们加载数据集,包括训练,测试样本:

该训练文件包含每个图像的id及其对应的标签

另一方面,测试文件只有id,我们必须预测它们对应的标签

样例提交文件将告诉我们预测的格式

我们将一个接一个地读取所有图像,并将它们堆叠成一个数组。我们还将图像的像素值除以255,使图像的像素值在[0,1]范围内。这一步有助于优化模型的性能。

让我们来加载图像:

如你所见,我们在训练集中有60,000张大小(28,28)的图像。由于图像是灰度格式的,我们只有一个单一通道,因此形状为(28,28)。

现在让我们研究数据和可视化一些图像:

以下是来自数据集的一些示例。我鼓励你去探索更多,想象其他的图像。接下来,我们将把图像分成训练集和验证集。

创建验证集并对图像进行预处理

我们在验证集中保留了10%的数据,在训练集中保留了10%的数据。接下来将图片和目标转换成torch格式:

同样,我们将转换验证图像:

我们的数据现在已经准备好了。最后,是时候创建我们的CNN模型了!

4.使用PyTorch实现CNNs

我们将使用一个非常简单的CNN架构,只有两个卷积层来提取图像的特征。然后,我们将使用一个完全连接的Dense层将这些特征分类到各自的类别中。

让我们定义一下架构:

现在我们调用这个模型,定义优化器和模型的损失函数:

这是模型的架构。我们有两个卷积层和一个线性层。接下来,我们将定义一个函数来训练模型:

最后,我们将对模型进行25个epoch的训练,并存储训练和验证损失:

可以看出,随着epoch的增加,验证损失逐渐减小。让我们通过绘图来可视化训练和验证的损失:

啊,我喜欢想象的力量。我们可以清楚地看到,训练和验证损失是同步的。这是一个好迹象,因为模型在验证集上进行了很好的泛化。

让我们在训练和验证集上检查模型的准确性:

训练集的准确率约为72%,相当不错。让我们检查验证集的准确性:

正如我们看到的损失,准确度也是同步的-我们在验证集得到了72%的准确度。

为测试集生成预测

最后是时候为测试集生成预测了。我们将加载测试集中的所有图像,执行与训练集相同的预处理步骤,最后生成预测。

所以,让我们开始加载测试图像:

现在,我们将对这些图像进行预处理步骤,类似于我们之前对训练图像所做的:

最后,我们将生成对测试集的预测:

用预测替换样本提交文件中的标签,最后保存文件并提交到排行榜:

你将在当前目录中看到一个名为submission.csv的文件。你只需要把它上传到问题页面的解决方案检查器上,它就会生成分数。链接:https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-apparels/?utmsource=blog&utmmedium=building-image-classification-models-cnn-pytorch

我们的CNN模型在测试集上给出了大约71%的准确率,这与我们在上一篇文章中使用简单的神经网络得到的65%的准确率相比是一个很大的进步。

5.结尾

在这篇文章中,我们研究了CNNs是如何从图像中提取特征的。他们帮助我们将之前的神经网络模型的准确率从65%提高到71%,这是一个重大的进步。

你可以尝试使用CNN模型的超参数,并尝试进一步提高准确性。要调优的超参数可以是卷积层的数量、每个卷积层的滤波器数量、epoch的数量、全连接层的数量、每个全连接层的隐藏单元的数量等。

python实现cnn特征提取_使用PyTorch提取CNNs图像特征相关推荐

  1. python头像转卡通_人工智能:一款图像转卡通的 Python 项目,超级值得你练手

    大家好,我是章鱼猫. 今天给大家推荐的开源项目,我感觉对于想学习 Python,想学习 tensorflow ,pytorch 的同学来讲,真的非常不错,是一个非常值得大家学习和练手的一个开源项目. ...

  2. code epi 光场_基于EPI的光场图像特征点检测方法与流程

    本发明属于计算机视觉技术领域,涉及一种基于epi的光场图像特征点检测方法. 背景技术: 光场成像技术是计算机视觉领域的热点研究方向,光场是一个研究的热点领域,简而言之,光场是一束光在传播过程中,所包含 ...

  3. python训练手势分类器_使用Pytorch训练分类器详解(附python演练)

    [前言]:你已经了解了如何定义神经网络,计算loss值和网络里权重的更新.现在你也许会想数据怎么样? 目录: 一.数据 二.训练一个图像分类器 使用torchvision加载并且归一化CIFAR10的 ...

  4. python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...

  5. python pdf删除图片_使用PyMuPdf提取、删除及替换PDF中的图片文件

    有时候想把PDF中的图片文件提取出来,身为程序员的我当然是自己写段代码来实现,先看看了网上的方法,都是逐行遍历,正则匹配来提取什么的,其实没有那么复杂,PyMuPdf官方文档里自带就有提取图片文件的方 ...

  6. java代码编写的文本特征提取_文本特征词提取算法

    在文本分类中,需要先对文本分词,原始的文本中可能由几十万个中文词条组成,维度非常高.另外,为了提高文本分类的准确性和效率,一般先剔除决策意义不大的词语,这就是特征词提取的目的.本文将简单介绍几种文本特 ...

  7. python中dice常见问题_【Pytorch】 Dice系数与Dice Loss损失函数实现

    由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面结合网上的教程进行详细实现. 先来看一个我在网上经常看到的一个版本. def diceCoeff(pred, gt, s ...

  8. 神经网络python分类识别图片_教现有的神经网络图像模型识别新的图像类别

    如何用最小代价重新训练Google Inception V3神经网络模型,以用于新的分类 郑灵翔 2017.06.03 运行本文的代码要求已安装TensorFlow,若没安装的请参考TensorFlo ...

  9. python怎么画长方体_将长方体绘制为图像

    我正在将一个重要的应用程序从PHP迁移到Python.一切都很好,除了一个我似乎无法弄清的问题.在 旧的应用程序使用PHP库Image_3D将一些长方体绘制成图像(SVG,尽管这并不重要),并将它们显 ...

最新文章

  1. abovedisplayskip无效_latex减少图片和图片解释文字之间的距离、调整公式与正文间距离,调整空白大小:...
  2. 高版本转低版本_Tekla高版本模型转低版本模型插件
  3. html实战例子: 在title左侧添加logo
  4. Java中实用类:Date、Calendar、Math、Random、String、StringBuffer的用法
  5. c语言简单的模拟坐标,C语言模拟实现简单扫雷游戏
  6. 物联网将成为第四次工业革命的基石
  7. 华农c语言实验1007答案,华农C语言题目及答案(完整版).docx
  8. C++:类的构造函数
  9. (转载)查看Oracle字符集及怎样修改字符集
  10. raspberry pi_书评:“ Raspberry Pi for Secret Agents”的使用效果不佳
  11. phpstorm 10 注册码
  12. Linux make menuconfig打开失败
  13. 6月7日 PowerPoint 版本支持的媒体格式(跨office版本演示需要了解)
  14. 谈谈joomla1.5中个人遇见的古怪问题
  15. 根据列表内车牌号,统计各省市车牌占有量
  16. 工业互联网环境下的工业控制系统安全防护
  17. 论文介绍--Spatio-Temporal Dynamics and Semantic Attribute Enriched Visual Encoding for Video Captioning
  18. 【附源码】计算机毕业设计JAVA学生公寓管理系统
  19. rufus制作ubuntuU盘启动以及window10和ubuntu20.04.2双系统
  20. 3D游戏中的数学运用

热门文章

  1. Virtual Villagers 攻略
  2. GC overhead limit exceede
  3. 详解Python中的True、False和None
  4. 地摊经济书籍-《城市地摊财富秘籍》
  5. 城市地摊重燃人间烟火,农村赶集却快熄灭烟火
  6. AK F.*ing leetcode 流浪计划之线段树
  7. sublime text 3 mac 注册码
  8. 新版Jsoncpp用法
  9. 计算机应用大赛PPT题库,经管系计算机应用技术专业PPT制作大赛.doc
  10. html中bak是什么文件怎么打开,bak文件怎么打开