import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 载入训练集transform=transforms.ToTensor(), # 把数据变成tensor类型download = True # 下载)
# 测试集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break


  in_channel: 输入数据的通道数,例RGB图片通道数为3,灰色图通道数为1;

out_channel: 输出数据的通道数,这个根据模型调整;

kennel_size: 卷积核大小,可以是int,或tuple;kennel_size=2,意味着卷积大小2, kennel_size=(2,3),意味着卷积在第一维度大小为2,在第二维度大小为3;

stride:步长,默认为1,与kennel_size类似,stride=2,意味在所有维度步长为2, stride=(2,3),意味着在第一维度步长为2,意味着在第二维度步长为3;

padding: 零填充 如:3X3的卷积窗口就填充1圈零,5X5的卷积窗口就填充2圈零,7X7的卷积窗口就填充3圈零

# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 初始化# nn.Conv2d(1,32,5,1,2): 通道数,输出,卷积窗口 步长,填充几圈0  激活函数relu   最大池化窗口2*2self.conv1 = nn.Sequential(nn.Conv2d(1,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2,2)) # 卷积层self.conv2 = nn.Sequential(nn.Conv2d(32,64,5,1,2),nn.ReLU(),nn.MaxPool2d(2,2)) # 卷积层self.fc1 = nn.Sequential(nn.Linear(64*7*7,500),nn.Dropout(p=0.5),nn.ReLU()) # 全连接层 全连接层 features_in其实就是输入的神经元个数,features_out就是输出神经元个数 64*7*7,1000 64个特征图 大小7*7  输出500个特征图self.fc2 = nn.Sequential(nn.Linear(500,10),nn.Softmax(dim=1)) # 全连接层def forward(self,x):# torch.Size([64, 1, 28, 28])  # 卷积中需要传入4维  批次大小  图像通道数 图片大小x = self.conv1(x)x = self.conv2(x)# torch.Size([64, 1, 28, 28]) -> (64,784)x = x.view(x.size()[0],-1) # 4维变2维 (在全连接层做计算只能2维)x = self.fc1(x)x = self.fc2(x)return x
# 定义模型
model = Net()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.5)# 随机梯度下降
# 定义模型训练和测试的方法
def train():# 模型的训练状态model.train()for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 交叉熵代价函数out(batch,C:类别的数量),labels(batch)loss = mse_loss(out,labels)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():# 模型的测试状态model.eval()correct = 0 # 测试集准确率for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))correct = 0for i,data in enumerate(train_loader): # 训练集准确率# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):print("epoch:",epoch)train()test()

PyTorch基础-使用卷积神经网络CNN实现手写数据集识别-07相关推荐

  1. 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践

    参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...

  2. 手搓卷积神经网络(CNN)进行手写数字识别(python)

    前言: 本文属于学习笔记性质.为了让自己更深入地理解卷积神经网络,我只用numpy.pandas等几个库手搓了一个识别MNIST数字的CNN.500张图单次训练,准确率70-80%. 注意: 1.代码 ...

  3. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  4. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  5. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别. 正文开始! 一.使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器 MNIST是一 ...

  6. 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    大家好,我是红色石头! 在上一篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 详细介绍了卷积神经网络 LeNet-5 的理论部分.今天我们将使用 Pytorch 来实现 LeNet-5 ...

  7. Tensorflow之 CNN卷积神经网络的MNIST手写数字识别

    点击"阅读原文"直接打开[北京站 | GPU CUDA 进阶课程]报名链接 作者,周乘,华中科技大学电子与信息工程系在读. 前言 tensorflow中文社区对官方文档进行了完整翻 ...

  8. pyqt5手写板+pytorch卷积神经网络,实现手写数字识别软件

    卷积神经网络的结构 #定义网络结构 #不是le-net5的结构 class Net(nn.Module):def __init__(self):super(Net, self).__init__()# ...

  9. 神经网络学习(三)比较详细 卷积神经网络原理、手写字体识别(卷积网络实现)

    之前写了一篇基于minist数据集(手写数字0-9)的全连接层神经网络,识别率(85%)并不高,这段时间学习了一些卷积神经网络的知识又实践了一把, 识别率(96%左右)确实上来了 ,下面把我的学习过程 ...

最新文章

  1. python基础知识点大全-【python基础学习】基础重点难点知识汇总
  2. Java的ArrayList集合_JAVA之ArrayList集合
  3. python中单个和批量增加更新的mysql(没有则插入,有则更新)
  4. uid(组件id) = userId + appId (android多用户)
  5. android theme 错误,为什么修改android:theme就崩溃,求助
  6. git config设置用户名_git从安装到多账户操作一套搞定(二)多账户使用
  7. Java开发人员常用网站收录
  8. mysql 5.7日志配置_mysql-5.7日志设置
  9. live2d模型导入unity报错 live2dsdk与Cubism下载 live2dSDKforUnity使用手册
  10. xp计算机限制打开u盘,处置xp系统电脑限制使用u盘的解决方法
  11. 联想用u盘重装系统步骤_练习联想使用u盘重装win7教程
  12. 海南大学计算机网络空间安全学院研究生,海南大学计算机与网络空间安全学院2021考研调剂公告...
  13. Webpack 配置中的一股清流
  14. Android网络数据JSON解析使用总结
  15. 对话“1024程序员节”嘉宾 ——RT-Thread 创始人熊谱翔
  16. [论文阅读]BiSeNet V2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation
  17. fedora 16 linux 配置 MP3 RMVB 解码器
  18. 萧井陌java_萧井陌编程入门指南
  19. 基于Java的学生学费支付系统
  20. 贴吧顶贴php脚步,【技术贴安卓按键精灵】贴吧顶贴脚本源码分享

热门文章

  1. Basic的Json与Xml
  2. java修饰符总结,java访问修饰符总结
  3. java换脸_随意换脸 · ink-image/api Wiki · GitHub
  4. linux date 小写h,linux date 命令详解[转载]
  5. 查看约束信息_【华智产品汇】育种信息安全的守护者——华智育种管家
  6. linux java程序控制台日志输出
  7. android art虚拟机安装,Android中art虚拟机启动流程
  8. php常用操作数组函数,PHP常见数组函数用法小结
  9. oracle删除判断是否存在,oracle创建表之前判断表是否存在,如果存在则删除已有表...
  10. php 组合模式,php设计模式(十三)透明组合模式