用PyTorch完成手写数字识别

数据集为MNIST

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2# 下载训练集
train_dataset = datasets.MNIST(root='./num/',train=True,transform=transforms.ToTensor(),download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',train=False,transform=transforms.ToTensor(),download=True)# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linearclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),nn.MaxPool2d(2, 2))self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),nn.MaxPool2d(2, 2))self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),nn.BatchNorm1d(120), nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(120, 84),nn.BatchNorm1d(84),nn.ReLU(),nn.Linear(84, 10))# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return xdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(),lr=LR,
)epoch = 1
if __name__ == '__main__':for epoch in range(epoch):sum_loss = 0.0for i, data in enumerate(train_loader):inputs, labels = datainputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()optimizer.zero_grad()  #将梯度归零outputs = net(inputs)  #将数据传入网络进行前向运算loss = criterion(outputs, labels)  #得到损失函数loss.backward()  #反向传播optimizer.step()  #通过梯度做一步参数更新# print(loss)sum_loss += loss.item()if i % 100 == 99:print('[%d,%d] loss:%.03f' %(epoch + 1, i + 1, sum_loss / 100))sum_loss = 0.0net.eval()  #将模型变换为测试模式correct = 0total = 0for data_test in test_loader:images, labels = data_testimages, labels = Variable(images).cuda(), Variable(labels).cuda()output_test = net(images)_, predicted = torch.max(output_test, 1)total += labels.size(0)correct += (predicted == labels).sum()print("correct1: ", correct)print("Test acc: {0}".format(correct.item() /len(test_dataset)))

参考教程

PyTorch手写数字识别(MNIST数据集)
https://blog.csdn.net/weixin_44613063/article/details/90815082

用PyTorch完成手写数字识别相关推荐

  1. 使用Pytorch实现手写数字识别

    使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...

  2. 用PyTorch进行手写数字识别

    目录 数据准备 网络模型 完整实现 数据准备 torch.utils.data.Datasets是PyTorch用来表示数据集的类,它是用PyTorch进行手写数字识别的关键. 下面是加载mnist数 ...

  3. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  4. pytorch实现手写数字识别_送源码!人工智能实现:识别图片中的手写数字,值得收藏...

    作者|小林同学 关注<高手杰瑞>,每天有不一样的实用小教程发布哦! 哈喽,大家好我是杰瑞.今天我给大家带来一个用机器学习的方法来实现手写数字识别的教程,就像C语言中输出的那一行" ...

  5. 使用PyTorch进行手写数字识别,在20 k参数中获得99.5%的精度。

    In this article we'll build a simple convolutional neural network in PyTorch and train it to recogni ...

  6. Pytorch CNN 手写数字识别 0-9

    使用的软件是pycharm 环境是在anaconda下创的虚拟环境pytorch 整个过程大体为,在画板手写数字,用python代码实现手写数字的批量生成,定义超参数,创建数据集包括训练集和数据集,创 ...

  7. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  8. pytorch实现手写数字识别_Paddle和Pytorch实现MNIST手写数字集识别对比

    一.简介 1. Paddle PaddlePaddle是百度自主研发的集深度学习核心框架.工具组件和服务平台为一体的技术领先.功能完备的开源深度学习平台,有全面的官方支持的工业级应用模型,涵盖自然语言 ...

  9. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  10. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

最新文章

  1. Python-ORM实战
  2. 从顶会论文看多模态预训练研究进展
  3. NBU备份之一 Windows操作系统BMR的配置
  4. 剑桥大学申请start up签证的有用的网站
  5. python读取调用摄像头并将读取视频写入视频文件
  6. android Json详解
  7. 他,是数学天才,是多复变解析函数的创始人
  8. bat批处理执行python_.bat批处理添加Python任务
  9. python引入导入自定义模块和外部文件
  10. POJ1741 Tree(点分治)
  11. 51Nod-1051 最大子矩阵和【最大子段和+DP】
  12. android学习笔记25——事件处理Handler
  13. (转)自定义listView及其adapter动态刷新
  14. 4G模块UICC逻辑通道入口+CGLA
  15. 只需要几行代码就可以轻松实现OCR图片转文字
  16. 近期币圈与美股的相关性
  17. 国内主流API市场分析报告
  18. fc模拟器安卓版_【SFC】魂斗罗3-异形战争模拟器情怀通关2020_EVOS
  19. 基于Websocket协议的即时通讯系统设计与实现
  20. java中倒出pdf增加高宽_java – 如何扩展PDF的页面大小以添加水印...

热门文章

  1. SDRAM控制器设计(9)用读写FIFO优化及仿真验证
  2. stm32神舟I号开发板下的六子棋开发
  3. 读书笔记——晶体管电路设计
  4. 网站建设合同- 范文格式
  5. java3d读取3ds文件,基于Java3D与3DSMAX的虚拟校园设计
  6. 北京地铁站经纬度收集
  7. DNN硬件加速器设计4 -- Co-Design and Benchmarking Metrics(MIT)
  8. 打印时电脑蓝屏或重启的解决办法
  9. 希捷ST31000528AS Disk Boot Failure, Insert System Disk and Press Enter和飞利浦的193ei显示器亮的问题
  10. 如何在虚拟机安装鸿蒙os,VirtualBox安装教程