1. 分析

论文:LeNet-5论文,论文中的image是1X32X32的照片,MINST的image是1X28X28,因此代码:self.hidden1 = nn.Linear(256, 120)中第一个维度是256。否则如果image是1X32X32时,代码应该是:self.hidden1 = nn.Linear(400, 120)
python代码——训练模型:

import torch
import torchvision
import torch.nn as nn
import torch.utils.data as DataEPOCH = 5
BATCH_SIZE = 50
LR = 0.002
DOWNLOAD_MINST = Truetrain_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MINST
)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MINST,
)train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,
)test_loader = Data.DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,
)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5, 5),stride=(1, 1),padding=0,),nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2,))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5, 5),stride=(1, 1),padding=0,),nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2,))self.flat = nn.Flatten()self.hidden1 = nn.Linear(256, 120)self.hidden2 = nn.Linear(120, 84)self.predict = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.flat(x)x = torch.relu(self.hidden1(x))x = torch.relu(self.hidden2(x))out = self.predict(x)return outnet = CNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
loss_function = torch.nn.CrossEntropyLoss()def train(data_loader, model, loss_function, optimizer):size = len(data_loader.dataset)model.train()for batch, (X, y) in enumerate(data_loader):X, y = X.to(device), y.to(device)# Compute prediction errorprediction = model(X)loss = loss_function(prediction, y)# Back propagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(data_loader, model, loss_fn):size = len(data_loader.dataset)num_batches = len(data_loader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in data_loader:X, y = X.to(device), y.to(device)prediction = model(X)test_loss += loss_fn(prediction, y).item()correct += (prediction.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")def call_model():for t in range(EPOCH):print(f"Epoch {t + 1}\n-------------------------------")train(train_loader, net, loss_function, optimizer)test(test_loader, net, loss_function)torch.save(net.state_dict(), 'net.pkl')print("Saved PyTorch Model State to model.pth")if __name__ == '__main__':call_model()

LeNet-5模型中激活函数使用的是sigmoid或者tanh函数,而在本文的代码中用的是relu,也就是针对每一个卷积层都是使用:nn.ReLU(),但如果如果一定要参照原文,则使用:nn.Tanh(),或者nn.Sigmoid()
运行结果:

Epoch 1
-------------------------------
loss: 2.306098  [    0/60000]
loss: 0.412156  [ 5000/60000]
loss: 0.412486  [10000/60000]
loss: 0.152816  [15000/60000]
loss: 0.309540  [20000/60000]
loss: 0.100331  [25000/60000]
loss: 0.130987  [30000/60000]
loss: 0.114419  [35000/60000]
loss: 0.151787  [40000/60000]
loss: 0.164818  [45000/60000]
loss: 0.046076  [50000/60000]
loss: 0.023506  [55000/60000]
Test Error:
Accuracy: 97.9%, Avg loss: 0.065859 Epoch 2
-------------------------------
loss: 0.085328  [    0/60000]
loss: 0.045330  [ 5000/60000]
loss: 0.114169  [10000/60000]
loss: 0.186654  [15000/60000]
loss: 0.104414  [20000/60000]
loss: 0.014921  [25000/60000]
loss: 0.102247  [30000/60000]
loss: 0.136939  [35000/60000]
loss: 0.111537  [40000/60000]
loss: 0.004142  [45000/60000]
loss: 0.025103  [50000/60000]
loss: 0.085847  [55000/60000]
Test Error:
Accuracy: 98.3%, Avg loss: 0.063019 Epoch 3
-------------------------------
loss: 0.062379  [    0/60000]
loss: 0.004411  [ 5000/60000]
loss: 0.028583  [10000/60000]
loss: 0.045782  [15000/60000]
loss: 0.230601  [20000/60000]
loss: 0.026126  [25000/60000]
loss: 0.044285  [30000/60000]
loss: 0.016973  [35000/60000]
loss: 0.015370  [40000/60000]
loss: 0.006403  [45000/60000]
loss: 0.018655  [50000/60000]
loss: 0.017848  [55000/60000]
Test Error:
Accuracy: 98.7%, Avg loss: 0.042296 Epoch 4
-------------------------------
loss: 0.045252  [    0/60000]
loss: 0.102186  [ 5000/60000]
loss: 0.030442  [10000/60000]
loss: 0.055944  [15000/60000]
loss: 0.056801  [20000/60000]
loss: 0.057802  [25000/60000]
loss: 0.060436  [30000/60000]
loss: 0.010024  [35000/60000]
loss: 0.012362  [40000/60000]
loss: 0.021393  [45000/60000]
loss: 0.004517  [50000/60000]
loss: 0.002426  [55000/60000]
Test Error:
Accuracy: 98.7%, Avg loss: 0.040705 Epoch 5
-------------------------------

