基于pytorch构建一个非常简单的卷积神经网络,以Mnist数据集为例演示基本的流程

1、导工具包

2、读取数据

(把该写的超参数全部写出来)

PS:当前输入图像的大小,注意这里使用卷积网络处理Mnist数据他就不是一个一个像素点了,既然我们要用卷积网络去做,那输入的他得是一张图像,对于一张图像我们现在的输入得是28x28x1的三维的数据,我们现在需要的数据他是三维的他是三维的。

3、卷积网络模块构建

定义的conv1不光是做了一个卷积,他是一个卷积模块,包含了卷积、池化、Relu全部加进去了,

定义conv2也是一个卷积模块;

在做卷积层的时候其实很简单,需要在nn模块当中,把Conv2d拿出来就可以了,其中:

in_channel:表示当前输入的特征图个数,对于第一个卷积来说,他的输入应该是我们的Mnist数据,这个数据是一个灰度图,所以说他的In_channels=1,这是我输入的颜色通道,或者说输入特征图的个数;

out_channels:表示输出特征图的个数,就是说你用多少个卷积核来对当前数据或者对当前的这个图进行特征提取,这里的16表示你用16个卷积核,16个不同的卷积核,肯定会得到16个不同的特征图,所以out这里的意思就是你想得到多少个特征图的意思;

kernel_size:表示我们现在做卷积,你得告诉我我每隔多大的区域进行特征提取,这里等于5就是我用一个5x5的区域来去在当前原始的输入图像当中进行特征提取,kernel_size=5表示卷积核的大小;

stride:表示当前的步长;我在做卷积的窗口进行滑动过程当中,每隔几个单元滑动,一般情况下步长都是为1的。

padding:表示做边缘填充;这里为2表示加2圈0。

PS:如果不想写这些参数的名字,直接向nn.Conv2d那样直接全部输入数字也行。

在nn.Conv2d中的16表示输入,之前输入是1表示灰度图默认就是1,

这里的16表示(大家记住一点就是我当前这个卷积层他的输入大小就是之前得到多少个特征图)之前得到的16个特征图,所以这里我们的输入也是16;

32表示使用32个卷积核去提取特征,得到的是32个图;

5表示kernel_size=5;

3.1、拉伸操作

在做卷积的时候,最后不管是卷积层还是Relu层还是pooling层,无论这三个层中的哪一个,我们最终得到的都是一个特征图,所谓这个特征图不考虑batch_size的前提下,他是一个三维的,比如我最终得到的是一个32x32x256的结果,他是一个图,他不是一个矩阵,我们最终需要得到一个10分类的结果,怎么样得到一个10分类的结果呢?现在我需要把这个立体的东西给他拉长,转成一个矩阵或者说是向量,比如说这里他是一个2048维度(假设的)的向量,接下来我连接上一个全连接层(一个权重参数w,一个偏置参数b),我就能得到我最终预测出来的一个结果,比如说10个类别,我就能够去做了。

PS:所以说这里我得多做一步,把当前得到的特征图,给他拉长,拉成一个长向量,基于这个向量我才能对他做一个全连接层,得到最终的一个预测结果,所以这里有一步拉长的操作。

在做拉长操作之前还得做一件事,我们得知道,你最后一层这个全连接层里边这个w他的一个维度,w的第2维度很简单肯定是个10,因为得到的是10个类别,第一个维度就是你得到这个特征图里边他有多少个特征,这里的2048就是把3个数乘在一起, 所以在做卷积的时候最后得到的这个特征图他的规格,它的大小是等于多少。

这里我们计算一下他得到的特征图大小是多少?

一开始输入是28x28x1--------------------->经过第一个卷积层之后得到28x28x16(因为用了16个卷积核)------------->经过Relu不变还是28x28x16------------->经过最大池化层是14x14x16,------------>经过第二个卷积层是14x14x32(因为用了32个卷积核)------------->经过Relu不变还是14x14x32------------->经过最大池化层是7x7x32,所以最终的w第一个维度就等于1568=7x7x32,第二个维度是10;

根据计算公式:

所以写了一个最终输出层,输出层里边我们是全连接操作,然后全连接里边他是32X7x7表示经过这几次卷积之后得到的一个结果,10就是最终想要输出的类别的个数。

3.2、把网络串起来

之后进行前向传播,前向传播比较简单,一开始经过conv1,再经过conv2,下一句特别的,做了一个reshape操作(x=x.view(x.size(0),-1)),这个reshape操作就是刚才说的咱们得把当前结果转化成全部向量的格式(因为下一层要做全连接了),接下来用向量再乘上我的全连接层,就是wx+b,最终就得到了当前这个输入属于10个类别中的各自的一个结果。

