文章目录

  • 复现准备
  • 数据部分
  • 搭建网络结构
    • C1层:
    • S2层:
    • C3层:
    • S4层:
    • C5层:
    • F6层:
    • Output层:
  • 损失函数与优化器:

复现准备

论文开头的一些概念和思想已经分析完.没看过的可以去看一下, CNN基础论文复现----LeNet5 (一)

环境: win10, Pycharm , python 3.8 , pytorch 1.9.1
先把可能用到的包都一股脑的全导进来。

import torch
from torch import optim
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn

数据部分

文章中已经指明了使用minist数据集。
简单介绍一下关于minist数据集:
就是一群人手写的数字 0-9,然后用于训练让机器识别。
数据包 训练集+测试集 一共100MB左右,有10W左右个数据样本, 里面的图片像素大小是 28 * 28,灰度像素范围 [0,255],且全是整数。像下面这样:

MINIST数据集可以自己下载然后拖到Pycharm里用,也可以直接在后面使用代码让他自动下。

MINIST 数据集下载地址: http://yann.lecun.com/exdb/mnist/

论文中只在第10页中介绍了一下MINIST数据集,然后紧接着就是在分析结果了,只给出了一些具体训练过程中的函数和方法的使用,所以复现过程中有很大一部分属于自由发挥,如有错误请大佬在评论区鞭打我。。。。

首先设置一下数据结构。

batch_size = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 64

设置一次抓取多少个样本进行训练。

transforms.Compose,

transforms相当于是对图片的一个处理工具箱,比如剪辑,旋转,填充变换等等,而 Compose相当于一个集合,将所有对图片的预处理操作放到一起,按步执行。

transforms.ToTensor(),

就是改变图像类型和数据 变成tensor类型,图像(0-255,像素值 28 * 28)值变为图像张量(映射0-1,像素值 1 * 28 * 28)
就是 W * H * C 变为 C * W * H嘛。

transforms.Normalize,

这个函数的目的就是标准化数据,通过某种算法将他限制在一定的范围之内(比如概率中的0-1),方便后期数据处理以及加快收敛速度。
函数后面的两个参数, 均值(mean) 标准差(std),由于这个模型已经非常成熟,所以直接用现有的数字就可以了。
Normalize的具体原理,看了一下,有点费劲,暂且搁置哈哈。。。。

然后我们开始将MINIST数据集弄下来。

train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

这块没啥好说的 看变量名和参数名也比较容易理解,四行分别是训练集 加载训练集 测试集 加载测试集。

shuffle = True 打乱样本顺序
train = True 作为训练集
download= True 从网上下载

搭建网络结构

所用到的网络结构就是论文中给出的LeNet-5卷积网络结构,
就下面这个样子。

在刚开始搭建网络的时候我发现一个问题,这张图是论文中给出的图,但是Input的是 32 * 32的 而MINIST数据集Input是28 * 28的 这是啥情况?

我又跑去看了一下论文,提到了一个归一化的概念,其实就是我们上面用的那个Normalize函数。
文中说:使用归一化算法的抗锯齿。通过计算像素的质心并平移图像,将图像置于28 28图像的中心,从而将该点置于28 28场的中心。在某些情况下,这个28 * 28域被扩展到32 * 32的背景像素,不是很理解这块,可能类似于padding?应该不影响复现过程,后面再说吧 下面的复现都认为输入的尺寸为28*28。

开始~

C1层:

C开头的层就是卷积层。
定义像下面这样,

self.conv1 = torch.nn.Conv2d(输入通道数, 输出通道数, 卷积核尺寸,填充数padding)

显然,输入通道是1,输出通道是6,卷积核尺寸论文中已经给出了为5。

padding用官网给的公式算一下:
H为28,dilation默认为1,kernel_size= 5,
所以这个公式只有 stride和padding不知道了。
可以假设stride为1,则此时padding为2,
如果假设stride为2,则padding就是14+了 显然不合理。
所以此时算出来了stride为1,padding为2。

则就有了第一层卷积层:
self.conv1 = torch.nn.Conv2d(1, 6, 5,padding = 2)

