点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大家分享一下,如何使用数据集基于Mask-RCNN训练一个行人检测与实例分割网络。这个例子是来自Pytorch官方的教程,我这里是根据我自己的实践重新整理跟解读了一下,分享给大家。

Mask-RCNN网络模型

前面一篇已经详细分享了关于模型本身,格式化输入与输出的结果。这里使用的预训练模型是ResNet50作为backbone网络,实现模型的参数微调迁移学习。输入的数据是RGB三通道的,取值范围rescale到0~1之间。

数据集介绍与读取

数据集地址下载地址:

https://www.cis.upenn.edu/~jshi/ped_html/

总计170张图像,345个标签行人,数据集采集自两所大学校园。

标注格式兼容Pascal标注格式。

基于Pytorch的DataSet接口类完成继承与使用,得到完成的数据聚集读取类实现代码如下:

from PIL import Image
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import faster_rcnn.transforms as T
import osclass PennFudanDataset(Dataset):def __init__(self, root_dir):self.root_dir = root_dirself.transforms = T.Compose([T.ToTensor()])self.imgs = list(sorted(os.listdir(os.path.join(root_dir, "PNGImages"))))self.masks = list(sorted(os.listdir(os.path.join(root_dir, "PedMasks"))))def __len__(self):return len(self.imgs)def num_of_samples(self):return len(self.imgs)def __getitem__(self, idx):# load images and bboximg_path = os.path.join(self.root_dir, "PNGImages", self.imgs[idx])mask_path = os.path.join(self.root_dir, "PedMasks", self.masks[idx])img = Image.open(img_path).convert("RGB")mask = Image.open(mask_path)# convert the PIL Image into a numpy arraymask = np.array(mask)# instances are encoded as different colorsobj_ids = np.unique(mask)# first id is the background, so remove itobj_ids = obj_ids[1:]# split the color-encoded mask into a set# of binary masksmasks = mask == obj_ids[:, None, None]# get bounding box coordinates for each masknum_objs = len(obj_ids)boxes = []for i in range(num_objs):pos = np.where(masks[i])xmin = np.min(pos[1])xmax = np.max(pos[1])ymin = np.min(pos[0])ymax = np.max(pos[0])boxes.append([xmin, ymin, xmax, ymax])# convert everything into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)# there is only one classlabels = torch.ones((num_objs,), dtype=torch.int64)masks = torch.as_tensor(masks, dtype=torch.uint8)image_id = torch.tensor([idx])area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])# suppose all instances are not crowdiscrowd = torch.zeros((num_objs,), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["masks"] = maskstarget["image_id"] = image_idtarget["area"] = areatarget["iscrowd"] = iscrowdif self.transforms is not None:img, target = self.transforms(img, target)return img, targetif __name__ == "__main__":ds = PennFudanDataset("D:/pytorch/PennFudanPed")for i in range(len(ds)):img, target = ds[i]print(i, img.size(), target)device = torch.device('cuda:0')boxes = target["boxes"]xmin, ymin, xmax, ymax = boxes.unbind(1)targets = [{k: v.to(device) for k, v in t.items()} for t in [target]]if i == 3:break

其中:

  • boxes表示的输入标注框

  • labels表示标签,这里0表示背景,1表示行人,两个分类

  • image_id表示图像标识

  • area表示标注框面积

  • mask对象标记,

模型训练

训练数据集,epoch=8,因为我的计算机内存比较小,所有batchSize=1,不然我就会内存爆炸了,训练一定时间后,就好拉,我把模型保存为mask_rcnn_pedestrian_model.pt文件。训练的代码如下:

# 检查是否可以利用GPU
# torch.multiprocessing.freeze_support()
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:print('CUDA is not available.')
else:print('CUDA is available!')# 背景 + 行人
num_classes = 2
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=True)
device = torch.device('cuda:0')
model.to(device)dataset = PennFudanDataset("D:/pytorch/PennFudanPed")
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True,  # num_workers=4,collate_fn=utils.collate_fn)params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.1)
num_epochs = 8
for epoch in range(num_epochs):train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)lr_scheduler.step()
torch.save(model.state_dict(), "mask_rcnn_pedestrian_model.pt")

上次训练Faster-RCNN的时候有人跟我说训练时候缺失文件,其实torchvision相关的辅助文件可以从这里下载,地址如下:

https://github.com/pytorch/vision/tree/master/references/detection

这样大家就可以自己去下载拉!

模型使用

当我们完成训练之后,就可以使用模型了,这里有个小小的注意点,当训练的时候我加载数据用的是Image.open方法读取图像,得到的是RGB顺序通道图像。在测试的时候我使用OpenCV来读取图像,得到是BGR顺序,所以需要通道顺序转换一下。千万别忘记。加载导出模型,读取测试图像,完成推理预测完整的代码如下:

