数据集MNIST


代码:

#1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
#2 定义超参数
BATCH_SIZE = 16 #每批处理的数据 GPU好 可以选 64 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #是否用GPU
EPOCHS = 10 #训练数据集的轮次
#3 构建pipeline  transforms,对图像做处理
pipeline = transforms.Compose([transforms.ToTensor(), #将图片转化成tensortransforms.Normalize((0.1307,), (0.3081,)) #正则化:模型出现过拟合现象时,降低模型复杂度
])
#4 下载、加载数据
from torch.utils.data import DataLoader#下载数据集
train_set = datasets.MNIST("data", train=True, download=True, transform=pipeline)test_set = datasets.MNIST("data", train=False, download=True, transform=pipeline)#加载数据
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
#插入代码 读取MNIST中的图片
with open("./data/raw/train-images-idx3-ubyte","rb") as f:file=f.read()image1 = [int(str(item).encode('ascii'),16) for item in file[16 : 16+784]]
print(image1)import cv2
import numpy as npimage1_np = np.array(image1, dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)cv2.imwrite("digit,jpg", image1_np)
#5 构建网络模型
class Digit(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1,10,5) # 1:灰度图片的通道 10:输出通道 5:kernel  self.conv2 = nn.Conv2d(10,20,3) #10:输入通道 20:输出通道  3:kernelself.fc1 = nn.Linear(20*10*10,500) #20*10*10:输出通道 500:输出通道self.fc2 = nn.Linear(500,10) #500输入通道 10:输出通道def forward(self,x):input_size = x.size(0) #batch_sizex = self.conv1(x) #输入:batch_size*1*28*28, 输出:batch*10*24*24(28-5+1=24)x = F.relu(x) #保持shape不变, 输出:batch*10*24*24x = F.max_pool2d(x,2,2) #池化层对图片进行压缩(降采样) 输入:batch*10*24*24 输出:batch*10*12*12x = self.conv2(x) #输入:batch*10*12*12 输出:batch*20*10*10(12-3+1=10)x =  F.relu(x) #保持shape不变,输出:batch*20*10*10x = x.view(input_size, -1) #拉平, -1自动计算维度, 20*10*10=2000x = self.fc1(x) #输入:batch*2000 输出:batch*500x = F.relu(x)x = self.fc2(x) #输入:batch*500 输出:batch*10output = F.log_softmax(x, dim=1) #计算分类后,每个数字的概率值return output
#6 定义优化器
model = Digit().to(DEVICE)optimizer = optim.Adam(model.parameters())
#7 定义训练方法
def train_model(model, device, train_loader, optimizer, epoch):#模型训练model.train()for batch_index, (data, target) in enumerate(train_loader):#部署到DEVICE上去data, target = data.to(device), target.to(device)#梯度初始化为0optimizer.zero_grad()#训练后的结果output = model(data)#计算损失loss = F.cross_entropy(output,target) #交叉熵损失#反向传播loss.backward()#参数优化optimizer.step()  if batch_index %3000 == 0: #每3000个数据输出一个结果print("Train Epoch :{} \t Loss : {: .6f}".format(epoch, loss.item())) #输出结果保留6位小数
#8 定义测试方法
def test_model(model, device, test_loader):
#模型验证model.eval()
#正确率correct = 0.0
#测试损失test_loss = 0.0with torch.no_grad(): #不需要计算梯度,也不需要进行反向传播for data, target in test_loader:#部署到device上data, target = data.to(device), target.to(device)#测试数据output = model(data)#计算测试损失test_loss += F.cross_entropy(output, target).item()#找到概率值最大的下标pred = output.argmax(dim=1) #或者写: pred = torch.max(output, dim=1)   pred = output,max(1, keepdim=True)[1]#累计正确的值correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print("Test -- Average loss : {: .4f}, Accuracy : {: .3f}\n". format(test_loss, 100.0*correct / len(test_loader.dataset)))
#9 调用 方法7 、8
for epoch in range(1, EPOCHS +1):train_model(model, DEVICE, train_loader, optimizer, epoch)test_model(model, DEVICE, test_loader)

结果:

PyTorch手写字体识别相关推荐

  1. PyTorch手写字体识别MNIST

    手写字体识别MNIST 1.准备工作 可以看这个老师的视频进行学习,讲解的非常仔细:视频学习 2.项目代码 2.1 导入模块 # 1.加载相关库 import torch import torch.n ...

  2. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

  3. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  4. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  5. 第六讲 Keras实现手写字体识别分类

    一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...

  6. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  7. Android Studio编写一个手写字体识别程序

    1.activity_main.xml 的代码 <?xml version="1.0" encoding="utf-8"?> <LinearL ...

  8. 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

    人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...

  9. python手写字体程序_深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...

  10. 《MATLAB 神经网络43个案例分析》:第19章 基于SVM的手写字体识别

    <MATLAB 神经网络43个案例分析>:第19章 基于SVM的手写字体识别 1. 前言 2. MATLAB 仿真示例 3. 小结 1. 前言 <MATLAB 神经网络43个案例分析 ...

最新文章

  1. 免费资源:Typicons-免费图标字体
  2. QSettings allKeys读取为空分析
  3. GDCM:gdcm::SOPClassUIDToIOD的测试程序
  4. 详细讲解在Spring中进行集成测试AbstractDependencyInjectionSpringContextTests
  5. 'module' object has no attribute 'Env'
  6. 海量数据切分抽取的实践场景(r11笔记第43天)
  7. IDEA卡顿问题解决-加大内存
  8. 【 数据结构(C语言)】线性表——链表反转
  9. Asterisk的配置详解
  10. java web个人博客开发(一需求获取和需求分析文档)
  11. rs 华为hcip 课件下载_华为路由与交换hcip最新题库
  12. ggplot2如何在R语言中绘制表格
  13. 快速将正式环境的数据同步到本地测试库
  14. 失落的嵌入式 英特尔强推MeeGo意欲何为
  15. OpenCV中图像的深度
  16. 提高信心的十个方法,助你考研坚持到底!
  17. 常用经方的应用体会­
  18. ConcurrentHashMap的实现原理
  19. STM32任意IO模拟8080时序驱动TFTLCD屏
  20. 蠎周刊 188: Jays

热门文章

  1. 顺序表的基本操作实现
  2. 大数据项目实时数据采集流程步骤分析
  3. 开源人脸识别库,face_recognition
  4. 最新四方支付平台源码(PHP版本,完全开源)提供第四方支付系统搭建服务。
  5. 交换机cad图例_网络交换机cad图
  6. php如何让图片自适应屏幕,css如何让图片自适应屏幕大小
  7. idm6.40最新版exe下载器介绍
  8. 空间中点到直线的距离
  9. Typora下载及使用
  10. 1501_FTA失效树分析简介