S2层:

S开头的层就是论文中提到的采样层,也就是我们常用的池化层,文中使用的是特殊的平均池化,这里与我们平常使用的平均池化不一样,平常使用的平均池化是直接相加然后取平均值,在论文中7页末和8页开头作者说了一下,这里的平均池化要乘上一个可训练的系数,再加上一个可训练的偏差。最后通过一个sigmoid函数。具体为什么要这样用,看了一下论文,也就只有7和8页介绍了网络结构的时候提到了S2层,具体也没说为什么要这样用。。。。。

大概像这样的: y = (a1+a2+a3+a4) * w + b 然后再加上sigmoid。

这样的话就比较恶心了,因为平常使用池化操作就是调用一行代码的事,现在这样就只能重写池化函数了,于是查阅资料,看到大佬的重写代码:

定义一个自定义类,继承自 torch.nn.Moduel
Class Subsampling(nn.Moduel)
然后初始化一下这个类里的init方法。

def __init__(self, in_channel):super(Subsampling, self).__init__()self.pool = nn.AvgPool2d(2)self.in_channel = in_channelF_in = 4 * self.in_channelself.weight = nn.Parameter(torch.rand(self.in_channel) * 4.8 / F_in - 2.4 / F_in, requires_grad=True)self.bias = nn.Parameter(torch.rand(self.in_channel),  requires_grad=True)

super(Subsampling, self).__init__():

继承父类的 nn.Moduel里的一些方法拿过来用,没啥好说的。

self.pool = nn.AvgPool2d(2)

继承父类,定义了一个池化层,这里是我们平常时候用的那种定义方法,而论文中的还需要做线性处理(乘一个数再加一个数),且文中也指定卷积核的尺寸为2 * 2。

self.weight 和 self.bias
这俩就是在定义作者的那两个可训练参数,权重和偏置值。

nn.Parameter()

看到一篇博客的总结,这里引用一下:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

不过!!!!

Parameter() 里的 torch.rand() 第一个参数 是size:用来定义tensor的shape ,但是里面的什么4.8 2.4 直接给我整蒙了,大佬的代码就是牛b,压根看不懂(可能是我太菜了哈哈)。

在这里卡了挺久的,也找了很多资料,发现网上几乎99%都直接使用平均池化或者最大池化代替了??为啥?我不理解~~·,能看懂的大佬教教我。。。。

原文明明是:

The four inputs to a unit in S2 are added, then multiplied by a trainable coefficient, and then added to a trainable bias.The result is passed through a sigmoidal function.

将S2中一个单位的四个输入相加,然后乘以一个可训练系数,然后加上一个可训练偏差。结果通过simgoid函数传递。

而最大池化和平均池化显然都不符合这样的标准,怎么能直接拿来用啊 那还叫复现嘛 不过奈何能力有限,直接重写的池化函数看不太懂,所以只能用网上大多数人用的最大池化层来暂行代替。

直接用最大池化就简单了,论文中已经给出了尺寸 2 * 2。
下面这样

nn.MaxPool2d(kernel_size=2)

此时已经深刻体会到框架所带来的API的好处。。。。。

C3层:

又是一个卷积层。

这里又开始恶心了~~这不是普通的卷积层,我们知道,对于普通的卷积操作,就是一个窗口对图中所有特征元素都进行卷积,而这里的卷积层则是只选择一部分特征图进行卷积,看一下他是怎么实现的。

论文中给出了卷积的操作的关系图:


简单的解释一下上面这个图什么意思,

看一下这一层卷积的通道数 为输入6 输出16。所以横坐标表示输出的特征通道0-15,共16个。竖坐标表示输入的特征通道0-5 一共6个。

看横坐标的0号特征通道,在竖左边上的0,1,2 上打了 ‘X’。表明输出的0号通道仅和输入的0,1,2号通道做卷积操作,3,4,5号通道无视。其他的同理。

0-5号为连续的3个特征通道,6-8号为连续的四个特征通道,9-14号为不连续的四个特征通道,最后一个15号为连续的6个,即全部特征通道。