2. 加载模型,测试所有数据集

import torch
import torchvision
import torch.nn as nn
import torch.utils.data as DataEPOCH = 5
BATCH_SIZE = 50
LR = 0.002
DOWNLOAD_MINST = Truetrain_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MINST
)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MINST,
)train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,
)test_loader = Data.DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,
)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5, 5),stride=(1, 1),padding=0,),nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2,))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5, 5),stride=(1, 1),padding=0,),nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2,))self.flat = nn.Flatten()self.hidden1 = nn.Linear(256, 120)self.hidden2 = nn.Linear(120, 84)self.predict = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.flat(x)x = torch.relu(self.hidden1(x))x = torch.relu(self.hidden2(x))out = self.predict(x)return outnet = CNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
loss_function = torch.nn.CrossEntropyLoss()def train(data_loader, model, loss_function, optimizer):size = len(data_loader.dataset)model.train()for batch, (X, y) in enumerate(data_loader):X, y = X.to(device), y.to(device)# Compute prediction errorprediction = model(X)loss = loss_function(prediction, y)# Back propagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(data_loader, model, loss_fn):size = len(data_loader.dataset)num_batches = len(data_loader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in data_loader:X, y = X.to(device), y.to(device)prediction = model(X)test_loss += loss_fn(prediction, y).item()correct += (prediction.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")def call_model():for t in range(EPOCH):print(f"Epoch {t + 1}\n-------------------------------")train(train_loader, net, loss_function, optimizer)test(test_loader, net, loss_function)torch.save(net.state_dict(), 'net.pkl')print("Saved PyTorch Model State to model.pth")def load_model():model = CNN()model.load_state_dict(torch.load("model.pth"))classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",]model.eval()with torch.no_grad():accuracy = 0for (images, labels) in test_loader:predict = model(images)for i in range(len(predict)):predicted, actual = classes[predict[i].argmax(0)], classes[labels[i]]print(f'Predicted: "{predicted}", Actual: "{actual}"')if predicted == actual:accuracy += 1print("accuracy:%.4f" % (accuracy / len(test_data.data)))print('num:%d' % accuracy)if __name__ == '__main__':load_model()

运行结果:

...
accuracy:0.9885
num:9885

