TorchVision中给出了使用ResNet-50-FPN主干(backbone)构建Faster R-CNN的pretrained模型,模型存放位置为https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth,可通过fasterrcnn_resnet50_fpn函数下载,此函数实现在torchvison/models/detection/faster_rcnn.py中,下载后在Ubuntu上存放在~/.cache/torch/hub/checkpoints目录下,在Windows上存放在C:\Users\spring\.cache\torch\hub\checkpoints目录下,其中spring为用户名。

模型的输入是一个tensor列表;每个shape都是[c,h,w];每个shape指定一副图像,并且图像中值的范围为[0,1],即已做过normalized;不同的图像可以有不同的大小,即此模型支持非固定大小图像的输入。

模型的行为取决于它是处于训练模式(training)还是评估模式(evaluation):

(1).在训练期间,模型需要输入tensors和targets(字典列表),包含boxes和labels。

boxes类型为FloatTensor[N,4],其中N为图像数;4为[x1,y1,x2,y2],即ground-truth box的左上和右下角坐标,它们的值要合理范围内。

labels类型为Int64Tensor[N],每个ground-truth box的class label。

(2).在推理(inference)过程中,模型只需要输入tensors,并返回后处理的预测(post-processed predictions),此预测类型为List[Dict[Tensor]],对应每个输入图像。

Dict字段内容除包含boxes和labels外,还包含scores。

scores类型为Tensor[N],每个预测的分值,按照值从大到小的顺序排列。

模型是通过COCO数据集训练获得的,COCO数据集的介绍参考:https://blog.csdn.net/fengbingchun/article/details/121308708

FPN全称为Feature Pyramid Networks,即特征金字塔网络,是一种多尺度的目标检测算法,FPN的介绍参考:https://blog.csdn.net/fengbingchun/article/details/87359191

ResNet即Residual Networks,也称为残差网络,是为了解决深度神经网络的”退化(degradation)”问题。ResNet-50中的50指此网络有50层。ResNet介绍参考:https://blog.csdn.net/fengbingchun/article/details/114167581

Faster R-CNN为目标检测算法,为RPN(Region Proposal Network)和Fast R-CNN的结合。Faster R-CNN介绍参考:https://blog.csdn.net/fengbingchun/article/details/87195597

以下为测试代码:

import torch
from torchvision import models
from torchvision import transforms
import cv2'''
Note: conda pytorch install opencv
windows: conda install opencv # python=3.8.8, opencv=4.0.1
ubuntu: pip3 install opencv-python # python=3.7.11, opencv=4.5.4
'''images_path = "../../data/image/"
images_name = ["1.jpg", "2.jpg", "4.jpg"]
images_data = [] # opencv
tensor_data = [] # pytorch tensorfor name in images_name:img = cv2.imread(images_path + name)print(f"name: {images_path+name}, opencv image shape: {img.shape}") # (w,h,c)images_data.append(img)transform = transforms.Compose([transforms.ToTensor()])tensor = transform(img) # Normalized Tensor image: [0., 1.]print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (c,h,w)tensor_data.append(tensor)# reference: torchvison/models/detection/faster_rcnn.py
# 使用ResNet-50-FPN(Feature Pyramid Networks, 特征金字塔网络)构建Faster RCNN模型
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
#print(model) # 可查看模型结构
model.eval() # 推理
predictions = model(tensor_data) # result: list: boxes (FloatTensor[N, 4]), labels (Int64Tensor[N]), scores (Tensor[N])
#print(predictions)coco_labels_name = ["unlabeled", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat","traffic light", "fire hydrant", "street sign", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse","sheep", "cow", "elephant", "bear", "zebra", "giraffe", "hat", "backpack", "umbrella", "shoe","eye glasses", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball", "kite", "baseball bat","baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "plate", "wine glass", "cup", "fork", "knife","spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza","donut", "cake", "chair", "couch", "potted plant", "bed", "mirror", "dining table", "window", "desk","toilet", "door", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven","toaster", "sink", "refrigerator", "blender", "book", "clock", "vase", "scissors", "teddy bear", "hair drier","toothbrush", "hair brush"] # len = 92for x in range(len(predictions)):pred = predictions[x]scores = pred["scores"]mask = scores > 0.5 # 只取scores值大于0.5的部分boxes = pred["boxes"][mask].int().detach().numpy() # [x1, y1, x2, y2]labels = pred["labels"][mask]scores = scores[mask]print(f"prediction: boxes:{boxes}, labels:{labels}, scores:{scores}")img = images_data[x]for idx in range(len(boxes)):cv2.rectangle(img, (boxes[idx][0], boxes[idx][1]), (boxes[idx][2], boxes[idx][3]), (255, 0, 0))cv2.putText(img, coco_labels_name[labels[idx]]+" "+str(scores[idx].detach().numpy()), (boxes[idx][0]+10, boxes[idx][1]+10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)cv2.imshow("image", img)cv2.waitKey(1000)cv2.imwrite(images_path+"result_"+images_name[x], img)print("test finish")

说明:

(1).输入图像既可以是彩色图也可以是灰度图,即channel为3或1均可。

(2).输入图像的大小不受限制,一组图像可以大小不一致。

(3).输入图像要求normalized到[0., 1.]。

(4).执行结果仅显示scores值大于0.5的情况。

