pytorch 使用训练好的模型预测新数据
神经网络在进行完训练和测试后,如果达到了较高的正确率的话,我们可以尝试将模型用于预测新数据。总共需要两大部分:神经网络、预测函数(新图片的加载,传入模型、得出结果)。
完整代码
import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module): # 神经网络部分用你自己的def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 2, 1) # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)self.conv3 = nn.Conv2d(64, 128, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(6272, 128) # 6272=128*7*7self.fc2 = nn.Linear(128, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = self.conv3(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)self.output = F.log_softmax(x, dim=1)out1 = xreturn self.output,out1def predict():model = Net()model.load_state_dict(torch.load('test.pt'))torch.no_grad()imgfile = glob.glob(r"") # 输入要预测的图片所在路径print(len(imgfile), imgfile)for i in imgfile:imgfile1 = i.replace("\\", "/")img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)img = cv2.resize(img, (64, 64)) # 是否需要resize取决于新图片格式与训练时的是否一致tran = transforms.ToTensor()img = img.reshape((*img.shape, -1))img = tran(img)img = img.unsqueeze(0)outputs, out1 = model(img) # outputs,out1修改为你的网络的输出predicted, index = torch.max(out1, 1)degre = int(index[0])list = [0, 45, -45, -90, 90, 135, -135, 180]print(predicted, list[degre])if __name__ == '__main__':predict()
神经网络部分复制你在训练时定义的神经网络即可,如果模型保存为字典,则需要
model.load_state_dict(torch.load('test.pt'))
新图片的格式需要与训练测试时的图片格式保持一致,所以需要resize,如果新图片为相同格式略过。
最后的list是你样本类别的list,每一类的索引需要与label保持一致,例如:
list = ['裤子', '套衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '短靴']
结果分析
tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135
前面的张量tensor代表了各个类别的“概率”中最大的那一个,然后根据最大“概率”所在的位置(index)来找到list所对应的类别,然后输出。
pytorch 使用训练好的模型预测新数据相关推荐
- ML之xgboost:基于xgboost(5f-CrVa)算法对HiggsBoson数据集(Kaggle竞赛)训练实现二分类预测(基于训练好的模型进行新数据预测)
ML之xgboost:基于xgboost(5f-CrVa)算法对HiggsBoson数据集(Kaggle竞赛)训练实现二分类预测(基于训练好的模型进行新数据预测) 目录 输出结果 设计思路 核心代码 ...
- mxnet保存模型,加载模型来预测新数据
mxnet保存模型,以及用模型来预测新数据 我们希望训练好之后的模型,可以保存下来,然后需要预测新数据的时候,就可以拿来用,可以这样做. 我们以线性回归的例子来讲: 1,训练并保存模型. impo ...
- Keras之MLP:利用MLP【Input(8)→(12)(relu)→O(sigmoid+二元交叉)】模型实现预测新数据(利用糖尿病数据集的八个特征实现二分类预测
Keras之MLP:利用MLP[Input(8)→(12)(relu)→O(sigmoid+二元交叉)]模型实现预测新数据(利用糖尿病数据集的八个特征实现二分类预测 目录 输出结果 实现代码 输出结果 ...
- Keras之DNN:利用DNN【Input(8)→(12+8)(relu)→O(sigmoid)】模型实现预测新数据(利用糖尿病数据集的八个特征进行二分类预测
Keras之DNN:利用DNN[Input(8)→(12+8)(relu)→O(sigmoid)]模型实现预测新数据(利用糖尿病数据集的八个特征进行二分类预测 目录 输出结果 设计思路 实现代码 输出 ...
- Keras之DNN:利用DNN算法【Input(8)→12+8(relu)→O(sigmoid)】利用糖尿病数据集训练、评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果)
Keras之DNN:利用DNN算法[Input(8)→12+8(relu)→O(sigmoid)]利用糖尿病数据集训练.评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果) 目录 输出结果 ...
- 持续学习常用6种方法总结:使ML模型适应新数据的同时保持旧数据的性能
来源:Deep IMBA 本文约4800字,建议阅读9分钟 本文将讨论6种方法,使模型可以在保持旧的性能的同时适应新数据,并避免需要在整个数据集(旧+新)上进行重新训练. 持续学习是指在不忘记从前面的 ...
- ML之回归预测:利用13种机器学习算法对Boston(波士顿房价)数据集【13+1,506】进行回归预测(房价预测)+预测新数据得分
ML之回归预测:利用13种机器学习算法对Boston(波士顿房价)数据集[13+1,506]进行回归预测(房价预测)+预测新数据得分 导读 本文章基于前边的一篇文章,对13种机器学习的回归模型性能比较 ...
- ML之xgboost:利用xgboost算法对Boston(波士顿房价)数据集【特征列分段→独热编码】进行回归预测(房价预测)+预测新数据得分
ML之xgboost:利用xgboost算法对Boston(波士顿房价)数据集[特征列分段→独热编码]进行回归预测(房价预测)+预测新数据得分 导读 对Boston(波士顿房价)数据集进行特征工程,分 ...
- SPSS Modeler 中如何利用训练好的模型进行新数据源的预测?
一.利用训练数据(Demos文件夹中的property_values_train.sav)得到模型块taxable_value: 注意原数据文件中有目标字段taxable_value字段: 需要进行预 ...
最新文章
- [日记]一个人去散步
- 普通幕僚:Ownership意识不足的几种症状
- 无服务器计算将会取代容器?
- hdu 4252(单调栈)
- wxWidgets:wxMiniFrame类用法
- Python 中的 os 模块常见方法
- 3-pycharm找不到库的解决办法
- 基于DispatchProxy打造自定义AOP组件
- Focus Stacking
- Julia : 如何进一步改进操作Redis的效率?
- DNF私服商业服搭建教程
- 关于List转Json的简单方法
- MYSQL不能远程连接
- JPEG格式压缩算法
- esp8266 OTA 云远程更新固件 wifiupdate
- 新收集的WAPPUSH代码,并经过改造
- “二舅”火了,自媒体短视频“爆火”的基本要素,你知道吗?
- Android逆向之玩转Xposed模块以劫持登录为例(实战篇)
- 执行 npm install -g grunt-cli 安装grunt发生错误问题
- 收费企业邮箱与收费个人邮箱区别,你造吗?