import torchvision
import torch
import cv2 as cv
import numpy as npmodel = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, pretrained_backbone=True)
model.load_state_dict(torch.load("./mask_rcnn_pedestrian_model.pt"))
model.eval()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])# 使用GPU
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:model.cuda()def object_detection__demo():frame = cv.imread("D:/images/pedestrian_02.png")frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)blob = transform(frame)c, h, w = blob.shapeinput_x = blob.view(1, c, h, w)output = model(input_x.cuda())[0]boxes = output['boxes'].cpu().detach().numpy()scores = output['scores'].cpu().detach().numpy()labels = output['labels'].cpu().detach().numpy()index = 0frame = cv.cvtColor(frame, cv.COLOR_RGB2BGR)for x1, y1, x2, y2 in boxes:if scores[index] > 0.9:print("score: ", scores[index])cv.rectangle(frame, (np.int32(x1), np.int32(y1)), (np.int32(x2), np.int32(y2)), (0, 0, 255), 2, 8, 0)index += 1cv.imshow("Mask-RCNN Demo", frame)cv.imwrite("D:/pedestrian_02mask_rcnn.png", frame)cv.waitKey(0)cv.destroyAllWindows()if __name__ == "__main__":object_detection__demo()

测试了几张张图像,运行结果分别如下:

没想到效果这么好,真的很靠谱!真的实例分割模型,明显提升了检测效果。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用相关推荐

  1. 轻松学Pytorch – 人脸五点landmark提取网络训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,本文是轻松学Pytorch系列文章第十篇,本文将介绍如何使 ...

  2. rcnn代码实现_轻松学Pytorch实现自定义对象检测器

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发.点赞.留言支 ...

  3. pytorch argmax_轻松学Pytorch使用ResNet50实现图像分类

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的 ...

  4. 终极指南:构建用于检测汽车损坏的Mask R-CNN模型(附Python演练)

    阅读时间将近11分钟 介绍 计算机视觉领域的应用继续令人惊叹着.从检测视频中的目标到计算人群中的人数,计算机视觉似乎没有无法克服的挑战. 这篇文章的目的是建立一个自定义Mask R-CNN模型,可以检 ...

  5. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  6. 人脸检测和行人检测2:YOLOv5实现人脸检测和行人检测(含数据集和训练代码)

    人脸检测和行人检测2:YOLOv5实现人脸检测和行人检测(含数据集和训练代码) 目录 人脸检测和行人检测2:YOLOv5实现人脸检测和行人检测(含数据集和训练代码) 1. 前言 2. 人脸检测和行人检 ...

  7. 【Pytorch神经网络理论篇】 33 基于图片内容处理的机器视觉:目标检测+图片分割+非极大值抑制+Mask R-CNN模型

    基于图片内容的处理任务,主要包括目标检测.图片分割两大任务. 1 目标检测 目标检测任务的精度相对较高,主要是以检测框的方式,找出图片中目标物体所在的位置.目标检测任务的模型运算量相对较小,速度相对较 ...

  8. 轻松学Pytorch – 年龄与性别预测

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,上周太忙,没有更新Pytorch轻松学系列文章,但是我还是 ...

  9. Mask R-CNN 模型

    数据准备 要训练 Mask R-CNN 实例分割模型,我们首先要准备图像的掩模(mask),使用标注工具 labelme(支持 Windows 和 Ubuntu,使用 (sudo) pip insta ...

最新文章

  1. ngnix有版本要求吗_联想小新15 2020款值得入手吗?性能怎么样?不可不看的秘密...
  2. [蓝桥杯][2013年第四届真题]危险系数(暴力+dfs)
  3. 漫步数理统计十三——特殊的期望
  4. mybatis源码环境搭建
  5. 客厅的WiFi在主卧收不到,什么方法简单便宜?
  6. 继承(1)----《.NET 2.0面向对象编程揭秘 》学习
  7. c# 获取docx中的内容
  8. linux红帽7修改时间,CentOS 7 and RedHat 7 时间同步即chrony服务配置
  9. spring Aop中切入点和连接点什么关系?
  10. mysql查询自然周_Hive和MySQL中自然周保持一致的方法
  11. Android程序中重启系统,Android调用系统关机与重启功能
  12. 4.2.5 预测分析法与预测分析表的构造
  13. ng-template、ng-container、ng-content 的用法
  14. UVA1616 Caravan Robbers
  15. Win10 Chinese输入法修复/note
  16. 安装Properties Editor
  17. 常用封装电阻的常用电阻阻值
  18. finalcut内存不足_final cut pro 内存不足可以更改缓存空间吗 final cut pr
  19. 19、会员中心 - 小程序端开发 - 微擎小程序模块应用开发
  20. Conmi的正确答案——linux/ubuntu安装web运维工具(Cockpit)

热门文章

  1. 同样是AI技术,为什么只有一加6称得上“全速”旗舰?
  2. Java 8 一行代码解决了空指针问题,太厉害了!
  3. 详解微服务技术中进程间通信
  4. 2020 最烂密码 TOP 200 大曝光!
  5. 让人头痛的大事务问题到底要如何解决?
  6. 图解 SQL 中 JOIN 的各种用法
  7. @AI开发者:薅资源,赢大奖,零成本体验AI开发,这场大赛等你来战!
  8. 收藏!PyTorch常用代码段合集
  9. 机器学习数学基础:常见分布与假设检验
  10. 如何阅读一份深度学习项目代码?