下面的代码是cnn识别cifar10,如果是cifar100,将数据集的改成cifar100,然后模型的输出神经元10改为100即可。

import torch,torchvision
import torch.nn as nn
import torchvision.transforms as transforms#定义模型
class CNNCifar(nn.Module):def __init__(self):super(CNNCifar,self).__init__()self.feature = nn.Sequential(nn.Conv2d(3,64,3,padding=2),   nn.BatchNorm2d(64),  nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(64,128,3,padding=2), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(128,256,3,padding=1),nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(256,512,3,padding=1),nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2,2))self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(2048, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,10))def forward(self, x):x = self.feature(x)output = self.classifier(x)return outputnet = CNNCifar()
print(net)#加载数据集
apply_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])train_dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True,transform=apply_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=False,transform=apply_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)#定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001,weight_decay=5e-4)#如果有gpu就使用gpu,否则使用cpu
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
net = net.to(device)#训练模型
print('training on: ',device)
def test(): net.eval()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = net(data)loss = criterion(output, target)acc+=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()print('test  acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))def train(): net.train()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = net(data)loss = criterion(output, target)loss.backward()optimizer.step()acc +=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()if batch%200==0:print('\tbatch: %d, loss: %.4f'%(batch, loss.item()))print('train acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))for epoch in range(20):print('epoch: %d'%epoch)train()test()

实验结果:

cnn识别cifar10、cifar100(pytorch)相关推荐

  1. CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)

    完整代码:代码地址https://www.lanzouw.com/iVadvo386ofhttps://www.lanzouw.com/iVadvo386of CGAN比DCGAN更进一步,利用标签信 ...

  2. DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)

    代码下载地址下载地址https://www.lanzouw.com/ipl8Yo37qxihttps://www.lanzouw.com/ipl8Yo37qxi Anime数据请在Anime Face ...

  3. 基于深度学习的口罩识别与检测PyTorch实现

    基于深度学习的口罩识别与检测PyTorch实现 1. 设计思路 1.1 两阶段检测器:先检测人脸,然后将人脸进行分类,戴口罩与不戴口罩. 1.2 一阶段检测器:直接训练口罩检测器,训练样本为人脸的标注 ...

  4. cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...

    首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层: 一,写VGG代码时,首先定义一个 vgg_block(n ...

  5. CNN图像分类Keras代码转换pytorch思路与实现

    tags: Python DL 写在前面 前几天改了一份代码, 是关于深度学习中卷积神经网络的Python代码, 用于解决分类问题. 代码是用TensorFlow的Keras接口写的, 需求是转换成p ...

  6. 动物数据集+动物分类识别训练代码(Pytorch)

    动物数据集+动物分类识别训练代码(Pytorch) 目录 动物数据集+动物分类识别训练代码(Pytorch) 1. 前言 2. Animals-Dataset动物数据集说明 (1)Animals90动 ...

  7. Tensorflow.js||使用 CNN 识别手写数字

    Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...

  8. CNN识别手写数字-莫烦python

    搭建一个 CNN识别手写数字 前面跟着莫烦python/tensorflow教程完成了神经网络识别手写数字的代码,这一part是cnn识别手写数字的 import tensorflow as tf f ...

  9. 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)

    面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 ...

最新文章

  1. CentOS7 升级 Git 版本
  2. 20180530更新
  3. Docker 17.12.0 发布
  4. PHP内核——内存管理
  5. ASP.NET 2.0 中实现模板中的数据绑定系列(2)
  6. android 定义固定数组,Android 图片数组定义和读取
  7. Python小白的数学建模课-15.图论的基本概念
  8. WordPress删除重复文章插件
  9. 华为Mate 30系列或下血本采用双主摄方案:CMOS尺寸破纪录
  10. php面试题9(看的时候就应该随手截图做笔记的)
  11. 联想高校AI精英挑战赛总冠军出炉!助力中国迎来智能变革
  12. 软件测试工程师 Linux 十大场景命令使用
  13. 云计算领导者,自主研发虚级化产品,华胜天成IVCS
  14. 算法与数据结构实验题 4.1 伊姐姐数字 game
  15. 抗锯齿_《战地V》深度学习抗锯齿性能测试:对它最友好的竟然是4K
  16. 数据增强操作(旋转、翻转、裁剪、色彩变化、高斯噪声等)
  17. 【linux运维】linux运维常用工具有哪些?
  18. linux安装系统识别不到硬盘,安装系统找不到硬盘解决方法【图文教程】
  19. 谁有《线性系统理论习题与解答》郑大钟
  20. 深善扶贫:深圳弘法寺启动“春风谷雨”送温暖行动

热门文章

  1. 3.1.10 基本分段存储管理方式
  2. Dubbo的Zookeeper版本
  3. Jedis的Spring配置
  4. jQuery的选择器分类
  5. 【数据结构-图】3.图的最短路径的几种算法(图解+演绎)
  6. 【图解Java】这下可以真的弄懂Java IO了~
  7. 【Java】6.1 Java 8增强的包装类
  8. 【PAT笔记】PAT中的散列思想
  9. will not add file alias already exists in index(git上传代码出错)
  10. PROC简单使用用例--VC连接ORACLE