Faster-RCNN复现

数据准备

主要目标是利用pytorch框架简易复现Faster-RCNN,我选了一个比较简单的数据集VOC2012,这个数据集标注用的xml格式。

数据读取如下:

 def loadXml(path):'''读取原始的XMLlabel文件:return:'''dom = xml.dom.minidom.parse(path)​root = dom.documentElement#获取图片上所有的目标objs = root.getElementsByTagName("object")#获取这张图片的长宽w, h = root.getElementsByTagName("width")[0].childNodes[0].data, root.getElementsByTagName("height")[0].childNodes[0].datainfos = []#把所有目标的信息放到infos中返回for item in objs:c = item.getElementsByTagName("name")[0].childNodes[0].dataxmin = item.getElementsByTagName("xmin")[0].childNodes[0].dataymin = item.getElementsByTagName("ymin")[0].childNodes[0].dataxmax = item.getElementsByTagName("xmax")[0].childNodes[0].dataymax = item.getElementsByTagName("ymax")[0].childNodes[0].datainfos.append((c, int(xmin), int(ymin), int(xmax), int(ymax)))return int(w), int(h), infos

模型配置

关于训练的参数,anchor的size比论文上多两个,提高小目标的检测能力,长宽比还是三个。初始学习率为0.0001,每2轮后乘以0.1。 训练的相关参数配置如下:

 class Config():# anchor的size和长宽比,这里的格式是由pytorch中自带的基础faster_rcnn的输入决定的,参考  # torchvision.models.detection.faster_rcnnanchor_size = ((32, 64, 128, 256, 512),)aspect_ratios = ((0.5, 1, 2),)# 特征网络backbone = "mobilenet_V2"# 目标类型,0是背景,同样是torchvision.models.detection.faster_rcnn的要求cls_label = ['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car','chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike','person', 'pottedplant', 'sheep', 'sofa', 'train','tvmonitor', 'cat'] #VOC所有标签class_num = 20lr = 0.0001epoch = 6batch_size = 2# 模型保存位置FasterRCNN_checkpoints = './checkpoints/FasterRCNN/model.pth'

torchvision中自带的faster_rcnn模块

 from torchvision.models.detection import faster_rcnnfrom torchvision.models.detection.rpn import AnchorGeneratorfrom torchvision.ops import MultiScaleRoIAlign​​# 生成anchoranchor_generator = AnchorGenerator(sizes=cfg.anchor_size, aspect_ratios=cfg.aspect_ratios)​# 设置roipooling,采用ROIAlignroi_pooler = MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)​# 建立FasterRCNN模型net = faster_rcnn.FasterRCNN(backbone,num_classes=cfg.class_num+1,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)

faster_rcnn创建主要需要三块,一个是anchor的生成器,一个是roipooling方式,还有一个是用于提取特征的backbone。

backbone可以用torch中提供的,也可以自己写的,但是,如果是自己写的,最好先进行预训练,我用没有预训练的自定义的backbone去试了下,效果不佳。

关于MultiScaleRoIAlign的第一个参数为啥是['0'],torchvision.models.detection.faster_rcnn.py中有说明。。说是如果backbone返回多个特征图层时,写入需要计算的图层名字,比如['feat1', 'feat3'],但如果只返回一个tensor,也就是只有一个图层返回,就['0']。

