MINIST

具体来看,MNIST手写数字数据集包含有60000张图片作为训练集数据,10000张图片作为测试集数据,且每一个训练元素都是28*28像素的手写数字图片,每一张图片代表的是从0到9中的每个数字。

mnist手写体数据集里的标准化参数transforms.Normalize((0.1307,), (0.3081,))

mnist手写体数据集里的标准化参数transforms.Normalize((0.1307,), (0.3081,))_ZJE-CSDN博客_mnist normalize

其中,0.1307和0.3081是mnist数据集的均值和标准差,因为mnist数据值都是灰度图,所以图像的通道数只有一个,因此均值和标准差各一个。要是imagenet数据集的话,由于它的图像都是RGB图像,因此他们的均值和标准差各3个,分别对应其R,G,B值。例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的标准化系数(RGB三个通道对应三组系数)。数据集给出的均值和标准差系数,每个数据集都不同的,都是数据集提供方给出的。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transformsBATCH_SIZE=512     # batch_size即每批训练的样本数量
EPOCHS=20          # 循环次数# 让torch判断是否使用GPU,即device定义为CUDA或CPU
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")train_data = datasets.MNIST('./data',train=True,transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # 标准化] ),download=True)
train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)test_data = datasets.MNIST('./data',train=False,transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # 标准化] ),download=True)
test_loader = DataLoader(test_data,batch_size=BATCH_SIZE,shuffle=True)#定义神经网络
class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()#128*128self.conv1 = nn.Conv2d(1,10,kernel_size=5)  #灰度图片的通道1self.conv2 = nn.Conv2d(10,20,kernel_size=3)self.fc1 = nn.Linear(20*10*10,500)self.fc2 = nn.Linear(500,10)def forward(self,x):in_size = x.size(0)  # in_size 为 batch_size(一个batch中的Sample数)x = self.conv1(x)     #输入batch_size*28*28 输出batch_size*24*24x = F.relu(x)x = F.max_pool2d(x,2,2)     #batch_size*14*14x = self.conv2(x)   #batch_size*20*10*10x = F.relu(x)x = x.view(in_size,-1)  #拉伸 20*10*10=2000x = self.fc1(x)     #输入2000  输出batch*500x= F.relu(x)x = self.fc2(x)# softmaxoutput = F.log_softmax(x, dim=1)     #计算分类后每个数字的概率# 返回值 outputreturn output
#定义优化器
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#训练
def train(model,device,train_loader,optimizer,epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):# print(len(train_loader))# print("train_loader_datasets",len(train_loader.dataset))# print("batch_idx:",batch_idx,"data.shape",data.shape,"data_len:",len(data))data, target = data.to(device), target.to(device)  # CPU转GPUoptimizer.zero_grad()  # 优化器清零output = model(data)  # 由model,计算输出值loss = F.cross_entropy(output, target)  # 计算损失函数losspred = output.max(1,keepdim=True)loss.backward()  # loss反向传播optimizer.step()  # 优化器优化if (batch_idx + 1) % 30 == 0:  # 输出结果print("batch_idx:", batch_idx, "batch_idx * len(data)", batch_idx * len(data), "data_len:", len(data),"len(train_loader):",len(train_loader),"train_loader_datasets",len(train_loader.dataset))print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# ---------------------测试函数------------------------------
# 测试的操作也一样封装成一个函数
def test(model, device, test_loader):model.eval()test_loss = 0                           # 损失函数初始化为0correct = 0                             # correct 计数分类正确的数目with torch.no_grad():           # 表示不反向求导(反向求导为训练过程)for data, target in test_loader:    # 遍历所有的data和targetdata, target = data.to(device), target.to(device)   # CPU -> GPUoutput = model(data)            # output为预测值,由model计算出test_loss += F.cross_entropy(output, target).item()     ### 将一批的损失相加pred = output.max(1, keepdim=True)[1]       ### 找到概率最大的下标# pred = torch.argmax(output,dim=1)# print("pred",pred)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)    # 总损失除数据集总数print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))# 下面开始训练,这里就体现出封装起来的好处了,只要写两行就可以了
# 整个数据集只过一遍
for epoch in range(1,EPOCHS + 1):train(model,DEVICE,train_loader,optimizer,epoch)test(model,DEVICE,test_loader)torch.save(model,"shmodel.pth")

结果

batch_idx: 29 batch_idx * len(data) 14848 data_len: 512 len(train_loader): 118 train_loader_datasets 60000

Train Epoch: 1 [14848/60000 (25%)]  Loss: 0.387592

batch_idx: 59 batch_idx * len(data) 30208 data_len: 512 len(train_loader): 118 train_loader_datasets 60000

Train Epoch: 1 [30208/60000 (50%)]  Loss: 0.187167

batch_idx: 89 batch_idx * len(data) 45568 data_len: 512 len(train_loader): 118 train_loader_datasets 60000

Train Epoch: 1 [45568/60000 (75%)]  Loss: 0.141795

Test set: Average loss: 0.0002, Accuracy: 9686/10000 (97%)

batch_size: 512

batch_idx:每次训练的batch的索引,0--117

data:[512, 1, 28, 28]      data_len: 512         512是batch_size

train_loader:118  0-117

train_loader_datasets: 60000                test_loader_datasets 10000

CNN--MINIST相关推荐

  1. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  2. pytorch0.4版的CNN对minist分类

    卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功 ...

  3. C++元编程——CNN进行Minist手写数字识别

    Minist数据来源: MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges 数据的格式如下: C ...

  4. 基于CNN的MINIST手写数字识别项目代码以及原理详解

    文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...

  5. 基于CNn的MINIST手写体识别

    深度学习的上机作业: 基于CNN卷积神经网络的MINIST手写体识别 版本:python-3.9,tensorflow-2.9 目录 MINIST数据集 训练CNN卷积神经网络 使用训练好的模型进行预 ...

  6. 深度卷积网络CNN与图像语义分割

    转载请注明出处:  http://xiahouzuoxin.github.io/notes/html/深度卷积网络CNN与图像语义分割.html 级别1:DL快速上手 级别2:从Caffe着手实践 级 ...

  7. TensorFlow基于minist数据集实现手写字识别实战的三个模型

    手写字识别 model1:输入层→全连接→输出层softmax model2:输入层→全连接→隐含层→全连接→输出层softmax model3:输入层→卷积层1→卷积层2→全连接→dropout层→ ...

  8. matlab 对mnist手写数字数据集进行判决分析_人工智能TensorFlow(十四)MINIST手写数字识别...

    MNIST是一个简单的视觉计算数据集,它是像下面这样手写的数字图片: MNIST 每张图片还额外有一个标签记录了图片上数字是几,例如上面几张图的标签就是:5.0.4.1. MINIST数据 MINIS ...

  9. TensorFlow-CIFAR10 CNN代码分析

    CIFAR 代码组织 代码分析 cifar10_trainpy cifar10py cifar10_evalpy Reference 根据TensorFlow 1.2.1,改了官方版本的报错. CIF ...

  10. AI:IPPR的数学表示-CNN结构/参数分析

    前言:CNN迎接多类的挑战 特定类型的传统PR方法特征提取的方法是固定的,模式函数的形式是固定的,在理论上产生了特定的"局限性" 的,分类准确度可以使用PAC学习理论的方法计算出来 ...

最新文章

  1. 【android】android中activity的生命周期
  2. HttpClient发送Get请求(java)【从新浪云搬运】
  3. 组合数(Combinatorial_Number)
  4. 卡西欧9860连接电脑数据传输_轻松办公好助手,卡西欧STYLISH计算器体验记
  5. python招生海报_从原研哉的哲学中学习海报设计
  6. Abp太重了?轻量化Abp框架
  7. 把变量赋值给寄存器_散装 vs 批发谁效率高?变量访问被ARM架构安排的明明白白...
  8. Intel Sandy Bridge/Ivy Bridge架构/微架构/流水线 (6) - 流水线前端微熔合/宏熔合
  9. 贫血模式or领域模式(转载)
  10. 二分专题(不定期更新)
  11. 《OpenCV 4.5计算机视觉开发实战(基于VC++)》示例代码免费下载
  12. windows update 无法启动 报错87:参数错误的解决方法
  13. 24券创始人杜一楠的失败检讨书:24券是如何毁在我手上的?[转]
  14. https之证书验证
  15. BZOJ2901: 矩阵求和
  16. 如何在文字识别软件ABBYY中创建区域模板,处理大量相同内容?
  17. 不能邮箱登录的网站都是耍流氓【无力吐槽】
  18. python stm32 usb bulk_STM32-USB学习笔记(一) USB基础
  19. OpenWRT路由器——网络打印服务器
  20. 2021年全球与中国生物芯片行业市场规模及发展前景分析

热门文章

  1. Python解析mat文件
  2. 7款可视化工具,提高开发效率必备
  3. 带你轻轻松松了解route-map
  4. 增量式解析大型XML文件
  5. linux下nmon的安装及使用教程
  6. Google式用户体验的十大内在原则
  7. 如何写一首悲伤的原创歌曲?
  8. (一)Redis实战教程之redis简介
  9. mysql元器件数据库_Capture CIS连接元器件数据库系统的方法
  10. 几个最新免费开源的中文语音数据集