点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。

模型

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet

这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:

model = torchvision.models.resnet50(pretrained=True).eval().cuda()tf = transforms.Compose([            transforms.Resize(256),            transforms.CenterCrop(224),            transforms.ToTensor(),            transforms.Normalize(            mean=[0.485, 0.456, 0.406],            std=[0.229, 0.224, 0.225]        )])

使用模型实现图像分类

这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:

 1with open('imagenet_classes.txt') as f: 2    labels = [line.strip() for line in f.readlines()] 3 4src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg 5image = cv.resize(src, (224, 224)) 6image = np.float32(image) / 255.0 7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406)) 8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225)) 9image = image.transpose((2, 0, 1))10input_x = torch.from_numpy(image).unsqueeze(0)11print(input_x.size())12pred = model(input_x.cuda())13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()14print(pred_index)15print("current predict class name : %s"%labels[pred_index[0]])16cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)17cv.imshow("input", src)18cv.waitKey(0)19cv.destroyAllWindows()

运行结果如下:

转ONNX支持

在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:

 1with open('imagenet_classes.txt') as f: 2    labels = [line.strip() for line in f.readlines()] 3net = cv.dnn.readNetFromONNX("resnet.onnx") 4src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg 5image = cv.resize(src, (224, 224)) 6image = np.float32(image) / 255.0 7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406)) 8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225)) 9blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)10net.setInput(blob)11probs = net.forward()12index = np.argmax(probs)13cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)14cv.imshow("input", src)15cv.waitKey(0)16cv.destroyAllWindows()

运行结果见上图,这里就不再贴了。

 推荐阅读 

轻松学Pytorch–环境搭建与基本语法

Pytorch轻松学-构建浅层神经网络

轻松学pytorch-构建卷积神经网络

轻松学Pytorch –构建循环神经网络

轻松学Pytorch-使用卷积神经网络实现图像分类

轻松学Pytorch-自定义数据集制作与使用

轻松学Pytorch-Pytorch可视化

轻松学Pytorch–Visdom可视化

轻松学Pytorch – 全局池化层详解

轻松学Pytorch – 人脸五点landmark提取网络训练与使用

轻松学Pytorch – 年龄与性别预测

轻松学Pytorch –车辆类型与颜色识别

轻松学Pytorch-全卷积神经网络实现表情识别

使用OpenVINO加速Pytorch表情识别模型

轻松学pytorch – 使用多标签损失函数训练卷积网络

志不强者智不达

言不信者行不果

pytorch argmax_轻松学Pytorch使用ResNet50实现图像分类相关推荐

  1. rcnn代码实现_轻松学Pytorch实现自定义对象检测器

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发.点赞.留言支 ...

  2. 轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大 ...

  3. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  4. 轻松学Pytorch – 人脸五点landmark提取网络训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,本文是轻松学Pytorch系列文章第十篇,本文将介绍如何使 ...

  5. 轻松学Pytorch – 年龄与性别预测

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,上周太忙,没有更新Pytorch轻松学系列文章,但是我还是 ...

  6. 轻松学Pytorch–环境搭建与基本语法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 基本思路选择 以前我用过Caffe,用过tensorflow,最近 ...

  7. 轻松学Pytorch –使用torchvision实现对象检测

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,前面一篇文章介绍了torchvision的模型ResNet ...

  8. celeba数据集_轻松学 Pytorch 使用DCGAN实现数据复制

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 DCGAN Ian J. Goodfellow首次提出了GAN之后,生成对抗只是神经网络还不是深度卷积神经网络 ...

  9. 【小白学PyTorch】12.SENet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小白学PyTorch | 10 pytorch常见运算详解 小白学Py ...

最新文章

  1. pandas 设置多重索引_Pandas多重索引使用详解
  2. 正确使用Windows Azure 中的VM Role
  3. 多视图几何总结——单应矩阵和基础矩阵的兼容关系
  4. pytorch torch.item()(返回此张量的值作为标准Python数字。 这仅适用于具有一个元素的张量。)
  5. 如何做一份出色的竞品分析?(一)
  6. WinForm连接数据库
  7. GUI实战|Python做一个文档图片提取软件
  8. 【编译打包】tengine 1.5.2
  9. 全国计算机考试真考题库4,全国计算机等级考试无纸化真考题库试卷二级C--(4)资料.docx...
  10. 知乎高赞:如果你是一个 Java 面试官,你会问哪些问题....
  11. [置顶] VS自带工具:dumpbin的使用
  12. gsonformat插件_没用过这些IDEA插件?怪不得写代码头疼
  13. Java运算符和类型转换
  14. exe模拟器android版,安卓exe模拟器
  15. DJ音乐盒-专注DJ
  16. SPSS Modeler 自动分类器学习笔记
  17. 使用conda安装pytorch时出现问题CondaSSLError: OpenSSL appears to be unavailable on this machine.
  18. NuGet是什么?为什么.NET项目中会有NuGet?如何使用NuGet程序包?
  19. java公倍数_java中如何计算最小公倍数
  20. 电视剧《奋斗》精彩对白节选---(三)

热门文章

  1. c汇编语言程序框架培训,[010][x86汇编语言]学习用户程序的编写(c08.asm)
  2. SCOM Rule 介绍 [SCOM中文系列之六]
  3. Cron表达式 详解
  4. vim 常用快捷键总结
  5. linux-pcap 抓包程序框架
  6. 再议 语法高亮插件的选择
  7. GridView 中添加删除确认提示框
  8. __va_rounded_size
  9. 快两年的时间,我都干了啥
  10. Android系统充电系统介绍-预防手机充电爆炸