PyTorch 实现经典模型2:AlexNet
AlexNet
网络结构
论文总结
- 成功使用ReLU作为CNN的激活函数,并验证效果在较深的网络中超过了Sigmoid。成功解决了Sigmoid在网络较深时的梯度弥散问题。
- 训练时使用Dropout随机忽略了一部分神经元,以避免模型过拟合。Dropout虽有专门的论文论述,但是AlexNet将其实用化,通过实践证明了它的效果。在AlexNet中主要是最后几个全连接层使用了Dropout。
- 在CNN中使用重叠的最大池化。此前CNN中普遍使用平均池化,AlexNet全部使用最大池化,避免平均池化的模糊化效果。并且让步长比池化核的尺寸小,这样池化层的输出之间会有重叠和覆盖,提升了特征丰富性。
- 提出了LRN层,对局部神经元的活动创建竞争机制,使得其中相应比较大的值变得相对更大,并抑制其他反馈较小的神经元。增强了模型的泛化能力(但后来的VGG证明这个作用不大)。
- 使用CUDA加速深度卷积网络的训练,利用GPU强大的并行计算能力,处理神经网络训练时大量的矩阵计算。AlexNet使用了两块GTX580GPU进行训练,同时AlexNet的设计让GPU之间的通信只在网络的某些层进行,控制了通信的性能损耗。
- 数据增强。随机地从256x256的原始图像中截取224x224大小的区域(以及水平翻转的镜像)对图像的RGB数据进行PCA处理,并对主成分做一个标准差为0.1的高斯扰动,增加一些噪声,这个Trick可以让错误率再下降1%。
代码实现
1) 导入必需的包
# 1) 导入必需的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
2) 搭建网络模型
# 2) 搭建网络模型
class AlexNet(nn.Module):def __init__(self):super(AlexNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # LRN(local_size=5, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True))self.layer2 = nn.Sequential(nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2, groups=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # LRN(local_size=5, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True))self.layer3 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),nn.ReLU(inplace=True))self.layer5 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2))# 由此从卷积变为全连接层self.layer6 = nn.Sequential(nn.Linear(in_features=6*6*256, out_features=4096), nn.ReLU(inplace=True), nn.Dropout())self.layer7 = nn.Sequential(nn.Linear(in_features=4096, out_features=4096), nn.ReLU(inplace=True), nn.Dropout())self.layer8 = nn.Linear(in_features=4096, out_features=1000)def forward(self, x):x = self.layer5(self.layer4(self.layer3(self.layer2(self.layer1(x)))))x = x.view(-1, 6*6*256)x = self.layer8(self.layer7(self.layer6(x)))return x
3) 导入使用的数据集、网络结构、优化器、损失函数等
4) 训练模型
5) 保存模型结构参数
6) 加载模型并测试模型效果
Ref
- Pytorch手撕经典网络之AlexNet
PyTorch 实现经典模型2:AlexNet相关推荐
- PyTorch 实现经典模型3:VGG
VGG 网络结构 代码 1) 导入必需的包 # 1) 导入必需的包 import torch import torch.nn as nn import torch.nn.functional as F ...
- PyTorch 实现经典模型6:RCNN (Fast RCNN, Faster RCNN)
RCNN (Fast RCNN, Faster RCNN) 适用范围 图像分类(Image classification) 目标检测(Object detection) 网络结构 代码 Ref R-C ...
- PyTorch 实现经典模型5:ResNet
ResNet 网络结构 代码 #------------------------------用50行代码搭建ResNet---------------------------------------- ...
- PyTorch 实现经典模型4:GoogLeNet
GoogLeNet 创新点: 通过多种卷积核叠加网络复杂性 学习多种卷积 提升计算效率 GoogLeNet网络深度达到22层 1x1卷积降低维度 降低计算量,提升计算效率 网络结构 代码 import ...
- PyTorch 实现经典模型1:LeNet5
模型:LeNet5 网络结构 符号说明 网络参数 代码实现 1) 导入必需的包 2) 搭建网络模型 3) 导入使用的数据集 4) 训练模型 5) 保存模型 6) 测试模型效果 所遇错误 '_Incom ...
- PyTorch 实现经典模型8:FCN
FCN 网络结构 代码 class fcn(nn.Module):def __init__(self, num_classes):super(fcn, self).__init__()self.sta ...
- PyTorch 实现经典模型7:YOLO (v1, v2, v3, v4)
YOLO (v1, v2, v3, v4) 网络结构 YOLO v3 网络结构 代码 Ref <机器爱学习>YOLO v1深入理解 <机器爱学习>YOLOv2 / YOLO90 ...
- 计算机视觉:基于眼疾分类数据集iChallenge-PM图像分类经典模型剖析(LeNet,AlexNet,VGG,GoogLeNet,ResNet)
计算机视觉:图像分类经典模型 LeNet AlexNet VGG GoogLeNet ResNet 图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测.图像分割.物体跟 ...
- PyTorch Hub发布获Yann LeCun强推!一行代码调用经典模型
作者 | Team PyTorch 译者 | Monanfei 责编 | 夕颜 出品 | AI科技大本营(ID: rgznai100) 导读:6月11日,Facebook PyTorch 团队推出了全 ...
最新文章
- 单片机是否为嵌入式技术,单片机和嵌入式学哪个?
- 深度学习- Dropout 稀疏化原理解析
- [我的1024开源程序]60元写的宠物网页和音乐网页
- 加密、解密、摘要、签名、证书一文搞懂
- java学习之三种常用设计模式
- leftjoin多个on条件_MYSQL|为什么LEFT JOIN会这么慢?
- java解析html jsoup_2020-06-02 jsoup java解析html
- Linux Vim显示行号方法详解
- mysql 参数调整_mysql需要调整的参数-阿里云开发者社区
- 带你认识MindSpore量子机器学习库MindQuantum
- mate30后续用鸿蒙系统,mate30可以升级鸿蒙不?升级后还能退回原系统吗?
- js复杂对象和简单对象的简单转化
- wps文档一敲空格就换行_wps敲空格变成点
- 浏览器主页被劫持成360导航.每次打开都是360导航https://hao.360.cn/?src=lmls=n36a7f6a197
- 两台电脑间的串口通信
- 电脑光驱不见了(错误代码39 黄色感叹号)的解决办法
- ~囍~ Evening Star 篇
- 如何找到两个圆的公切线?
- word在使用Endnote时变得非常卡解决办法
- R统计笔记(二):投影运算与转换