基于resnet训练flower图像分类模型(p31-p37)上一篇,我改成别的笔记本跑完了。按照老师的步骤,进行加载模型及测试数据预测。

我们之前是冻住了,只训练一层,也可以全部训练,我的显卡太低跑时间太长了,这部分没跑。

1加载训练好的模型

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)# GPU模式
model_ft = model_ft.to(device)# 保存文件的名字
filename='checkpoint.pth'# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

2测试数据预处理

测试数据处理方法需要跟训练时一致才可以。输入的大小是一致,标准化用跟训练数据相同的mean和std,PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换。


def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width - 224) / 2bottom_margin = (img.height - 224) / 2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin,top_margin))# 相同的预处理方法img = np.array(img) / 255mean = np.array([0.485, 0.456, 0.406])  # provided meanstd = np.array([0.229, 0.224, 0.225])  # provided stdimg = (img - mean) / std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return imgdef imshow(image, ax=None, title=None):"""展示数据"""if ax is None:fig, ax = plt.subplots()# 颜色通道还原image = np.array(image).transpose((1, 2, 0))# 预处理还原mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = std * image + meanimage = np.clip(image, 0, 1)ax.imshow(image)ax.set_title(title)return aximage_path = './data/flower_data/train/3/image_06620.jpg'
img = process_image(image_path)
imshow(img)

在验证集随便选一张花的图片进行测试

3对一个batch的数据进行测试

# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()model_ft.eval()if train_on_gpu:output = model_ft(images.cuda())  #utput表示对一个batch中每一个数据得到其属于各个类别的可能性
else:output = model_ft(images)

output,有8张图片,每个图片有102种分类结果

得到概率最大的那个

_, preds_tensor = torch.max(output, 1)  #得到概率最大的那个preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())

展示预测结果:

fig=plt.figure(figsize=(20, 12))
columns =4
rows = 2for idx in range (columns*rows):ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])plt.imshow(im_convert(images[idx]))ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()  #绿色名字为预测正确,红色名字为预测错误

看这个图,效果还不错,跟之前模型的85%的准确率比较符合。

加载模型及对测试数据进行预测p41相关推荐

  1. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  2. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  3. Keras : 训练minst数据集并加载模型对本地手写图片进行预测

    我是本期目录酱 引入 minst数据集介绍 训练模型与测试的py代码分析 训练及测试的py代码(全) 训练及测试结果分析 加载模型并预测本地图片结果 加载模型并预测本地图片py代码(不全) 加载模型并 ...

  4. tensorflow2实现yolov3并使用opencv4.5.5 DNN加载模型预测

    目录 综述 一.什么是YOLO 二.YOLOv3 网络 1.网络结构 2.网络输出解读(前向过程) 2.1.输出特征图尺寸 2.2.锚框和预测 3.训练策略与损失函数(反向过程) 三.tensorfl ...

  5. mxnet加载模型的params和json文件来预测

    导读 有时候我们在使用别人的mxnet预训练模型时,会有两个文件params和json文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构.这里我们以ImageNet的预 ...

  6. tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)

    下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...

  7. tensorflow 加载模型

    训练模型 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt money=np.array([[109 ...

  8. keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score

    keras分类回归的损失函数与评价指标 目标函数 (1)mean_squared_error / mse 均方误差,常用的目标函数,公式为((y_pred-y_true)**2).mean() (2) ...

  9. 网页怎么预先加载模型_修补预先训练的语言模型

    网页怎么预先加载模型 Can you fill in the words that I've removed from a recent announcement? 您能填写我从最近的公告中删除的词吗 ...

最新文章

  1. 虚拟化服务器端口用万兆,万兆以太网部署需要注意的方面有哪些
  2. Fedora 23 U盘启动出现“Failed to load ldlinux.c32”解决
  3. [CF1082E] Increasing Frequency
  4. jsp点选框_Jsp单选框
  5. java指的是什么_java什么是实例意思指的是
  6. 备忘: MIRACL 大数运算库使用手册
  7. 复制一段话,发现收费怎么办,下边帮你解决
  8. 实践解决跨域问题的三种方式剖析
  9. 微信公众号开发 ----微信网页开发config接口注入(3)
  10. vue中自定义组件(插件)
  11. Python微信爬虫_00
  12. Cocos2d-x 着色器
  13. Ubuntu18.04安装搜狗输入法无法切换中英文
  14. 现代战争——僵尸网络的历史 上篇
  15. 拼多多“超级农货节”收官 阳光玫瑰、琯溪蜜柚上榜“超级水果”
  16. Microsoft Teams 思维导图的4大好处,你知道怎样创建吗?
  17. 哪款蓝牙耳机性价比高?双十一蓝牙耳机推荐
  18. barcode4j CODE128/EAN128生成 不定长 msg值 分隔符
  19. 考进中科院计算所:我的经历和体会
  20. Qt使用C++封装qml自定义图形控件(QQuickPaintedItem)

热门文章

  1. unity3d 如何UI优化和减少DC(DrawCall)
  2. Vue项目中background-image属性设置方法
  3. 高德地图获取地址坐标
  4. 用计算机弹出生僻字的歌,抖音生僻字是什么歌?抖音生僻字歌词注音完整版
  5. 【Python成长之路】Python爬虫 --requests库爬取网站乱码(\xe4\xb8\xb0\xe5\xaf\x8c\xe7\x9)的解决方法
  6. CentOS 7 安装 TinyProxy 代理服务器
  7. git密码重置后如何登录
  8. 腾讯云申请免费SSL证书
  9. Java之IK 分词器
  10. 在VMware上安装macOS