1 案例基本工具概述

1.1 数据集简介

Imagenet数据集共有1000个类别,表明该数据集上的预训练模型最多可以输出1000种不同的分类结果。

  • Imagenet数据集是目前深度学习图像领域应用得非常多的一个领域,关于图像分类、定位、检测等研究工作大多基于此数据集展开。
  • Imagenet数据集文档详细,有专门的团队维护,使用非常方便,在计算机视觉领域研究论文中应用非常广,几乎成为了目前深度学习图像领域算法性能检验的“标准”数据集。
  • Imagenet数据集有1400多万幅图片,涵盖2万多个类别,其中有超过百万的图片有明确的类别标注和图像中物体位置的标注。

1.2 预训练模型

PyTorch中提供了许多在可以被直接加载到模型中并进行器的eNet数据集上训练好的模型,这些模型叫作预训练模型预测。

1.2.1 预训练模型简介

预训练模型都存放在PyTorch的torchvision库中。torchvision库是非常强大的PyTorch视觉处理库,包括分类、目标检测、语义分割等多种计算机视觉任务的预训练模型,还包括图片处理、锚点计算等很多基础工具。

1.2.2 预训练模型简介

2 代码实战

2.1 案例概述

实例描述,将ImageNet数据集上的预训练模型ResNet18加抗到内存,并使用该模型对图片进行分类预测。

2.2 代码实现:下载并加载预训练模型-----ResNetModel.py(第1部分)

from PIL import Image
import matplotlib.pyplot as plt
import json
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import models,transforms # 引入torchvision库
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 1.1 下载并加载预训练模型:引入基础库,并使用torchvision库中的API下载模型。
# Tip:本例使用的中文标签总类别为1001类,索引值为0的类为None,代表未知分类;英文标签总类注意别为1000类,没有None类。
#      因为PyTorch中的模型是在英文标签中训练的,所以在读取中文标签时,还需要将索引值加1
model = models.resnet18(pretrained=True) # True代表要下载模型 ,返回一个具有18层的ResNet模型
model = model.eval()

2.3 代码实现:加载标签并对输入数据进行预处理-----ResNetModel.py(第2部分)

# 1.2 加载标签并对输入数据进行预处理
labels_path = './models_2/code_01/imagenet_class_index.json' # 处理英文标签
with open(labels_path) as json_data:idx_to_labels = json.load(json_data)def getone(onestr):return onestr.replace(',','')
with open('models_2/code_01/中文标签.csv','r+') as f:zh_labels = list(map(getone,list(f)))print(len(zh_labels),type(zh_labels),zh_labels[:5]) # 显示输出中文标签transform = transforms.Compose([transforms.Resize(256), # 将输入图像的尺寸修改为256×256transforms.CenterCrop(224), # 沿中心裁剪得224×224transforms.ToTensor(),transforms.Normalize(   # 图片归一化参数:对图片按照指定的均值与方差进行归一化处理,必须要与模型实际训练的预处理方式一样。mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)

2.4 使用模型进行预测

2.4.1 代码操作概述

打开一个图片文件,并将其输入模型进行预测,同时输出预测结果。

2.4.2 代码实现:使用模型进行预测 -----ResNetModel.py(第3部分)

# 1.3 使用模型进行预测
# -------start-------- 将四通道中代表透明通道的维度A去掉,变为4通道的图片
def preimg(img): # 图片预处理函数:if img.mode == 'RGBA': # 实现兼容RGBA格式的图片信息ch = 4print('ch',ch)a = np.asarray(img)[:,:,:3]img = Image.fromarray(a)return imgim = preimg(Image.open('models_2/code_01/book.png')) # 载入图片
transforms_img = transform(im)  # 调整图片大小
inputimg = transforms_img.unsqueeze(0) # 增加批次维度
# -------end-------- 将四通道中代表透明通道的维度A去掉,变为4通道的图片output = model(inputimg) # 输入模型
output = F.softmax(output,dim=1)  # 获取结果# 从预测结果中取前3名
prediction_score , pred_label_idx = torch.topk(output,3)
prediction_score  = prediction_score.detach().numpy()[0] # 获取结果概率
pred_label_idx = pred_label_idx.detach().numpy()[0] # 获得结果ID
predicted_label = idx_to_labels[str(pred_label_idx[0])][1]#取出标签名称
predicted_label_zh = zh_labels[pred_label_idx[0] + 1 ] #取出中文标签名称
print(' 预测结果:', predicted_label,predicted_label_zh,'预测分数:', prediction_score[0])

2.5 预测结果可视化

2.5.1 可视化代码概述

将预测结果以图的方式显示出来。

2.5.2 代码实战:预测结果可视化-----ResNetModel.py(第4部分)

