加载模型及对测试数据进行预测p41
基于resnet训练flower图像分类模型(p31-p37)上一篇,我改成别的笔记本跑完了。按照老师的步骤,进行加载模型及测试数据预测。
我们之前是冻住了,只训练一层,也可以全部训练,我的显卡太低跑时间太长了,这部分没跑。
1加载训练好的模型
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)# GPU模式
model_ft = model_ft.to(device)# 保存文件的名字
filename='checkpoint.pth'# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
2测试数据预处理
测试数据处理方法需要跟训练时一致才可以。输入的大小是一致,标准化用跟训练数据相同的mean和std,PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换。
def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width - 224) / 2bottom_margin = (img.height - 224) / 2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin,top_margin))# 相同的预处理方法img = np.array(img) / 255mean = np.array([0.485, 0.456, 0.406]) # provided meanstd = np.array([0.229, 0.224, 0.225]) # provided stdimg = (img - mean) / std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return imgdef imshow(image, ax=None, title=None):"""展示数据"""if ax is None:fig, ax = plt.subplots()# 颜色通道还原image = np.array(image).transpose((1, 2, 0))# 预处理还原mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = std * image + meanimage = np.clip(image, 0, 1)ax.imshow(image)ax.set_title(title)return aximage_path = './data/flower_data/train/3/image_06620.jpg'
img = process_image(image_path)
imshow(img)
在验证集随便选一张花的图片进行测试
3对一个batch的数据进行测试
# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()model_ft.eval()if train_on_gpu:output = model_ft(images.cuda()) #utput表示对一个batch中每一个数据得到其属于各个类别的可能性
else:output = model_ft(images)
output,有8张图片,每个图片有102种分类结果
得到概率最大的那个
_, preds_tensor = torch.max(output, 1) #得到概率最大的那个preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
展示预测结果:
fig=plt.figure(figsize=(20, 12))
columns =4
rows = 2for idx in range (columns*rows):ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])plt.imshow(im_convert(images[idx]))ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show() #绿色名字为预测正确,红色名字为预测错误
看这个图,效果还不错,跟之前模型的85%的准确率比较符合。
加载模型及对测试数据进行预测p41相关推荐
- tensorflow 1.x Saver(保存与加载模型) 预测
20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...
- Pytorch加载模型并进行图像分类预测
目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...
- Keras : 训练minst数据集并加载模型对本地手写图片进行预测
我是本期目录酱 引入 minst数据集介绍 训练模型与测试的py代码分析 训练及测试的py代码(全) 训练及测试结果分析 加载模型并预测本地图片结果 加载模型并预测本地图片py代码(不全) 加载模型并 ...
- tensorflow2实现yolov3并使用opencv4.5.5 DNN加载模型预测
目录 综述 一.什么是YOLO 二.YOLOv3 网络 1.网络结构 2.网络输出解读(前向过程) 2.1.输出特征图尺寸 2.2.锚框和预测 3.训练策略与损失函数(反向过程) 三.tensorfl ...
- mxnet加载模型的params和json文件来预测
导读 有时候我们在使用别人的mxnet预训练模型时,会有两个文件params和json文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构.这里我们以ImageNet的预 ...
- tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)
下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...
- tensorflow 加载模型
训练模型 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt money=np.array([[109 ...
- keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score
keras分类回归的损失函数与评价指标 目标函数 (1)mean_squared_error / mse 均方误差,常用的目标函数,公式为((y_pred-y_true)**2).mean() (2) ...
- 网页怎么预先加载模型_修补预先训练的语言模型
网页怎么预先加载模型 Can you fill in the words that I've removed from a recent announcement? 您能填写我从最近的公告中删除的词吗 ...
最新文章
- 虚拟化服务器端口用万兆,万兆以太网部署需要注意的方面有哪些
- Fedora 23 U盘启动出现“Failed to load ldlinux.c32”解决
- [CF1082E] Increasing Frequency
- jsp点选框_Jsp单选框
- java指的是什么_java什么是实例意思指的是
- 备忘: MIRACL 大数运算库使用手册
- 复制一段话,发现收费怎么办,下边帮你解决
- 实践解决跨域问题的三种方式剖析
- 微信公众号开发 ----微信网页开发config接口注入(3)
- vue中自定义组件(插件)
- Python微信爬虫_00
- Cocos2d-x 着色器
- Ubuntu18.04安装搜狗输入法无法切换中英文
- 现代战争——僵尸网络的历史 上篇
- 拼多多“超级农货节”收官 阳光玫瑰、琯溪蜜柚上榜“超级水果”
- Microsoft Teams 思维导图的4大好处,你知道怎样创建吗?
- 哪款蓝牙耳机性价比高?双十一蓝牙耳机推荐
- barcode4j CODE128/EAN128生成 不定长 msg值 分隔符
- 考进中科院计算所:我的经历和体会
- Qt使用C++封装qml自定义图形控件(QQuickPaintedItem)
热门文章
- unity3d 如何UI优化和减少DC(DrawCall)
- Vue项目中background-image属性设置方法
- 高德地图获取地址坐标
- 用计算机弹出生僻字的歌,抖音生僻字是什么歌?抖音生僻字歌词注音完整版
- 【Python成长之路】Python爬虫 --requests库爬取网站乱码(\xe4\xb8\xb0\xe5\xaf\x8c\xe7\x9)的解决方法
- CentOS 7 安装 TinyProxy 代理服务器
- git密码重置后如何登录
- 腾讯云申请免费SSL证书
- Java之IK 分词器
- 在VMware上安装macOS