下面的代码是cnn是被MNIST,如果识别Fashion-MNIST,可以将数据集换成Fashion-MNIST即可。

第一个全连接的输入神经元个数如何确定,可以参考我的另一篇博客。即nn.lInear(1600,128)的中数字1600如何确定的?

import torch,torchvision
import torch.nn as nn#定义模型
class CNNMnist(nn.Module):def __init__(self):super(CNNMnist,self).__init__()self.feature = nn.Sequential(nn.Conv2d(1,32,3), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(32,64,3), nn.ReLU(), nn.MaxPool2d(2,2))self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(1600, 128),nn.ReLU(),nn.Linear(128,10))def forward(self, x):x = self.feature(x)output = self.classifier(x)return outputnet = CNNMnist()#加载数据集
apply_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True,transform=apply_transform)
test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=False,transform=apply_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)#定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#如果有gpu就使用gpu,否则使用cpu
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
net = net.to(device)#训练模型
print('training on: ',device)def test(test_loader): net.eval()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = net(data)loss = criterion(output, target)acc+=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()print('test acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))def train(): net.train()loss_sum = 0for batch, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = net(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch%200==0:print('\tbatch: %d, loss: %.4f'%(batch, loss.item()))for epoch in range(5):print('epoch: %d'%epoch)train()test(test_loader)

实验结果:

cnn识别mnist、Fashion-MNIST(pytorch)相关推荐

  1. cnn识别cifar10、cifar100(pytorch)

    下面的代码是cnn识别cifar10,如果是cifar100,将数据集的改成cifar100,然后模型的输出神经元10改为100即可. import torch,torchvision import ...

  2. 【人工智能项目】Fashion Mnist识别实验

    [人工智能项目]Fashion Mnist识别实验 本次主要通过四个方法对fashion mnist进行识别实验,主要为词袋模型.hog特征.mlp多层感知器和cnn卷积神经网络.那么话不多说,走起来 ...

  3. tensorflow2.0 CNN fashion MNIST图像分类

    基于 CNN的 fashion MNIST图像分类 fashion MNIST图像分类 数据集简介 数据的预处理 CNN简介和构建 模型部分代码 CNN实验结果 致谢 fashion MNIST图像分 ...

  4. fashionmnist数据集_Keras实现Fashion MNIST数据集分类

    本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...

  5. Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类

    1.引言 1.1.什么是Pytorch PyTorch是一个开源的Python机器学习库. 1.2.什么是CNN 卷积神经网络(Convolutional Neural Networks)是一种深度学 ...

  6. TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%

    TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...

  7. python cnn程序_python cnn训练(针对Fashion MNIST数据集)

    本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...

  8. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  9. TensorFlow 学习(六)时尚(衣服、鞋、包等) Fashion MNIST识别

    使用 jupyter notebook 笔记,导入需要的包 # TensorFlow and tf.keras import tensorflow as tf from tensorflow impo ...

  10. 3层-CNN卷积神经网络预测MNIST数字

    3层-CNN卷积神经网络预测MNIST数字 本文创建一个简单的三层卷积网络来预测 MNIST 数字.这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成. MNIST 由 ...

最新文章

  1. 蓝桥杯-递归求二项式系数值(java)
  2. golang设置运行CPU数量及sync.Mutex全局互斥锁的使用示例
  3. android flux 与mvp,使用 MVP 时在设计上的考量
  4. 7段均衡器最佳调节图_超高级的吉他均衡器 更细腻的控制 你值得拥有
  5. [转载]Windowsnbsp;Servernbsp;2008nbsp;R2nbsp;之二十五ADnbsp;RMS信任策略
  6. C语言指针概念全面解析
  7. 失业了又不想进厂打工,怎么办
  8. kali linux下sqlmap使用教程
  9. Android电视远程桌面,180元让电视变电脑 远程桌面终端评测
  10. 【极乐净土mmd】动作+镜头数据下载
  11. 能测电机温度和振动在线测量工具——温振变送器
  12. 远程桌面访问软件:TeamViewer
  13. 单点登录系统设计分析
  14. 通过5个概念 一文弄明白DAO
  15. 打开程序员心理B面,这些黑红话题他们亲自回应丨1024特辑
  16. Java Web实训项目:西蒙购物网1
  17. How to Rerun Failed Tests in JUnit?
  18. 创业融资路演PPT模板
  19. idea 2020.3更新后如何实现run parallel
  20. 什么是RC高通滤波电路

热门文章

  1. 2021春季每日一题【week8 未完结】
  2. 第三届“传智杯”全国大学生IT技能大赛(初赛B组)【C++】
  3. 红米路由器ac2100怎样设置ipv6_红米(Redmi)路由器AC2100手机怎么设置?
  4. 没有到主机的路由_网络基础知识:UDP协议之路由跟踪
  5. 一文读懂 volatile 关键字
  6. 信息系统项目管理知识--信息安全
  7. 蓝桥杯java第八届第八题--包子凑数
  8. 03-JDBC学习手册:JDBC中几个重要接口和异常处理
  9. (SpringMVC)拦截器
  10. (仿头条APP项目)7.首页标签页完善和微头条页面设计实现