训练好的模型,使用load
1、直接将读取到的数据放到模型里面出现以下错误提示

原因:图片格式不对,将图片转化为torch.tensor格式

transf = transforms.ToTensor()
image = transf(image)

2、改过格式以后将图片放入模型中如下所示

原因:这个时候图片维度不对。

image = torch.randn(1,1,28,28)

3、核心代码
我把预测代码写到网络类中,关键是predict函数

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Conv2d(1,6,5)self.pool1 = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(6,16,5)self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Linear(16*4*4,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.view(-1,16*4*4)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef predict(self, image):#这个地方是核心image = cv2.resize(image, (28, 28))_, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY_INV)transf = transforms.ToTensor()image = transf(image)image = torch.randn(1, 1, 28, 28)output = self(image)a, predict = torch.max(output.data, 1)print('预测结果为',int(predict[0]))

使用MNIST数据集训练出来的模型预测自己手写数据相关推荐

  1. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  2. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

  3. DL之DNN:利用MultiLayerNetExtend模型【6*100+ReLU+SGD,dropout】对Mnist数据集训练来抑制过拟合

    DL之DNN:利用MultiLayerNetExtend模型[6*100+ReLU+SGD,dropout]对Mnist数据集训练来抑制过拟合 目录 输出结果 设计思路 核心代码 更多输出 输出结果 ...

  4. DL之DNN:利用MultiLayerNet模型【6*100+ReLU+SGD,weight_decay】对Mnist数据集训练来抑制过拟合

    DL之DNN:利用MultiLayerNet模型[6*100+ReLU+SGD,weight_decay]对Mnist数据集训练来抑制过拟合 目录 输出结果 设计思路 核心代码 更多输出 输出结果 设 ...

  5. DL之DNN:利用MultiLayerNet模型【6*100+ReLU+SGD】对Mnist数据集训练来理解过拟合现象

    DL之DNN:利用MultiLayerNet模型[6*100+ReLU+SGD]对Mnist数据集训练来理解过拟合现象 导读 自定义少量的Mnist数据集,利用全连接神经网络MultiLayerNet ...

  6. Keras之DNN:利用DNN算法【Input(8)→12+8(relu)→O(sigmoid)】利用糖尿病数据集训练、评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果)

    Keras之DNN:利用DNN算法[Input(8)→12+8(relu)→O(sigmoid)]利用糖尿病数据集训练.评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果) 目录 输出结果 ...

  7. Keras : 利用卷积神经网络CNN对图像进行分类,以mnist数据集为例建立模型并预测

    我是本期目录酱! 引入 计算机视觉 图像特征 如何区分图像的类别 卷积神经网络 卷积Convolution 卷积层 池化Pooling 卷积神经网络 以mnist数据集为例建立模型并预测 简单分析 p ...

  8. Mnist数据集训练-手写数字的识别

    mnist数据集是一套手写体数字的图像数据集,包含60000个训练样本和10,000个测试集,由纽约大学的Yann LeCun等人维护.它包含各种手写数字图片: 本次实验我们直接将其下载好放在相应文件 ...

  9. paddle - crowdHuman数据集训练人体识别模型

    paddle - crowdHuman数据集训练人体识别模型 数据集annotation crowdhuman的odgt文件各项意义 转换为paddle yolo的格式 输入哪些数据? 输出模型 数据 ...

  10. mnist数据集在caffe(windows)上的训练与测试及对自己手写数字的分类

    以下出自http://www.cnblogs.com/yixuan-xu/p/5858595.html 我按照大神的运算步骤完全正确,只是为了加深理解,自己又重新写了一遍,详情请看上述大神博客. 对m ...

最新文章

  1. java 有没有类似于 requests 爬虫_大数据时代,怎么能不了解“爬虫”是什么?
  2. 机器学习数据预处理之缺失值:众数(mode)填充
  3. ETH Zurich提出新型网络「ROAD-Net」,解决语义分割域适配问题
  4. linux运维基础篇 unit14
  5. python生成字母图片_Python 模拟动态产生字母验证码图片功能
  6. 如何禁止Linux内核的-O2编译选项【转】
  7. Linux系统怎么挂载安卓手机,NFS挂载Android文件系统
  8. 解决 wget 使用 https 下载报错的问题
  9. 【数据库】Mysql的REPLACE()函数替换字符串
  10. thinkphp手机版小说网站源码
  11. 数学趣题——猴子吃桃问题
  12. can't find which disk is full
  13. jQuery动画之显示隐藏动画
  14. php让iframe 重定向,利用可以在iframe中嵌入网页进行重定向
  15. 机器学习(Machine Learning)深度学习(Deep Learning)资料
  16. 计算机组策略恢复,Win10重置组策略编辑器的方法
  17. springBoot整合ElasticSearch【代码直接复制可用】(超级详细)
  18. Android 桌面Widget (小组件)开发详解
  19. 计算机概论二进制加法,计算机科学概论二进制
  20. arcgis用python字段自动编号_属性表字段自动编号

热门文章

  1. UNIX-LINUX编程实践教程-第五章-实例代码注解-echostate.c
  2. [AutoSAR] BSW模块的服务层,重点关注OS部分
  3. ORB-SLAM 解读(二) ORB描述子如何实现旋转不变性
  4. 为什么B+树适合做索引
  5. HanLP: Han Language Processing
  6. 第四季-专题14-串口驱动程序设计
  7. 第三季-第10课-时间编程
  8. python笔记23-unittest单元测试之mock
  9. 【298天】每日项目总结系列036(2017.11.30)
  10. CentOS7/RHEL7 systemd详解