# 1.4 预测结果可视化
#可视化处理,创建一个1行2列的子图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 8))
fig.sca(ax1)                #设置第一个轴是ax1
ax1.imshow(im)              #第一个子图显示原始要预测的图片
#设置第二个子图为预测的结果,按概率取前3名
barlist = ax2.bar(range(3), [i for i in prediction_score])
barlist[0].set_color('g')         #颜色设置为绿色
#预测结果前3名的柱状图
plt.sca(ax2)
plt.ylim([0, 1.1])
#竖直显示Top3的标签
plt.xticks(range(3), [idx_to_labels[str(i)][1][:15] for i in pred_label_idx ], rotation='vertical')
fig.subplots_adjust(bottom=0.2)    #调整第二个子图的位置
plt.show()                          #显示图像

结果输出:

3  代码总览ResNetModel.py

from PIL import Image
import matplotlib.pyplot as plt
import json
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import models,transforms # 引入torchvision库
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 1.1 下载并加载预训练模型:引入基础库,并使用torchvision库中的API下载模型。
# Tip:本例使用的中文标签总类别为1001类,索引值为0的类为None,代表未知分类;英文标签总类注意别为1000类,没有None类。
#      因为PyTorch中的模型是在英文标签中训练的,所以在读取中文标签时,还需要将索引值加1
model = models.resnet18(pretrained=True) # True代表要下载模型 ,返回一个具有18层的ResNet模型
model = model.eval()# 1.2 还在标签并对输入数据进行预处理
labels_path = './models_2/code_01/imagenet_class_index.json' # 处理英文标签
with open(labels_path) as json_data:idx_to_labels = json.load(json_data)def getone(onestr):return onestr.replace(',','')
with open('models_2/code_01/中文标签.csv','r+') as f:zh_labels = list(map(getone,list(f)))print(len(zh_labels),type(zh_labels),zh_labels[:5]) # 显示输出中文标签transform = transforms.Compose([transforms.Resize(256), # 将输入图像的尺寸修改为256×256transforms.CenterCrop(224), # 沿中心裁剪得224×224transforms.ToTensor(),transforms.Normalize(   # 图片归一化参数:对图片按照指定的均值与方差进行归一化处理,必须要与模型实际训练的预处理方式一样。mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)# 1.3 使用模型进行预测
# -------start-------- 将四通道中代表透明通道的维度A去掉,变为4通道的图片
def preimg(img): # 图片预处理函数:if img.mode == 'RGBA': # 实现兼容RGBA格式的图片信息ch = 4print('ch',ch)a = np.asarray(img)[:,:,:3]img = Image.fromarray(a)return imgim = preimg(Image.open('models_2/code_01/book.png')) # 载入图片
transforms_img = transform(im)  # 调整图片大小
inputimg = transforms_img.unsqueeze(0) # 增加批次维度
# -------end-------- 将四通道中代表透明通道的维度A去掉,变为4通道的图片output = model(inputimg) # 输入模型
output = F.softmax(output,dim=1)  # 获取结果# 从预测结果中取前3名
prediction_score , pred_label_idx = torch.topk(output,3)
prediction_score  = prediction_score.detach().numpy()[0] # 获取结果概率
pred_label_idx = pred_label_idx.detach().numpy()[0] # 获得结果ID
predicted_label = idx_to_labels[str(pred_label_idx[0])][1]#取出标签名称
predicted_label_zh = zh_labels[pred_label_idx[0] + 1 ] #取出中文标签名称
print(' 预测结果:', predicted_label,predicted_label_zh,'预测分数:', prediction_score[0])# 1.4 预测结果可视化
#可视化处理,创建一个1行2列的子图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 8))
fig.sca(ax1)                #设置第一个轴是ax1
ax1.imshow(im)              #第一个子图显示原始要预测的图片
#设置第二个子图为预测的结果,按概率取前3名
barlist = ax2.bar(range(3), [i for i in prediction_score])
barlist[0].set_color('g')         #颜色设置为绿色
#预测结果前3名的柱状图
plt.sca(ax2)
plt.ylim([0, 1.1])
#竖直显示Top3的标签
plt.xticks(range(3), [idx_to_labels[str(i)][1][:15] for i in pred_label_idx ], rotation='vertical')
fig.subplots_adjust(bottom=0.2)    #调整第二个子图的位置
plt.show()                          #显示图像

