神经网络在进行完训练和测试后,如果达到了较高的正确率的话,我们可以尝试将模型用于预测新数据。总共需要两大部分:神经网络、预测函数(新图片的加载,传入模型、得出结果)。


完整代码

import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):  # 神经网络部分用你自己的def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)  # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)self.conv3 = nn.Conv2d(64, 128, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(6272, 128)  # 6272=128*7*7self.fc2 = nn.Linear(128, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = self.conv3(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)self.output = F.log_softmax(x, dim=1)out1 = xreturn self.output,out1def predict():model = Net()model.load_state_dict(torch.load('test.pt'))torch.no_grad()imgfile = glob.glob(r"")  # 输入要预测的图片所在路径print(len(imgfile), imgfile)for i in imgfile:imgfile1 = i.replace("\\", "/")img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)img = cv2.resize(img, (64, 64))  # 是否需要resize取决于新图片格式与训练时的是否一致tran = transforms.ToTensor()img = img.reshape((*img.shape, -1))img = tran(img)img = img.unsqueeze(0)outputs, out1 = model(img)  # outputs,out1修改为你的网络的输出predicted, index  = torch.max(out1, 1)degre = int(index[0])list = [0, 45, -45, -90, 90, 135, -135, 180]print(predicted, list[degre])if __name__ == '__main__':predict()

神经网络部分复制你在训练时定义的神经网络即可,如果模型保存为字典,则需要

model.load_state_dict(torch.load('test.pt'))

新图片的格式需要与训练测试时的图片格式保持一致,所以需要resize,如果新图片为相同格式略过。

最后的list是你样本类别的list,每一类的索引需要与label保持一致,例如:

list = ['裤子', '套衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '短靴']

结果分析

tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135

前面的张量tensor代表了各个类别的“概率”中最大的那一个,然后根据最大“概率”所在的位置(index)来找到list所对应的类别,然后输出。

pytorch 使用训练好的模型预测新数据相关推荐

  1. ML之xgboost:基于xgboost(5f-CrVa)算法对HiggsBoson数据集(Kaggle竞赛)训练实现二分类预测(基于训练好的模型进行新数据预测)

    ML之xgboost:基于xgboost(5f-CrVa)算法对HiggsBoson数据集(Kaggle竞赛)训练实现二分类预测(基于训练好的模型进行新数据预测) 目录 输出结果 设计思路 核心代码 ...

  2. mxnet保存模型,加载模型来预测新数据

    mxnet保存模型,以及用模型来预测新数据 我们希望训练好之后的模型,可以保存下来,然后需要预测新数据的时候,就可以拿来用,可以这样做.  我们以线性回归的例子来讲:  1,训练并保存模型. impo ...

  3. Keras之MLP:利用MLP【Input(8)→(12)(relu)→O(sigmoid+二元交叉)】模型实现预测新数据(利用糖尿病数据集的八个特征实现二分类预测

    Keras之MLP:利用MLP[Input(8)→(12)(relu)→O(sigmoid+二元交叉)]模型实现预测新数据(利用糖尿病数据集的八个特征实现二分类预测 目录 输出结果 实现代码 输出结果 ...

  4. Keras之DNN:利用DNN【Input(8)→(12+8)(relu)→O(sigmoid)】模型实现预测新数据(利用糖尿病数据集的八个特征进行二分类预测

    Keras之DNN:利用DNN[Input(8)→(12+8)(relu)→O(sigmoid)]模型实现预测新数据(利用糖尿病数据集的八个特征进行二分类预测 目录 输出结果 设计思路 实现代码 输出 ...

  5. Keras之DNN:利用DNN算法【Input(8)→12+8(relu)→O(sigmoid)】利用糖尿病数据集训练、评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果)

    Keras之DNN:利用DNN算法[Input(8)→12+8(relu)→O(sigmoid)]利用糖尿病数据集训练.评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果) 目录 输出结果 ...

  6. 持续学习常用6种方法总结:使ML模型适应新数据的同时保持旧数据的性能

    来源:Deep IMBA 本文约4800字,建议阅读9分钟 本文将讨论6种方法,使模型可以在保持旧的性能的同时适应新数据,并避免需要在整个数据集(旧+新)上进行重新训练. 持续学习是指在不忘记从前面的 ...

  7. ML之回归预测:利用13种机器学习算法对Boston(波士顿房价)数据集【13+1,506】进行回归预测(房价预测)+预测新数据得分

    ML之回归预测:利用13种机器学习算法对Boston(波士顿房价)数据集[13+1,506]进行回归预测(房价预测)+预测新数据得分 导读 本文章基于前边的一篇文章,对13种机器学习的回归模型性能比较 ...

  8. ML之xgboost:利用xgboost算法对Boston(波士顿房价)数据集【特征列分段→独热编码】进行回归预测(房价预测)+预测新数据得分

    ML之xgboost:利用xgboost算法对Boston(波士顿房价)数据集[特征列分段→独热编码]进行回归预测(房价预测)+预测新数据得分 导读 对Boston(波士顿房价)数据集进行特征工程,分 ...

  9. SPSS Modeler 中如何利用训练好的模型进行新数据源的预测?

    一.利用训练数据(Demos文件夹中的property_values_train.sav)得到模型块taxable_value: 注意原数据文件中有目标字段taxable_value字段: 需要进行预 ...

最新文章

  1. [日记]一个人去散步
  2. 普通幕僚:Ownership意识不足的几种症状
  3. 无服务器计算将会取代容器?
  4. hdu 4252(单调栈)
  5. wxWidgets:wxMiniFrame类用法
  6. Python 中的 os 模块常见方法
  7. 3-pycharm找不到库的解决办法
  8. 基于DispatchProxy打造自定义AOP组件
  9. Focus Stacking
  10. Julia : 如何进一步改进操作Redis的效率?
  11. DNF私服商业服搭建教程
  12. 关于List转Json的简单方法
  13. MYSQL不能远程连接
  14. JPEG格式压缩算法
  15. esp8266 OTA 云远程更新固件 wifiupdate
  16. 新收集的WAPPUSH代码,并经过改造
  17. “二舅”火了,自媒体短视频“爆火”的基本要素,你知道吗?
  18. Android逆向之玩转Xposed模块以劫持登录为例(实战篇)
  19. 执行 npm install -g grunt-cli 安装grunt发生错误问题
  20. 收费企业邮箱与收费个人邮箱区别,你造吗?

热门文章

  1. DPDK Mempool
  2. python 三维坐标图
  3. 世界坐标系和图像坐标系的对应关系
  4. 【笔试与面试】中软国际
  5. 求奇数立方和和偶数平方和
  6. weblogic可以安装多个吗_有280多个精密部件的“智能手”,真的可以替代人手吗...
  7. 用 MAX7219 点亮 8*8点阵显示屏(倒不如说是 8*8 LED模块)
  8. 十六进制转字符串或char字符数组
  9. 医院信息系统的业务功能详解
  10. 华强北的AirPods 能用吗?(华强北避坑科普分享)