(5).测试代码中类别数为92而不是80,92=1+11+80。其中1为id为0,label name为unlabeled;11为从COCO中移除的label,如street sign;80为真正的label数,如person。详细参考:https://github.com/nightrome/cocostuff/blob/master/labels.md

(6).结果显示中有冗余的检测框,可以通过NMS(Non-Maximum Suppression)非极大值抑制算法移除。

执行结果如下:以下原始测试图像来自网络

GitHub: GitHub - fengbingchun/PyTorch_Test: PyTorch's usage

TorchVision中使用FasterRCNN+ResNet50+FPN进行目标检测相关推荐

  1. PPDet:减少Anchor-free目标检测中的标签噪声,小目标检测提升明显

    本文转载自AI算法修炼营. 这篇文章收录于BMVC2020,主要的思想是减少anchor-free目标检测中的label噪声,在COCO小目标检测上表现SOTA!性能优于FreeAnchor.Cent ...

  2. [Python图像识别] 四十八.Pytorch构建Faster-RCNN模型实现小麦目标检测

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  3. 【FPN车辆目标检测】数据集获取以及Windows7+TensorFlow+Faster-RCNN+FPN代码环境配置和运行过程实测

    PS 最近在学目标检测想用最新的FPN网络,刚好看到这篇博客https://blog.csdn.net/Angela_qin/article/details/80944604尝试把它复现,说的小白一点 ...

  4. bash 脚本中激活conda环境_ubuntu18.10目标检测算法环境部署+开机自启动脚本创建screen下的web服务...

    内容概要 cuda+cudnn+python环境安装 ubuntu18的开机自启动脚本 screen服务开启 以我个人的实践来看,把python开发的算法封装成webserver的服务供前端程序调用是 ...

  5. libtorch学习笔记(17)- ResNet50 FPN以及如何应用于Faster-RCNN

    什么是FPN FPN,即Feature Pyramid Networks,是一种多尺寸,金字塔结构深度学习网络,使用了FPN的Faster-RCNN,其测试结果超过大部分single-model,包括 ...

  6. Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解

    Torchvision更新到0.3.0后支持了更多的功能,其中新增模块detection中实现了整个faster-rcnn的功能.本博客主要讲述如何通过torchvision和pytorch使用fas ...

  7. 目标检测中的anchor-base与anchor-free

    前言 本文参考目标检测阵营 | Anchor-Base vs Anchor-Free 如何评价zhangshifeng最新的讨论anchor based/ free的论文? - 知乎 基础知识 | 目 ...

  8. 【深度学习】一位算法工程师从30+场秋招面试中总结出的超强面经——目标检测篇(含答案)...

    作者丨灯会 来源丨极市平台 编辑丨极市平台 导读 作者灯会为21届中部985研究生,凭借自己整理的面经,去年在腾讯优图暑期实习,七月份将入职百度cv算法工程师.在去年灰飞烟灭的算法求职季中,经过30+ ...

  9. 深度学习阅读导航 | 04 FPN:基于特征金字塔网络的目标检测

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

最新文章

  1. vector容器中重写sort方法
  2. P4720 【模板】扩展卢卡斯定理/exLucas(无讲解,纯记录模板)
  3. 无需复杂插件即可从Eclipse启动和调试Tomcat
  4. 2020教育OMO模式落地应用研究报告
  5. springboot实现定时任务常用的2种方式
  6. [导入] 用java把页面日期控件写出来
  7. heidisql 命令保存blob_git常用命令总结
  8. win10计算机网络共享设置密码,win10系统清除网络共享密码的方法介绍
  9. 使用Photoshop画一个圆锥体
  10. forever保护node server进程报错(node:8156) Warning: Accessing non-existent property ‘padLevels‘ of module e
  11. 23web app实现上下左右滑动
  12. android刷windows教程视频,蓝魔i9s安卓版刷Windows 8.1系统固件视频教程
  13. 好好学习,持续学习,才能持续赚钱
  14. 5.海康威视-Spring boot下实现抓图并保存
  15. Win10电脑清除运行窗口的历史记录
  16. COMSOL多物理场/FDTD时域有限差分/ RSoft光电器件仿真设计“ 几十种案例解析,助您掌握光电器件仿真模拟
  17. linux用sz下载文件夹,linux上很方便的上传下载文件工具rz和sz使用介绍
  18. ThinkPHP优雅草小程序一键生成运营管理系统源码
  19. 5G小基站国产化超五成,美国芯片仅占1%,难怪美国芯片难卖了
  20. Linux笔记(1)

热门文章

  1. 使用Python,EoN模拟网络中的疾病扩散模型,并结合matplotlib绘图
  2. CV算法复现(分类算法2/6):AlexNet(2012年 Hinton组)
  3. Matlab中bwmorph函数的使用
  4. c语言找出比n小的最大质数,C++ 实现求小于n的最大素数的实例
  5. php接收不到ajax请求参数,我是否需要在ajax请求和接收该请求的php之间编码/解码查询参数?...
  6. php与c 哪个好,C语言和PHP,新手选择哪个比较好?
  7. Unity制作2D动作平台游戏视频教程
  8. MIB in SNMP
  9. pmdk -- libpmemlog 介绍
  10. linux进程间通信:system V 信号量和共享内存实现进程间同步