rcnn代码实现_轻松学Pytorch实现自定义对象检测器
点击上方蓝字关注我们
微信公众号:OpenCV学堂
关注获取更多计算机视觉与深度学习知识
大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。
数据集
使用了公开的宠物数据集,下载地址如下:
http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gzhttp://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
对象检测模型的输入是image图像,需要target信息包括:
boxes:表示标注的矩形左上角与右下角坐标(x1,y1,x2,y2)
labels:表示每个标注框的类别,注意从1开始,0永远表示背景
image_id:数据集中图像索引id值,
area:标注框的面积,COCO评估的时候会用到
iscrowd:当iscrowd=true不会参与模型评估计算
从标注xml文件中读取相关信息,完成解析,自定义一个宠物数据集的代码如下:
1class PetDataset(Dataset): 2 def __init__(self, root_dir): 3 self.root_dir = root_dir 4 self.transforms = T.Compose([T.ToTensor()]) 5 self.ann_xmls = list(sorted(os.listdir(os.path.join(root_dir, "annotations/xmls")))) 6 7 def __len__(self): 8 return len(self.ann_xmls) 910 def num_of_samples(self):11 return len(self.ann_xmls)1213 def __getitem__(self, idx):14 # load images and bbox15 bbox_xml_path = os.path.join(self.root_dir, "annotations/xmls", self.ann_xmls[idx])1617 # 读取xml18 dom = parse(bbox_xml_path)19 # 获取文档元素对象20 data = dom.documentElement21 # 获取 objects22 objects = data.getElementsByTagName('object')23 node = data.getElementsByTagName('filename')[0]24 file_ame = node.childNodes[0].nodeValue25 image_path = os.path.join(self.root_dir, "images", file_ame)26 img = cv.imread(image_path)2728 # get bounding box coordinates29 boxes = []30 labels = []31 for object_ in objects:32 # 获取标签中内容33 name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue34 if name == "dog":35 labels.append(np.int(1))36 if name == "cat":37 labels.append(np.int(2))3839 bndbox = object_.getElementsByTagName('bndbox')[0]40 xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)41 ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)42 xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)43 ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)44 boxes.append([xmin, ymin, xmax, ymax])4546 boxes = torch.as_tensor(boxes, dtype=torch.float32)47 # there is only one class48 labels = torch.as_tensor(labels, dtype=torch.int64)4950 image_id = torch.tensor([idx])51 area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])52 iscrowd = torch.zeros((len(objects),), dtype=torch.int64)5354 target = {}55 target["boxes"] = boxes56 target["labels"] = labels57 target["image_id"] = image_id58 target["area"] = area59 target["iscrowd"] = iscrowd60 img, target = self.transforms(img, target)61 return img, target
顺便说一下,这里输入图像通道顺序是BGR!
Faster RCNN模型训练
之前一篇文章中介绍了Faster-RCNN模型与预训练模型使用,这里通过下面的代码加载模型:
num_classes = 2 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=True) device = torch.device('cuda:0') model.to(device)
其中pretrained=False表示训练使用,num_classes 表示对象检测数据集的对象类别,这里只有dog跟cat两个类别,所以num_classes =2
设置好了模型的参数,下面就可以初始化加载数据集,开始正式训练,代码如下:
dataset = PetDataset("D:/pytorch/pet_data")data_loader = torch.utils.data.DataLoader( dataset, batch_size=4, shuffle=True, # num_workers=4, collate_fn=utils.collate_fn)params = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)num_epochs = 8for epoch in range(num_epochs): train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10) lr_scheduler.step()torch.save(model.state_dict(), "faster_rcnn_vehicle_model.pt")
运行结果如下:
如果你的内存不够猛,训练的时候可能会得到下面这个错误:
回去改一下batch size就好了,如果改成1还有这个错误话,就直接砸机器就对了!
模型推理使用
对训练好的模型,加载模型,然后就可以推理预测了,代码演示如下:
1image = cv.imread("D:/images/test.jpg") 2blob = transform(image) 3c, h, w = blob.shape 4input_x = blob.view(1, c, h, w) 5output = model(input_x.cuda())[0] 6boxes = output['boxes'].cpu().detach().numpy() 7scores = output['scores'].cpu().detach().numpy() 8labels = output['labels'].cpu().detach().numpy() 9index = 010for x1, y1, x2, y2 in boxes:11 if scores[index] > 0.5:12 cv.rectangle(image, (np.int32(x1), np.int32(y1)),13 (np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)14 label_id = labels[index]15 label_txt = coco_names[str(label_id)]16 cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 1)17 index += 118cv.imshow("Faster-RCNN Pet Detection", image)19cv.imwrite("D:/pet2.png", image)20cv.waitKey(0)21cv.destroyAllWindows()
运行结果如下:
推荐阅读
轻松学Pytorch–环境搭建与基本语法
Pytorch轻松学-构建浅层神经网络
轻松学pytorch-构建卷积神经网络
轻松学Pytorch –构建循环神经网络
轻松学Pytorch-使用卷积神经网络实现图像分类
轻松学Pytorch-自定义数据集制作与使用
轻松学Pytorch-Pytorch可视化
轻松学Pytorch–Visdom可视化
轻松学Pytorch – 全局池化层详解
轻松学Pytorch – 人脸五点landmark提取网络训练与使用
轻松学Pytorch – 年龄与性别预测
轻松学Pytorch –车辆类型与颜色识别
轻松学Pytorch-全卷积神经网络实现表情识别
使用OpenVINO加速Pytorch表情识别模型
轻松学pytorch – 使用多标签损失函数训练卷积网络
轻松学Pytorch-使用ResNet50实现图像分类
OpenCV4.4 + YOLOv4 真的可以运行了…..
轻松学Pytorch –使用torchvision实现对象检测
伏久者,飞必高
开先者,谢独早
rcnn代码实现_轻松学Pytorch实现自定义对象检测器相关推荐
- 数据集制作_轻松学Pytorch自定义数据集制作与使用
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...
- celeba数据集_轻松学 Pytorch 使用DCGAN实现数据复制
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 DCGAN Ian J. Goodfellow首次提出了GAN之后,生成对抗只是神经网络还不是深度卷积神经网络 ...
- 轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大 ...
- pytorch argmax_轻松学Pytorch使用ResNet50实现图像分类
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的 ...
- 轻松学Pytorch – 人脸五点landmark提取网络训练与使用
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,本文是轻松学Pytorch系列文章第十篇,本文将介绍如何使 ...
- c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)
关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...
- 轻松学Pytorch – 年龄与性别预测
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,上周太忙,没有更新Pytorch轻松学系列文章,但是我还是 ...
- pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构
[机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...
- data后缀文件解码_小白学PyTorch | 17 TFrec文件的创建与读取
[机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 ...
最新文章
- 如何查看已安装的CentOS版本信息
- python 读取txt文件为字典_python将txt文件读取为字典的示例
- Oracle select 基础查询语句 day02
- linux 查看数据库和表 mysql 命令
- 关于Linux C multiple definition of‘XXX’的问题
- 51nod 1577 线性基
- 三维点云数据处理软件供技术原理说明_基于Geomagic Studio的点云数据处理三维建模技术...
- ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
- 学习用PySide写界面
- Java使用ffmpeg将视频转为Mp4格式
- ubuntu系统打开.chm文件方式
- python怎么让图片旋转45度_python – 有没有办法将matplotlib图旋转45度?
- 题目 2322: 大鱼吃小鱼
- 合作模式歌利亚机器人_《歌利亚》画面战斗及机器人制作试玩图文心得 歌利亚好玩吗...
- vue3 使用element表格导出excel表格(带图片)
- js 递归创建文件夹
- RK3288刷机教程:安装Ubuntu 16.04
- for循环 for循环嵌套
- 非真,亦非假——20世纪数学悖论入侵机器学习
- win11系统没有触屏怎么办 Windows11没有触屏的解决方法
热门文章
- FLOATER:更加灵活的Transformer位置编码!
- android第三方launcher,目前Android平台最好的Launcher
- mysql查询字段数据是否有空格_mysql查询条件字段值末尾有空格也能查到数据问题...
- python往mysql存入数据_Python向mysql存入数据出错.
- 多任务学习(MTL)在转化率预估上的应用
- Python学习相关文档
- Excel关于宏的运用
- 深度为你解答怎么避免域名被微信拦截,微信域名防封需要注意哪些问题?
- Rainbond 5.0正式发布, 支持对接管理已有Kubernetes集群...
- 22.Windows及linux下gerapy使用