为什么要这样做呢? 这就是在复现的上一篇文章中提到的类似于dropout的算法,忘了的可以再去看一下。目的就是为了 打破对称性和让他们在不同的特征图中被迫提取尽量互补的特征,以防止过拟合降低模型的耦合和提高鲁棒性,还可以减少参数。

这里肯定是又要进行卷积层的重写,不过我这Pytorch伪入门的水平,重写某个类来说比较困难(毕竟源码都没看过),然后我又去查阅了大量资料,几乎没有完全性复现的代码和文章(只有一篇,还没看懂),其余的基本都是用了Pytorch自带的api进行卷积。

无奈只能用自带函数进行卷积。
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),

S4层:

这就简单了,论文中没有过多介绍,就是一个普通的池化层,通道数为 16,卷积核大小为2*2。

nn.MaxPool2d(kernel_size=2

C5层:

这里就很无语了,论文中给的图片是线性层,文字描述却说是卷积层,网上也是有线性的有卷积的,就离谱。。
不过好在无论是线性还是卷积,都可以直接调用函数来做。
这里我直接使用线性来做。
输入 为5 *5 *46 输出是120 * 1 * 1 .

nn.Linear(in_features=5*5*16, out_features=120

F6层:

这里是实打实的线性层。

输入120输出84。

nn.Linear(in_features=120, out_features=84

Output层:

这一层我还以为很简单,但看了一下论文发现并没有那么简单,读了一下论文结合查阅的相关资料明白了。

论文中说的很复杂,我认为没有必要弄的完全懂,能够明白大致什么意思就行了。

最后这一层是由 欧式径向基函数(Euclidean RBF) 组成 。
通俗的理解 ,作用相当于判断正确率,设计一个 7*12的矩阵。将0-9中的每一个数字都展开成 7 * 12的矩阵,其中黑色用1表示 白色用0表示。

所以 数字1 就长下面这样:

[0, 0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]

远点看 是不是把1的地方涂黑 这就是 印刷字体中的 ‘1’ 了。

论文中使用的是黑色为 +1 白色为 -1。 将0-9的全部数字都定义出来。

然后我们看 上一层F6层输出的向量和0-9的数字中的每一个编码向量求 距离的平方和,就是论文中给出的下面这个公式。

所得距离越小,即越接近0 则对应是该数字的概率就越大。
之所以弄这么一套复杂的编码过程,论文中说是为了高精度辨别类似于数字0和字母O 或者 数字 1和 字母I 这种极其相近的情况。

不过也是能力有限 所以,这里直接使用线性层代替。
nn.Linear(in_features=84, out_features=10)

到这终于是把网络搭建完了,吐了。。

损失函数与优化器:

论文中是在每一个池化层后面加上一个sigmoid作为激活函数,在F6层后面加了一个tanh的激活函数。

虽然还没明白为什么要这么做,不过为了方便,都是用sigmoid作为激活函数,在pytorch里直接都使用CrossEntropyLossr损失函数(softmax + NLLLoss),也是符合文中描写的过程的。

优化器没得说,肯定得随机梯度下降,用SGD啦~

这些函数原理和用法就不在这这里展开说了,不清楚的可以去下面看看:

CrossEntropyLossr详解

Sigmoid详解

SDG详解

到这基本的东西都已经完成了,后面就是训练和绘图分析结果了。

CNN基础论文 精读+复现----LeNet5 (二)相关推荐

  1. CNN基础论文 精读+复现---- ResNet(二)

    文章目录 准备工作 BasicBlock块 ResNet-18.34网络结构 完整代码: 小总结 准备工作 昨天把论文读完了,CNN基础论文 精读+复现---- ResNet(一) ,今天用pytor ...

  2. CNN基础论文 精读+复现----VGG(一)

    文章目录 前言 第1页 第2-3页 第四页 第五页 前言 原文Github地址:https://github.com/shitbro6/paper/blob/main/VGG.pdf 原文arxiv地 ...

  3. CNN基础论文 精读+复现----GoogleNet InceptionV1 (一)

    文章目录 前言 第1页 摘要与引言 第2页 文献综述 第3-4页 第4-5页 inception模块细节 第5-7页 GoogLeNet 第8页 训练细节 第8-10页 ILSVRC 2014 inc ...

  4. 【推荐系统论文精读系列】(二)--Factorization Machines

    文章目录 一.摘要 二.介绍 三.稀疏性下预测 四.分解机(FM) A. Factorization Machine Model B. Factorization Machines as Predic ...

  5. 进阶必备:CNN经典论文代码复现 | 附下载链接

    经常会看到类似的广告<面试算法岗,你被要求复现论文了吗?>不好意思,我真的被问过这个问题.当然也不是所有面试官都会问,究其原因,其实也很好理解.企业肯定是希望自己的产品是有竞争力,有卖点的 ...

  6. 【推荐系统论文精读系列】(八)--Deep Crossing:Web-Scale Modeling without Manually Crafted Combinatorial Features

    文章目录 一.摘要 二.介绍 三.相关工作 四.搜索广告 五.特征表示 5.1 独立特征 5.2 组合特征 六.模型架构 6.1 Embedding层 6.2 Stacking层 6.3 Residu ...

  7. 【推荐系统论文精读系列】(五)--Neural Collaborative Filtering

    文章目录 一.摘要 二.介绍 三.准备知识 3.1 从隐式数据中进行学习 3.2 矩阵分解 四.神经协同过滤 4.1 总体框架 4.1.1 学习NCF 4.2 广义矩阵分解(GMF) 4.3 多层感知 ...

  8. 【推荐系统论文精读系列】(一)--Amazon.com Recommendations

    文章目录 一.摘要 二.推荐算法 三.传统协同过滤 四.聚类模型 五.基于搜索方式 六.基于物品的协同过滤 七.怎样工作? 八.可扩展性 九.总结 References 论文名称:Amazon.com ...

  9. 繁凡的对抗攻击论文精读(二)CVPR 2021 元学习训练模拟器进行超高效黑盒攻击(清华)

    点我轻松弄懂深度学习所有基础和各大主流研究方向入门综述! <繁凡的深度学习笔记>,包含深度学习基础和 TensorFlow2.0,PyTorch 详解,以及 CNN,RNN,GNN,AE, ...

最新文章

  1. Centos7 安装 telnet 服务
  2. android 自定义皮肤,Android Studio 自定义皮肤主题和背景
  3. 将String类型的Json字符串转化对象或对象数组
  4. a partial surjection的题库
  5. MYSQL问题解决方案:Access denied for user 'root'@'localhost' (using password:YES)
  6. acme云服务器生成证书_使用 acme.sh 申请 SSL 证书并且定期自动更新
  7. linq绑定下拉列表,combobox中增加listitem的方法,增加“请选择”
  8. flash大作业一分钟源文件_「百树云课堂」一写作业就像被雷劈,是什么“病”?...
  9. java 获取当前时间,前一天时间
  10. Switch基本知识
  11. Collectors.counting()
  12. PyQt5(designer)入门教程
  13. 《模式识别与机器学习》 简称 PRML 开源了
  14. 2017 年全国大学生电子设计竞赛(本科组)题目√
  15. java接口自动化测试框架搭建
  16. 【AI安全】对抗样本之FGSM的代码实现(TensorFlow2)
  17. php while循环 selecrt下拉框 option默认选中
  18. 封装jquery的方法
  19. 工程伦理第三章学习笔记2020最新
  20. fabric 环境 搭建与安装

热门文章

  1. JAVA面试汇总第四章 Spring及数据库相关
  2. Mac Git 如何设置ssh key
  3. 跨考考研难吗?选择这几个专业更容易上岸!
  4. PSD-BPA南网培训资料
  5. Makefile 零基础学习笔记:if 的用法
  6. bing 高级搜索_如何使用Bing的高级搜索运算符:更好搜索的8条提示
  7. Visual Studio 2012 示例代码浏览器 - 数以千计的开发示例近在手边,唾手可得
  8. ESB和SOA到底是什么?
  9. mysql数据库爆破_mysql数据库密码爆破
  10. Javaweb人才招聘系统