faster_rcnn中提供很多参数,可以在建立模型时传入。

 def __init__(self, backbone, num_classes=None,# transform parametersmin_size=800, max_size=1333,image_mean=None, image_std=None,# RPN parametersrpn_anchor_generator=None, rpn_head=None,rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,rpn_nms_thresh=0.7,rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,# Box parametersbox_roi_pool=None, box_head=None, box_predictor=None,box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,box_batch_size_per_image=512, box_positive_fraction=0.25,bbox_reg_weights=None):

训练部分

 from dataset.MyDataset import MyDatasetfrom models.FasterRCNN import FasterRCNNfrom torch.utils.data import DataLoaderfrom Config import Configimport torch.optim as optimimport numpy as npfrom torch.autograd import Variableimport torchfrom utils.DataProcess import barfrom torch.utils.tensorboard import SummaryWriter​​def train():# 调用tensorboard监视训练过程writer = SummaryWriter("runs/summary")cfg = Config()​net = FasterRCNN().net​net.cuda()​trainset = MyDataset('train')​trainLoader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)​params = [p for p in net.parameters() if p.requires_grad]​opt = optim.Adam(params, lr=cfg.lr)​num_train = len(trainLoader)​for epoch in range(1, cfg.epoch + 1):net.train()# 每2轮学习率下降一次if epoch % 2 == 0:for p in opt.param_groups:p['lr'] *= 0.1for n, data in enumerate(trainLoader):imgs = []gtbox = []for i in range(cfg.batch_size):img_path = data[0][i]gtbox_path = data[1][i]img = trainset.load_pic(img_path)img = Variable(img)gtbbox = np.load(gtbox_path)gtbbox = torch.from_numpy(gtbbox)gtbbox = Variable(gtbbox)if torch.cuda.is_available():gtbbox = gtbbox.cuda()img = img.cuda()imgs.append(img)gtbox.append({"boxes": gtbbox[:, 1:], "labels": gtbbox[:, 0]})opt.zero_grad()data = net(imgs, gtbox)l_cls = data['loss_classifier']l_box = data['loss_box_reg']l_obj = data['loss_objectness']l_box_rpn = data['loss_rpn_box_reg']bar('正在第%d轮训练,loss_cls=%.5f,loss_box=%.5f,loss_obj=%.5f,loss_box_rpn=%.5f' %(epoch, l_cls.data, l_box.data, l_obj.data, l_box_rpn.data), n, num_train)loss = l_cls + l_box + l_obj + l_box_rpn# 添加监视目标writer.add_scalar("loss_classifier", l_cls.data, (epoch - 1) * num_train + n)writer.add_scalar("loss_box_reg", l_box.data, (epoch - 1) * num_train + n)writer.add_scalar("loss_objectness", l_obj.data, (epoch - 1) * num_train + n)writer.add_scalar("loss_rpn_box_reg", l_box_rpn.data, (epoch - 1) * num_train + n)loss.backward()opt.step()torch.save(net, cfg.FasterRCNN_checkpoints)​​if __name__ == '__main__':train()

训练部分用tensorboard记录了一下训练过程。这里注意一下,torch提供的faster_rcnn输入包含两部分,一个是图片集,注意不需要resize,第二个是标注集,标注集是个list,里面元素是Map,包括坐标信息boxes和标签信息labels。至于具体怎么读取数据就看自己怎么设计数据了。训练返回结果包括四个损失,anchor的坐标损失、anchor前背景损失、目标框的坐标损失以及目标分类损失。

模型校验

 net = torch.load(cfg.FasterRCNN_checkpoints)net.eval()***************省略***************img = Variable(img)if torch.cuda.is_available():img = img.cuda()res = net([img])res = res[0]boxes = res["boxes"]labels = res["labels"]scores = res["scores"]***************省略***************

校验部分就省略写了,主要两点,第一点,不需要新建一个faster_rcnn再加载数据,就直接torch.load就完事了,一定要记得net.eval();第二点,输入只需要图片列表。校验时,返回结果为坐标、label和得分。最后根据得分进行一下nms即可。

测试效果

自己用的电脑,显卡比较垃圾,就大概练了练,batch_size最大就到2。大概效果如下

源码:https://github.com/Zou-CM/Faster-RCNN-pytorch

Faster-RCNN简易复现相关推荐

  1. python cnn 实例_学习python的算法-Faster RCNN算法复现

    [实例简介]Faster RCNN算法复现 [实例截图] [核心代码]详细见压缩包 ssdetection-master ├── LICENSE ├── README.md ├── cfgs │    ...

  2. 【目标检测】Faster R-CNN的复现

    文章目录 Faster Rcnn 0. 利用Git下载Code 1. 数据准备 2. 模型加载 3. 模型训练 4. 模型测试 5. 运行demo.py 6. 训练自定义Images文件和对应XML文 ...

  3. 可能是史上最详细-Faster RCNN Pytorch 复现全纪录

    向AI转型的程序员都关注了这个号

  4. faster rcnn论文_【论文解读】精读Faster RCNN

    Faster R-CNN论文链接: Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 推荐代 ...

  5. pytorch 实现Faster R-cnn从头开始(一)

    前言 从本章开始就要进入学习faster rcnn的复现了,深入了解目标检测的核心,只有知道等多的细节才能有机会创造和改进,代码很多,所以我也是分章节更新.每次学会一个知识点就可以了.我写的有reti ...

  6. 里程碑式成果Faster RCNN复现难?我们试了一下 | 附完整代码

    作者 | 已退逼乎 来源 | 知乎 [导读]2019年以来,除各AI 大厂私有网络范围外,MaskRCNN,CascadeRCNN 成为了支撑很多业务得以开展的基础,而以 Faster RCNN 为基 ...

  7. (目标检测)Faster R-CNN 论文解读+复现

    Faster R-CNN xyang 声明:本篇文章借用了他人理解,如有侵权,请联系,另如需转载,请注明出处 关于最新最全的目标检测论文,可以查看awesome-object-detection &l ...

  8. ResNet、Faster RCNN、Mask RCNN是专利算法吗?盘点何恺明参与发明的专利

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 前段时间OpenCV正式将SIFT算法的实现从Non-free模块移到主库,因SIFT专利到期了(专利 ...

  9. 【Faster RCNN detectron2】detectron2实现Faster RCNN目标检测

    目录 1. 背景介绍 2.安装步骤 3.Faster RCNN目标检测 4. 效果 5.错误解决 6.参考博客 7,下一节代码解析 在上一篇博客记录了 SlowFast的复现过程,slowfast其中 ...

  10. ResNet、Faster RCNN、Mask RCNN 是专利算法吗?盘点何恺明参与发明的专利!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:我爱计算机视觉,52CV君 AI博士笔记系列推荐 周志华&l ...

最新文章

  1. CAS、原子操作类的应用与浅析及Java8对其的优化
  2. 网络营销外包期间站长如何挖掘用户真实需求探索网络营销外包真谛
  3. livecharts中仪表盘_Vue中使用Echarts仪表盘展示实时数据的实现
  4. Spark读取Kafka因为序列化引起的问题:org.apache.spark.sql.streaming.StreamingQueryException: null
  5. 数据库修改后 前台同步更新 php,PHP实现前台页面与MySQL的数据绑定、同步更新...
  6. 藏在兰州拉面里精益管理秘诀
  7. java List 去除重复元素的五种方式 学习笔记
  8. android+mid播放器,手机midi播放器下载
  9. Java获取时间戳,System.currentTimeMillis() 和 System.nanoTime() 哪个更快?
  10. 仿微信朋友圈发表图片拖拽和删除功能
  11. 恶意代码分析实战Lab3-1
  12. 吴恩达2022机器学习课程评测来了!
  13. [虚拟机] 如何让VMware上的虚拟机识别到U盘
  14. 招聘网站分析-智联招聘网的爬虫设计与实现
  15. 计算机用户名的数值数据是什么,计算机数据最基本的单位是什么?
  16. 华为机考1-54题总结
  17. 【银河麒麟国产服务器安装mysql、nginx和docker遇到的问题】(回忆篇)
  18. 预测分析 Python ARIMA模型预测(学习笔记)
  19. mac vscode插件位置
  20. LeetCode - 794 - 有效的井字游戏 - java

热门文章

  1. 『杭电1726』God’s cutter
  2. Crust Network 与京湘豫等地区块链名企、投资人考察广西区块链科创园
  3. gsm无线热点数据采集服务器,GSM无线网络优化及WLAN热点分析工具开发
  4. Unity Timeline 初识
  5. 在苹果手机上实现虹膜识别(通过改装实现)
  6. 有一种爱情叫做冯小刚与徐帆
  7. 【软件分析/静态程序分析学习笔记】5.数据流分析基础(Data Flow Analysis-Foundations)
  8. 尚福林:建立集团诉讼和股东代表诉讼制度
  9. SDS新书的来龙去脉 amp;amp; SDS序言 - 倪光南:众筹出书也是一种创新
  10. 【遥感数字图像处理】实验:遥感专题地图制作经典流程(Erdas版)