【Pytorch神经网络实战案例】23 使用ImagNet的预训练模型识别图片内容相关推荐

  1. 【Pytorch神经网络实战案例】11 循环神经网络结构训练语言模型并进行简单预测

    1 语言模型步骤 简单概述:根据输入内容,继续输出后面的句子. 1.1 根据需求拆分任务 (1)先对模型输入一段文字,令模型输出之后的一个文字. (2)将模型预测出来的文字当成输入,再放到模型里,使模 ...

  2. 【Pytorch神经网络实战案例】21 基于Cora数据集实现Multi_Sample Dropout图卷积网络模型的论文分类

    Multi-sample Dropout是Dropout的一个变种方法,该方法比普通Dropout的泛化能力更好,同时又可以缩短模型的训练时间.XMuli-sampleDropout还可以降低训练集和 ...

  3. 【Pytorch神经网络实战案例】18 最大化深度互信信息模型DIM实现搜索最相关与最不相关的图片

    图片搜索器分为图片的特征提取和匹配两部分,其中图片的特征提取是关键.将使用一种基于无监督模型的提取特征的方法实现特征提取,即最大化深度互信息(DeepInfoMax,DIM)方法. 1 最大深度互信信 ...

  4. 【Pytorch神经网络实战案例】40 TextCNN模型分析IMDB数据集评论的积极与消极

    卷积神经网络不仅在图像视觉领域有很好的效果,而且在基于文本的NLP领域也有很好的效果.TextCN如模型是卷积神经网络用于文本处理方面的一个模型. 在TextCNN模型中,通过多分支卷积技术实现对文本 ...

  5. 【Pytorch神经网络实战案例】28 GitSet模型进行步态与身份识别(CASIA-B数据集)

    1 CASIA-B数据集 本例使用的是预处理后的CASIA-B数据集, 数据集下载网址如下. http://www.cbsr.ia.ac.cn/china/Gait%20Databases%20cH. ...

  6. 【Pytorch神经网络实战案例】29 【代码汇总】GitSet模型进行步态与身份识别(CASIA-B数据集)

    1 GaitSet_DataLoader.py import numpy as np # 引入基础库 import os import torch.utils.data as tordata from ...

  7. 【Pytorch神经网络实战案例】27 MaskR-CNN内置模型实现语义分割

    1 PyTorch中语义分割的内置模型 在torchvision库下的models\segmentation目录中,找到segmentation.Py文件.该文件中存放着PyTorch内置的语义分割模 ...

  8. 【Pytorch神经网络实战案例】34 使用GPT-2模型实现句子补全功能(手动加载)

    1 GPT-2 模型结构 GPT-2的整体结构如下图,GPT-2是以Transformer为基础构建的,使用字节对编码的方法进行数据预处理,通过预测下一个词任务进行预训练的语言模型. 1.1 GPT- ...

  9. 【Pytorch神经网络实战案例】24 基于迁移学习识别多种鸟类(CUB-200数据集)

    1 迁移学习 在实际开发中,常会使用迁移学习将预训练模型中的特征提取能力转移到自己的模型中. 1.1 迁移学习定义 迁移学习指将在一个任务上训练完成的模型进行简单的修改,再用另一个任务的数据继续训练, ...

最新文章

  1. 12-09关于几种排序方式
  2. centos下pg_dump的服务器版本不匹配问题
  3. shell for循环1到100_浅谈Linux下shell 编程的for循环常用的6种结构
  4. python决策树的应用_机器学习-决策树实战应用
  5. 圆你国产数据库DBA之梦,达梦DCA培训考试券免费拿
  6. x的奇幻之旅 (史蒂夫·斯托加茨 著)
  7. android传感器开发与智能设备案例实战_【我的物联网成长记2】设备如何进行选型?...
  8. 【RLchina第二讲】 Foundations of Reinforcement Learning
  9. Illustrator 教程,如何在 Illustrator 中连接路径?
  10. JVM(三)——类结构与类加载器
  11. SouthidcEditor编辑器如何支持上传png图片
  12. 新版淘宝宽屏轮播代码带缩略图
  13. SSM框架原理流程及使用方法
  14. mac怎么禁止某个应用联网?
  15. PostgreSQL全文检索
  16. 解决Shiro+SpringBoot异步任务长时间运行导致的UnknownSessionException错误问题
  17. 让你的 Mac 用上最美的屏保,Aerial 使用教程
  18. 计算机中汉字的顺序用什么牌,中国汉字的写做顺序,你知道吗?
  19. 用python提取word到excel(excel可更新)
  20. 解决Unity3d 图片黑边问题

热门文章

  1. 新年来了,上海求职,路过看看
  2. 2013计算机二级试题,2013全国计算机二级上机考试试题46-100套试题
  3. mysql concat例子_MYSQL中CONCAT详解
  4. 机器学习接口和代码之 线性回归
  5. PyQt5案例汇总(简洁版)
  6. 数组拼接时中间怎么加入空格_【题解二维数组】1123:图像相似度
  7. access 根据id删除数据_小程序云开发之数据库自动备份丨云开发101
  8. android string数组转json_移动端开发基础【20】pages.json的配置项pages
  9. Docker swarm 实战-部署wordpress
  10. js 日期天数相加减,格式化yyyy-MM-dd