4、评估

评估函数计算一下当前的准确率。

5、训练网络模型

把一开始定义的CNN拿到手,定义损失函数,指定优化器。接下来遍历每个epoch,每个epoch里边我一个batch一个batch的取数据,然后定义一个train模块(net.train()),然后在train模块当中更新我们的一个权重参数,就可以了;

每隔一百次可以去在验证集上看一下当前验证集的效果等于多少;接下来计算准确率打印当前结果;

PyTorch框架:(5)使用PyTorch框架构建卷积神经网络相关推荐

  1. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  2. PyTorch 入门实战(四)——利用Torch.nn构建卷积神经网络

    承接上一篇:PyTorch 入门实战(三)--Dataset和DataLoader PyTorch入门实战 1.博客:PyTorch 入门实战(一)--Tensor 2.博客:PyTorch 入门实战 ...

  3. 基于Pytorch再次解析使用块的现代卷积神经网络(VGG)

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  4. keras构建卷积神经网络(CNN(Convolutional Neural Networks))进行图像分类模型构建和学习

    keras构建卷积神经网络(CNN(Convolutional Neural Networks))进行图像分类模型构建和学习 全连接神经网络(Fully connected neural networ ...

  5. TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略

    TF之CNN:Tensorflow构建卷积神经网络CNN的简介.使用方法.应用之详细攻略 目录 TensorFlow 中的卷积有关函数入门 1.tf.nn.conv2d函数 案例应用 1.TF之CNN ...

  6. tensorflow 图像教程 の TF Layers 教程:构建卷积神经网络

    文章目录 TF Layers 教程:构建卷积神经网络 卷积神经网络的简介 构建基于卷积神经网络的 MNIST 分类器 输入层 第一个卷积层 第一个池化层 第二个卷积层和池化层 全连接层 Logits ...

  7. keras构建卷积神经网络_在Keras中构建,加载和保存卷积神经网络

    keras构建卷积神经网络 This article is aimed at people who want to learn or review how to build a basic Convo ...

  8. keras构建卷积神经网络_通过此简单教程学习在网络上构建卷积神经网络

    keras构建卷积神经网络 by John David Chibuk 约翰·大卫·奇布克(John David Chibuk) 通过此简单教程学习在网络上构建卷积神经网络 (Learn to buil ...

  9. keras构建卷积神经网络_在python中使用tensorflow s keras api构建卷积神经网络的初学者指南...

    keras构建卷积神经网络 初学者的深度学习 (DEEP LEARNING FOR BEGINNERS) Welcome to Part 2 of the Neural Network series! ...

  10. 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类

    实验目的 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类. 实验环境 import tensorflow as tfprint(tf.__version__) output ...

最新文章

  1. 【转】 Pro Android学习笔记(二九):用户界面和控制(17):include和merge
  2. python 类和对象 有必要学吗_类与对象-python学习19
  3. Javaweb练手项目
  4. 利用fiddler给android模拟器抓包
  5. Web缓存相关知识整理
  6. 开发一个简单的WebPart
  7. 迷惑行为!淘宝上线新版“相亲名片”:上来先告诉相亲对象你花了多少钱?...
  8. Mootools:Hash中的null值
  9. LaTeX 阿拉伯语
  10. Linux下服务器密码正确,登录的时候却提示密码错误
  11. 如何在面试中介绍自己的项目经验
  12. GWAS 教程之QC
  13. 安卓投屏大师TC DS如何把手机声音传输到电脑教程
  14. 数据库之区分DB\DBMS\DBS
  15. 胡凡算法笔记第二章摘录
  16. 微型计算机系统的发展历史,计算机的系统发展历史
  17. 剑灵系统推荐加点_新版剑灵怎么加点(2019剑灵技能加点在哪里)
  18. 火狐浏览器图形验证码刷新不生效的问题(图片src重新赋值不生效的问题)
  19. 以程序员的视角带你看西安
  20. 龙芯 python_二代龙芯派 VS 树莓派 3B+:性能孰胜一筹?

热门文章

  1. 客快物流大数据项目(十五):DockeFile常用命令
  2. 2021年大数据Spark(三十六):SparkStreaming实战案例一 WordCount
  3. php-fpm开启报错-ERROR: An another FPM instance seems to already listen on /tmp/php-cgi.sock
  4. Python案例:使用XPath的爬虫
  5. f是一个python内部变量类型,Python基础变量类型——List浅析
  6. elasticsearch 监控
  7. 第2节 mapreduce深入学习:4, 5
  8. Django进阶-auth集成认证模块
  9. 洛谷 P1142 轰炸
  10. Charles抓取https请求