点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。

数据集

使用了公开的宠物数据集,下载地址如下:

http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gzhttp://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz

对象检测模型的输入是image图像,需要target信息包括:

boxes:表示标注的矩形左上角与右下角坐标(x1,y1,x2,y2)
labels:表示每个标注框的类别,注意从1开始,0永远表示背景
image_id:数据集中图像索引id值,
area:标注框的面积,COCO评估的时候会用到
iscrowd:当iscrowd=true不会参与模型评估计算

从标注xml文件中读取相关信息,完成解析,自定义一个宠物数据集的代码如下:

 1class PetDataset(Dataset): 2    def __init__(self, root_dir): 3        self.root_dir = root_dir 4        self.transforms = T.Compose([T.ToTensor()]) 5        self.ann_xmls = list(sorted(os.listdir(os.path.join(root_dir, "annotations/xmls")))) 6 7    def __len__(self): 8        return len(self.ann_xmls) 910    def num_of_samples(self):11        return len(self.ann_xmls)1213    def __getitem__(self, idx):14        # load images and bbox15        bbox_xml_path = os.path.join(self.root_dir, "annotations/xmls", self.ann_xmls[idx])1617        # 读取xml18        dom = parse(bbox_xml_path)19        # 获取文档元素对象20        data = dom.documentElement21        # 获取 objects22        objects = data.getElementsByTagName('object')23        node = data.getElementsByTagName('filename')[0]24        file_ame = node.childNodes[0].nodeValue25        image_path = os.path.join(self.root_dir, "images", file_ame)26        img = cv.imread(image_path)2728        # get bounding box coordinates29        boxes = []30        labels = []31        for object_ in objects:32            # 获取标签中内容33            name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue34            if name == "dog":35                labels.append(np.int(1))36            if name == "cat":37                labels.append(np.int(2))3839            bndbox = object_.getElementsByTagName('bndbox')[0]40            xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)41            ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)42            xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)43            ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)44            boxes.append([xmin, ymin, xmax, ymax])4546        boxes = torch.as_tensor(boxes, dtype=torch.float32)47        # there is only one class48        labels = torch.as_tensor(labels, dtype=torch.int64)4950        image_id = torch.tensor([idx])51        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])52        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)5354        target = {}55        target["boxes"] = boxes56        target["labels"] = labels57        target["image_id"] = image_id58        target["area"] = area59        target["iscrowd"] = iscrowd60        img, target = self.transforms(img, target)61        return img, target

顺便说一下,这里输入图像通道顺序是BGR

Faster RCNN模型训练

之前一篇文章中介绍了Faster-RCNN模型与预训练模型使用,这里通过下面的代码加载模型:

num_classes = 2 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=True) device = torch.device('cuda:0') model.to(device)

其中pretrained=False表示训练使用,num_classes 表示对象检测数据集的对象类别,这里只有dog跟cat两个类别,所以num_classes =2

设置好了模型的参数,下面就可以初始化加载数据集,开始正式训练,代码如下:

