网络介绍:

Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构:

网络结构较为简单,共有五个卷积层和三个全连接层,原文作者在训练时使用了多卡一起训练,具体细节可以阅读原文得到。
Alexnet文章链接:http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
作者在网络中使用了Relu激活函数和Dropout等方法来防止过拟合,更多细节看文章。

数据集介绍

使用的是MNIST手写数字识别数据集,torchvision中自带有数据集的下载地址。

定义网络结构

就按照网络结构图中一层一层的定义就行,其中第1,2,5层卷积层后面接有Max pooling层和Relu激活函数,五层卷积之后得到图像的特征表示,送入全连接层中进行分类。

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/2 下午3:25import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optimclass AlexNet(nn.Module):def __init__(self,width_mult=1):super(AlexNet,self).__init__()#定义每一个就卷积层self.layer1 = nn.Sequential(#卷积层  #输入图像为1*28*28nn.Conv2d(1,32,kernel_size=3,padding=1),#池化层nn.MaxPool2d(kernel_size=2,stride=2)  ,   #池化层特征图通道数不改变,每个特征图的分辨率变小#激活函数Relunn.ReLU(inplace=True),)self.layer2 = nn.Sequential(nn.Conv2d(32,64,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=2,stride=2),nn.ReLU(inplace=True),)self.layer3 = nn.Sequential(nn.Conv2d(64,128,kernel_size=3,padding=1),)self.layer4 = nn.Sequential(nn.Conv2d(128,256,kernel_size=3,padding=1),)self.layer5 = nn.Sequential(nn.Conv2d(256,256,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3, stride=2),nn.ReLU(inplace=True),)#定义全连接层self.fc1 = nn.Linear(256 * 3 * 3,1024)self.fc2 = nn.Linear(1024,512)self.fc3 = nn.Linear(512,10)#对应十个类别的输出def forward(self,x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)x = x.view(-1,256*3*3)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x

训练

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/2 下午3:38import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
from alexnet import AlexNet
import cv2
from utils import plot_image,plot_curve,one_hot
#定义使用GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")#设置超参数
epochs = 30
batch_size = 256
lr = 0.01train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),#数据归一化torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True
)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = 256,shuffle = False
)#定义损失函数
criterion = nn.CrossEntropyLoss()#定义网络
net = AlexNet().to(device)#定义优化器
optimzer = optim.SGD(net.parameters(),lr=lr,momentum = 0.9)#train
train_loss = []
for epoch in range(epochs):sum_loss = 0.0for batch_idx,(x,y) in enumerate(train_loader):print(x.shape)x = x.to(device)y = y.to(device)#梯度清零optimzer.zero_grad()pred = net(x)loss = criterion(pred, y)loss.backward()optimzer.step()train_loss.append(loss.item())sum_loss += loss.item()if batch_idx % 100 == 99:print('[%d, %d] loss: %.03f'% (epoch + 1, batch_idx + 1, sum_loss / 100))sum_loss = 0.0
torch.save(net.state_dict(),'/home/lwf/code/pytorch学习/alexnet图像分类/model/model.pth')
plot_curve(train_loss)

使用交叉熵损失函数和SGD优化器来训练网络,训练后保存模型至本地。

训练过程中损失函数的收敛过程:

测试准确率

完整代码:https://github.com/SPECTRELWF/pytorch-cnn-study/tree/main/Alexnet-MNIST
个人主页:http://liuweifeng.top:8090/

使用PYTORCH复现ALEXNET实现MNIST手写数字识别相关推荐

  1. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  2. Pytorch实现mnist手写数字识别

    2020/6/29 Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别.可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能. Mnist数据集如图 ...

  3. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  4. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  5. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  6. Pytorch实战1:LeNet手写数字识别 (MNIST数据集)

    版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作! Pytorch实战1:LeNet手写数字识别 (MNIST数据集) 实验环境: Pytorch 0.4.0 torchvision 0.2. ...

  7. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

  8. C语言底层搭建CNN实现MNIST手写数字识别

    手写数字识别 手写数字识别是指使用计算机自动识别手写体阿拉伯数字的技术.作为光学字符识别OCR的一个分支,它可以被广泛应用到手写数据的自动录入场景中.传统的识别方法如最近邻算法k-NN.支持向量机SV ...

  9. FlyAi实战之MNIST手写数字识别练习赛(准确率99.55%)

    欢迎关注WX公众号:[程序员管小亮] 文章目录 欢迎关注WX公众号:[程序员管小亮] 一.介绍 二.代码实现 1_数据加载 2_归一化 3_定义网络结构 4_设置优化器和退火函数 5_数据增强 6_拟 ...

最新文章

  1. ***“出更”---获取源码的***
  2. poj2823 线段树模板题 点修改(也可以用单调队列)
  3. Salesforce中所有常用类型字段的取值与赋值
  4. 抽了几天用Flex写了个上传小工具,支持批量上传,支持配置
  5. 团子大家族(clannad)
  6. Python lambda表达式
  7. WebAssembly和Blazor:解决了一个存在十年的老问题
  8. python获取某文件路径_Python获取当前文件路径
  9. Spring : Spring @Transactional-嵌套事物回滚
  10. Redis缓存穿透、缓存雪崩和缓存击穿理解
  11. java 显示锁_Java 实现一个自己的显式锁Lock(有超时功能)
  12. WPF DataGrid 获取选中的当前行某列值
  13. eclipse 点击 ctrl+鼠标左键看不了源码问题解决
  14. 通信总线模块:RS485、SP3232
  15. 如何设置PPT里的表格行高等高
  16. springboot微信登陆
  17. Linux服务器挂载ntfs硬盘,Linux中挂载NTFS格式的硬盘的方法
  18. pyraformer: low-complexity pyramidal attention for long-range time series modeling and forecasting
  19. Java-构造方法(constructor)
  20. 使用runas命令让域用户可以以管理员权限运行程序

热门文章

  1. 三十四、使用pytesser3 和pillow完成图形验证码的识别
  2. spring mvc 渲染html,在Spring MVC中使用Thymeleaf模板渲染Web视图
  3. GPLinker:基于GlobalPointer的事件联合抽取
  4. NeurIPS 2021 | PCAN:高效时序建模,提升多目标追踪与分割性能
  5. 报名|第2期“DI极客说”,揭秘决策AI创新应用带来的行业变革
  6. AI顶会直播丨深度学习顶级会议ICLR 2021中国预讲会明天召开,为期三天五大论坛...
  7. NeurIPS 2020 | 自步对比学习:充分挖掘无监督学习样本
  8. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类
  9. dataset_flickr8k.json与dataset_flickr30k.json的比较
  10. InfluxData【环境搭建 01】时序数据库 InfluxDB 最新版本安装启动验证(在线安装+离线安装及各版本下载地址)