LeNet-5实现分类MINST数据集(学习笔记四)相关推荐

  1. 数据集学习笔记(四):VOC转COCO数据集并据txt中图片的名字批量提取对应的图片并保存到另一个文件夹

    文章目录 转换代码 根据名字将图片保存在另一个文件夹 转换代码 import os import random import shutil import sys import json import ...

  2. 吴恩达《机器学习》学习笔记四——单变量线性回归(梯度下降法)代码

    吴恩达<机器学习>学习笔记四--单变量线性回归(梯度下降法)代码 一.问题介绍 二.解决过程及代码讲解 三.函数解释 1. pandas.read_csv()函数 2. DataFrame ...

  3. mysql新增表字段回滚_MySql学习笔记四

    MySql学习笔记四 5.3.数据类型 数值型 整型 小数 定点数 浮点数 字符型 较短的文本:char, varchar 较长的文本:text, blob(较长的二进制数据) 日期型 原则:所选择类 ...

  4. STM32F103学习笔记四 时钟系统

    STM32F103学习笔记四 时钟系统 本文简述了自己学习时钟系统的一些框架,参照风水月 1. 单片机中时钟系统的理解 1.1 概述 时钟是单片机的脉搏,是单片机的驱动源 用任何一个外设都必须打开相应 ...

  5. JavaScript学习笔记(四)(DOM)

    JavaScript学习笔记(四) DOM 一.DOM概述 二.元素对象 2.1 获取方式 (1).通过ID获取一个元素对象,如果没有返回null (2).通过`标签名`获取一组元素对象,,如果没有返 ...

  6. C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻

    前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...

  7. IOS学习笔记(四)之UITextField和UITextView控件学习

    IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...

  8. RabbitMQ学习笔记四:RabbitMQ命令(附疑难问题解决)

    RabbitMQ学习笔记四:RabbitMQ命令(附疑难问题解决) 参考文章: (1)RabbitMQ学习笔记四:RabbitMQ命令(附疑难问题解决) (2)https://www.cnblogs. ...

  9. JSP学习笔记(四十九):抛弃POI,使用iText生成Word文档

    POI操作excel的确很优秀,操作word的功能却不敢令人恭维.我们可以利用iText生成rtf文档,扩展名使用doc即可. 使用iText生成rtf,除了iText的包外,还需要额外的一个支持rt ...

  10. Ethernet/IP 学习笔记四

    Ethernet/IP 学习笔记四 EtherNet/IP Quick Start for Vendors Handbook (PUB213R0): https://www.odva.org/Port ...

最新文章

  1. rtsp连接断开_live555学习之RTSP连接建立以及请求消息处理过程
  2. Spring Remoting: Burlap--转
  3. mfc制作登录界面mysql_MFC制作漂亮界面之登录界面
  4. sqlserver和mysql运营_SQLServer和MySql的区别总结
  5. 【A】兼容Core3.0后 Natasha 的隔离域与热编译操作。
  6. 《Linux KVM虚拟化架构实战指南》——导读
  7. adcetris研发历程_抗体类药物质量控制—张伯彦20130730.pdf
  8. java 代码压缩javascript_通过Java压缩JavaScript代码实例分享
  9. java源码-AtomicReference
  10. java为table添加一行_Js实现Table动态添加一行的小例子
  11. 【译】jquery基础教程(jQuery Fundamentals)——(第一部分)概述
  12. 为了方便远程登录写的简单expect脚本
  13. Unity基础补漏(1)_GameObject类_Time类_Transform类_Camera_光面板_物理面板/物理材质_碰撞检测函数_刚体加力
  14. 程序员必须知道的八件事
  15. 非淡泊无以明志,非宁静无以致远
  16. 获取手机IMEI/ICCID/IMSI
  17. 第七章第二十三题(游戏:储物柜难题)(Game: locker problem)
  18. weui 可移动悬浮按钮
  19. CMS是什么?如何识别CMS?
  20. matlab单双极性眼图程序,求通信大神讲讲这个matlab程序每一段的意思

热门文章

  1. 给2012 年考上北邮的同学的几点建议
  2. java高级工程师哪些技术要掌握?
  3. 信贷风控模型搭建及核心风控模式分类
  4. 40个前端新手入门练习项目,学完即可做项目
  5. 《中国科学》中文论文模板使用CCTTEX编译
  6. 已刷高格固件的路由器如何更换为潘多拉固件
  7. js拼接json对象_JS实现合并json对象的方法
  8. dlna和miracast可以共存吗_关于无线显示技术,AirPlay,DLNA,Miracast,WiDi 等有何异同?...
  9. Day1通信基本概念 通信系统模型 通信系统分类与通信方式
  10. 【广工考试笔记】计算机系统结构考试速成笔记