dataset = PetDataset("D:/pytorch/pet_data")data_loader = torch.utils.data.DataLoader(     dataset, batch_size=4, shuffle=True,  # num_workers=4,     collate_fn=utils.collate_fn)params = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=0.005,                             momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,                                                step_size=5,                                                gamma=0.1)num_epochs = 8for epoch in range(num_epochs):     train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)     lr_scheduler.step()torch.save(model.state_dict(), "faster_rcnn_vehicle_model.pt")

运行结果如下:

如果你的内存不够猛,训练的时候可能会得到下面这个错误:

回去改一下batch size就好了,如果改成1还有这个错误话,就直接砸机器就对了!

模型推理使用

对训练好的模型,加载模型,然后就可以推理预测了,代码演示如下:

 1image = cv.imread("D:/images/test.jpg") 2blob = transform(image) 3c, h, w = blob.shape 4input_x = blob.view(1, c, h, w) 5output = model(input_x.cuda())[0] 6boxes = output['boxes'].cpu().detach().numpy() 7scores = output['scores'].cpu().detach().numpy() 8labels = output['labels'].cpu().detach().numpy() 9index = 010for x1, y1, x2, y2 in boxes:11    if scores[index] > 0.5:12        cv.rectangle(image, (np.int32(x1), np.int32(y1)),13                     (np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)14        label_id = labels[index]15        label_txt = coco_names[str(label_id)]16        cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 1)17    index += 118cv.imshow("Faster-RCNN Pet Detection", image)19cv.imwrite("D:/pet2.png", image)20cv.waitKey(0)21cv.destroyAllWindows()

运行结果如下:

 推荐阅读 

轻松学Pytorch–环境搭建与基本语法

Pytorch轻松学-构建浅层神经网络

轻松学pytorch-构建卷积神经网络

轻松学Pytorch –构建循环神经网络

轻松学Pytorch-使用卷积神经网络实现图像分类

轻松学Pytorch-自定义数据集制作与使用

轻松学Pytorch-Pytorch可视化

轻松学Pytorch–Visdom可视化

轻松学Pytorch – 全局池化层详解

轻松学Pytorch – 人脸五点landmark提取网络训练与使用

轻松学Pytorch – 年龄与性别预测

轻松学Pytorch –车辆类型与颜色识别

轻松学Pytorch-全卷积神经网络实现表情识别

使用OpenVINO加速Pytorch表情识别模型

轻松学pytorch – 使用多标签损失函数训练卷积网络

轻松学Pytorch-使用ResNet50实现图像分类

OpenCV4.4 + YOLOv4 真的可以运行了…..

轻松学Pytorch –使用torchvision实现对象检测

伏久者,飞必高

开先者,谢独早

rcnn代码实现_轻松学Pytorch实现自定义对象检测器相关推荐

  1. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  2. celeba数据集_轻松学 Pytorch 使用DCGAN实现数据复制

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 DCGAN Ian J. Goodfellow首次提出了GAN之后,生成对抗只是神经网络还不是深度卷积神经网络 ...

  3. 轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大 ...

  4. pytorch argmax_轻松学Pytorch使用ResNet50实现图像分类

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的 ...

  5. 轻松学Pytorch – 人脸五点landmark提取网络训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,本文是轻松学Pytorch系列文章第十篇,本文将介绍如何使 ...

  6. c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

  7. 轻松学Pytorch – 年龄与性别预测

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,上周太忙,没有更新Pytorch轻松学系列文章,但是我还是 ...

  8. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...

  9. data后缀文件解码_小白学PyTorch | 17 TFrec文件的创建与读取

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 ...

最新文章

  1. 如何查看已安装的CentOS版本信息
  2. python 读取txt文件为字典_python将txt文件读取为字典的示例
  3. Oracle select 基础查询语句 day02
  4. linux 查看数据库和表 mysql 命令
  5. 关于Linux C multiple definition of‘XXX’的问题
  6. 51nod 1577 线性基
  7. 三维点云数据处理软件供技术原理说明_基于Geomagic Studio的点云数据处理三维建模技术...
  8. ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
  9. 学习用PySide写界面
  10. Java使用ffmpeg将视频转为Mp4格式
  11. ubuntu系统打开.chm文件方式
  12. python怎么让图片旋转45度_python – 有没有办法将matplotlib图旋转45度?
  13. 题目 2322: 大鱼吃小鱼
  14. 合作模式歌利亚机器人_《歌利亚》画面战斗及机器人制作试玩图文心得 歌利亚好玩吗...
  15. vue3 使用element表格导出excel表格(带图片)
  16. js 递归创建文件夹
  17. RK3288刷机教程:安装Ubuntu 16.04
  18. for循环 for循环嵌套
  19. 非真,亦非假——20世纪数学悖论入侵机器学习
  20. win11系统没有触屏怎么办 Windows11没有触屏的解决方法

热门文章

  1. FLOATER:更加灵活的Transformer位置编码!
  2. android第三方launcher,目前Android平台最好的Launcher
  3. mysql查询字段数据是否有空格_mysql查询条件字段值末尾有空格也能查到数据问题...
  4. python往mysql存入数据_Python向mysql存入数据出错.
  5. 多任务学习(MTL)在转化率预估上的应用
  6. Python学习相关文档
  7. Excel关于宏的运用
  8. 深度为你解答怎么避免域名被微信拦截,微信域名防封需要注意哪些问题?
  9. Rainbond 5.0正式发布, 支持对接管理已有Kubernetes集群...
  10. 22.Windows及linux下gerapy使用