1.环境
Ubuntu20.04
Vscode
Cuda 11.2
Pytorch 1.8
2.代码

import time
import torch
import torchvision
from torch import nn,optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.conv = nn.Sequential(nn.Conv2d(1,6,5),nn.Sigmoid(),nn.MaxPool2d(2,2),nn.Conv2d(6,16,5),nn.Sigmoid(),nn.MaxPool2d(2,2))self.fc = nn.Sequential(nn.Linear(16*4*4,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))def forward(self,img):feature = self.conv(img)output = self.fc(feature.view(img.shape[0],-1))return outputnet = LeNet()def load_data_fashion_mnist(batch_size,resize=None,root='~/Datasets/FashionMNIST'):trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,download=True,transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,download=True,transform=transform)train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=4)test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=4)return train_iter,test_iterdef evaluate_accuracy(data_iter,net,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):acc_sum,n = 0.0,0with torch.no_grad():for X,y in data_iter:if isinstance(net,torch.nn.Module):net.eval()acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train()else:if('is_training' in net.__code__.co_varnames):acc_sum += (net(X,is_training=False).argmax(dim=1) == y).float().sum().item()else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum/ndef train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs):net = net.to(device)print("training on ",device)loss = torch.nn.CrossEntropyLoss()batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0,0.0,0,time.time()for X,y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat,y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter,net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' %(epoch + 1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start))batch_size = 256
train_iter,test_iter = load_data_fashion_mnist(batch_size=batch_size)lr, num_epochs = 0.001, 10
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)

3.结果

Pytorch学习笔记——LeNet模型相关推荐

  1. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  2. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  3. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  4. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  5. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  6. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  7. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  8. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  9. 深度学习入门之PyTorch学习笔记

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 5 循环神经网络 6 生成对抗网络 7 深度学习实战 参考资料 绪论 深度学习如今 ...

最新文章

  1. 嵌入式开发之davinci--- DVRRDK, EZSDK和DVSDK这三者有什么区别
  2. React Redux 的一些基本知识点
  3. 注释驱动的 Spring cache 缓存介绍--转载
  4. 实战分享之专业领域词汇无监督挖掘
  5. nb信号和4g信号_手机信号很强但是4G网络却很卡?学会这三招,立马恢复网速
  6. 【theano-windows】学习笔记一——theano中的变量
  7. 完数c++语言程序_C语言经典100题(19)
  8. 【华为云技术分享】云小课 | SAP S/4HANA高可用之实战演练
  9. 【电脑硬件问题】视频接口和显示器偏色
  10. jquery实现点击元素,如果弹出层隐藏则显示,显示则隐藏
  11. FlashFXP如何破解
  12. u8转完看不到菜单_进入软件后所有菜单栏都不显示
  13. matlab导入数据文件,matlab怎么导入mat数据文件
  14. 【pycharm】复制粘贴快捷键失效
  15. Codeforces Gym 100339B Diversion 树形DP + LCA
  16. 【电力电子】【2013】基于对称分量提取的三电平三相并网变流器电压暂降时的电网同步与控制
  17. 孙溟㠭篆刻作品《叶》
  18. 论软件开发过程RUP及其应用
  19. 麦克风离计算机主机多远,唱歌的时候,嘴离麦克风多远最好听?
  20. 产品说明书—分类宝(WasteSorting)

热门文章

  1. linux(ubuntu)下vi命令(例:sudo vi ~/.bashrc)
  2. 【敏捷办公学习必备软件推荐】
  3. LeetCode 409. 最长回文串(构造最长回文判断)
  4. @FeignClient使用详解
  5. 美国挡不住商汤:仅一周后重启IPO,新增基石投资3.8亿元,年前30号挂牌上市
  6. Python+Vue计算机毕业设计超市积分管理系统o2qyn(源码+程序+LW+部署)
  7. 微信页面在浏览器打开
  8. 2006年安全软件全球纵览(转
  9. oge 封包工具 封包_什么是封包?
  10. Boom 3D序列号授权超赞的音效增强及播放工具