TorchVision 对象检测微调教程

1. 预训练的Mask R-CNN 模型进行微调

我们将说明如何在 torchvision 中使用新功能,以便在自定义数据集上训练实例细分模型。

2. 定义数据集

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
Penn-Fudan 数据库中对行人检测和分割。 数据集应继承自标准torch.utils.data.Dataset类,并实现__len__和__getitem__。
让我们为此数据集编写一个torch.utils.data.Dataset类。

import os
import numpy as np
import torch
from PIL import Imageclass PennFudanDataset(object):def __init__(self, root, transforms):self.root = rootself.transforms = transforms# load all image files, sorting them to# ensure that they are alignedself.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))def __getitem__(self, idx):# load images ad masksimg_path = os.path.join(self.root, "PNGImages", self.imgs[idx])mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])img = Image.open(img_path).convert("RGB")# note that we haven't converted the mask to RGB,# because each color corresponds to a different instance# with 0 being backgroundmask = Image.open(mask_path)# convert the PIL Image into a numpy arraymask = np.array(mask)# instances are encoded as different colorsobj_ids = np.unique(mask)# first id is the background, so remove itobj_ids = obj_ids[1:]# split the color-encoded mask into a set# of binary masksmasks = mask == obj_ids[:, None, None]# get bounding box coordinates for each masknum_objs = len(obj_ids)boxes = []for i in range(num_objs):pos = np.where(masks[i])xmin = np.min(pos[1])xmax = np.max(pos[1])ymin = np.min(pos[0])ymax = np.max(pos[0])boxes.append([xmin, ymin, xmax, ymax])# convert everything into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)# there is only one classlabels = torch.ones((num_objs,), dtype=torch.int64)masks = torch.as_tensor(masks, dtype=torch.uint8)image_id = torch.tensor([idx])area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])# suppose all instances are not crowdiscrowd = torch.zeros((num_objs,), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["masks"] = maskstarget["image_id"] = image_idtarget["area"] = areatarget["iscrowd"] = iscrowdif self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.imgs)

数据集__getitem__应该返回:

图像:大小为(H, W)的 PIL 图像
目标:包含以下字段的字典
boxes (FloatTensor[N, 4]):[x0, y0, x1, y1]格式的N个边界框的坐标,范围从0至W,从0至H
labels (Int64Tensor[N]):每个边界框的标签。0经常表示背景类
image_id (Int64Tensor[1]):图像标识符。 它在数据集中的所有图像之间应该是唯一的,并在评估过程中使用
area (Tensor[N]):边界框的区域。 在使用 COCO 度量进行评估时,可使用此值来区分小盒子,中盒子和大盒子之间的度量得分。
iscrowd (UInt8Tensor[N]):iscrowd = True 的实例在评估期间将被忽略。
(可选)masks (UInt8Tensor[N, H, W]):每个对象的分割Mask
(可选)keypoints (FloatTensor[N, K, 3]):对于 N 个对象中的每个对象,它包含[x, y, visibility]格式的 K 个关键点,以定义对象。 可见性= 0 表示关键点不可见。 请注意,对于数据扩充,翻转关键点的概念取决于数据表示形式,您可能应该将references/detection/transforms.py修改为新的关键点表示形式
如果您的模型返回上述方法,则它们将使其适用于训练和评估,并将使用pycocotools中的评估脚本

3. 定义模型

在本教程中,我们使用 Mask R-CNN 。

3.1 pre-trained model的微调

The first is when we want to start from a pre-trained model, and just finetune the last layer.
假设您想从在 COCO 上经过预训练的模型开始,并希望针对您的特定类别对其进行微调。 这是一种可行的方法:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

在我们的例子中,由于我们的数据集非常小,我们希望从预训练的模型中进行微调,因此我们将遵循方法 。

这里我们还想计算实例分割掩码,因此我们将使用 Mask R-CNN:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictordef get_model_instance_segmentation(num_classes):# load an instance segmentation model pre-trained pre-trained on COCOmodel = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)# get number of input features for the classifierin_features = model.roi_heads.box_predictor.cls_score.in_features# replace the pre-trained head with a new onemodel.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)# now get the number of input features for the mask classifierin_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channelshidden_layer = 256# and replace the mask predictor with a new onemodel.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,hidden_layer,num_classes)return model

就是这样,这将使model随时可以在您的自定义数据集上进行训练和评估

补充3.2 替换backbone

The other is when we want to replace the backbone of the model with a different one (for faster predictions, for example).

import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios=((0.5, 1.0, 2.0),))# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],output_size=7,sampling_ratio=2)# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,num_classes=2,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)

4. Training and evaluation functions

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.3.0cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

在references/detection/中,我们提供了许多帮助程序功能来简化训练和评估检测模型。 在这里,我们将使用references/detection/engine.py,references/detection/utils.py和references/detection/transforms.py。 只需将它们复制到您的文件夹中并在此处使用它们即可。

Let’s write some helper functions for data augmentation / transformation, which leverages the functions in refereces/detection that we have just copied:

from engine import train_one_epoch, evaluate
import utils
import transforms as Tdef get_transform(train):transforms = []# converts the image, a PIL image, into a PyTorch Tensortransforms.append(T.ToTensor())if train:# during training, randomly flip the training images# and ground-truth for data augmentationtransforms.append(T.RandomHorizontalFlip(0.5))return T.Compose(transforms)

请注意,我们不需要在数据转换中添加均值/标准差归一化或图像缩放,因为这些是由Mask R-CNN模型内部处理的。

查看模型在训练过程中的期望值以及对样本数据的推断时间

遍历数据集之前,最好先查看模型在训练过程中的期望值以及对样本数据的推断时间。

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4,collate_fn=utils.collate_fn)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets)   # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)           # Returns predictions

instantiate 实例化

# use our dataset and defined transformations
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4,collate_fn=utils.collate_fn)data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4,collate_fn=utils.collate_fn)

instantiate the model and the optimizer

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')# our dataset has two classes only - background and person
num_classes = 2# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)

10 epochs


# let's train it for 10 epochs
num_epochs = 10for epoch in range(num_epochs):# train for one epoch, printing every 10 iterationstrain_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)# update the learning ratelr_scheduler.step()# evaluate on the test datasetevaluate(model, data_loader_test, device=device)

what it actually predicts in a test image

# pick one image from the test set
img, _ = dataset_test[0]
# put the model in evaluation mode
model.eval()
with torch.no_grad():prediction = model([img.to(device)])

Printing the prediction shows that we have a list of dictionaries. Each element of the list corresponds to a different image. As we have a single image, there is a single dictionary in the list. The dictionary contains the predictions for the image we passed. In this case, we can see that it contains boxes, labels, masks and scores as fields.

prediction
[{'boxes': tensor([[ 61.7920,  35.8468, 196.2695, 328.1466],[276.3983,  21.7483, 291.1403,  73.4649],[ 79.1629,  42.9354, 201.3314, 207.8434]], device='cuda:0'),'labels': tensor([1, 1, 1], device='cuda:0'),'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0'),'scores': tensor([0.9994, 0.8378, 0.0524], device='cuda:0')}]

convert the image

让我们检查图像和预测的分割蒙版。

为此,我们需要转换图像,该图像已重新缩放为0-1,并且通道已翻转,因此我们将其转换为[C,H,W]格式。

Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

And let’s now visualize the top predicted segmentation mask. The masks are predicted as [N, 1, H, W], where N is the number of predictions, and are probability maps between 0-1.

Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

09月28日 pytorch与resnet(三)预训练的Mask R-CNN 模型进行微调相关推荐

  1. 面试经历---阿里游戏(2020年09月28日晚上7点视频面试)

    9月28日晚上进行了一次视频面试,阿里广州游戏部门,下面说下这次面试的情况 1.自我介绍 介绍了做过的项目,面试官就围绕做过的项目进行深挖. 2.redis的集群方式 如果节点挂掉怎么办? 单个节点的 ...

  2. 2005年09月28日  日本,东京  晴朗

    又一次故地重游,来这里出差,前段日子公司限制上网,所以一直都没有机会来这里看望大家. 咳,在日本觉得没有国内那么随便,也没有朝鲜冷面可以吃. 这次来日本,也赶上了国内吃月饼的节日,而我去吃不到,公司发 ...

  3. 扒一扒HTTPS网站的内幕[2015年09月29日]

    扒一扒HTTPS网站的内幕 野狗 2015年09月28日发布 作者:王继波  野狗科技运维总监,曾在360.TP-Link从事网络运维相关工作,在网站性能优化.网络协议研究上经验丰富. 野狗官博:ht ...

  4. 互联网晚报 | 12月22日 星期三 | 乐视宣布涨薪;小米12系列官宣12月28日发布;好未来推出全新品牌美校...

    ‍ ‍今日看点 ✦ 中国移动:A股IPO发行价57.58元/股,预计募资额560亿元 ✦ 小米12系列手机官宣12月28日正式发布,首发三款机型 ✦ 乐视宣布涨薪,全员信称今年实现经营利润和现金流双平 ...

  5. 10月28日人工智能讲师叶梓为各工科院校老师进行了为期三天的人工智能培训

    10月28日人工智能讲师叶梓为各工科院校老师进行了为期三天的人工智能培训,培训过程中人工智能讲师叶梓与各高校老师就人工智能前沿热点进行热烈的讨论. 根据人力资源和社会保障部办公厅<关于印发专业技 ...

  6. 2017年09月19日泰国清迈曼谷普吉岛三地游

    2017年9月19日,早起亚航飞吉隆坡,转机去清迈.亚航是廉价航空,机上餐食座位.行李托运等都收费,每个人限带7KG行李上飞机,一般是1个小号行李箱+1个背包登机.到了清迈机场后打的150泰铢到清迈古 ...

  7. 《问道》1月28日三区组体验1.43新版

    <问道>1月28日三区组体验1.43新版   发布日期: 2010-1-26 所属类别: 游戏公告   <问道>1.43版神兵"天鉴",经过数月紧锣密鼓的精 ...

  8. PANDAS 数据合并与重塑(concat篇) 原创 2016年09月13日 19:26:30 47784 pandas作者Wes McKinney 在【PYTHON FOR DATA ANALYS

    PANDAS 数据合并与重塑(concat篇) 原创 2016年09月13日 19:26:30 标签: 47784 编辑 删除 pandas作者Wes McKinney 在[PYTHON FOR DA ...

  9. 解密谷歌机器学习工程最佳实践——机器学习43条军规 翻译 2017年09月19日 10:54:58 98310 本文是对Rules of Machine Learning: Best Practice

    解密谷歌机器学习工程最佳实践--机器学习43条军规 翻译 2017年09月19日 10:54:58 983 1 0 本文是对Rules of Machine Learning: Best Practi ...

  10. 个人空间岁末大回报活动12月28日获奖名单

    个人空间岁末大回报: 动手就有C币拿!活动已于15日启动,非常感谢各位网友的大力支持和积极参与,个人空间的所有工作人员在这祝大家好运,希望你们每天都能拿到C币存入社区银行! 欢迎各位获奖者去自己的银行 ...

最新文章

  1. [C#项目开源] MongoDB 可视化管理工具 (2011年10月-至今)
  2. 什么检索是借助计算机技术进行自动标引的,自动文献检索系统
  3. 腾讯员工人均年薪84.7万,马化腾:员工心理健康最重要
  4. 全部编程皆为Web编程
  5. owncloud 配置mysql_傻瓜式搭建私人网络硬盘——owncloud安装指南
  6. 【直播讲座】用友摩天联合光环国际,听国学学项目管理
  7. Express中使用ejs新建项目以及ejs中实现传参、局部视图include、循环列表数据的使用
  8. empinfo Oracle数据库,Oracle数据库中相关技术详细操作
  9. easyui 布局自适应
  10. php iso 8859 1 解码,关于php:Apache的默认编码是ISO-8859-1,但网站是UTF-8?
  11. Python从命令行参数和配置文件获取信息
  12. 单例模式 - 双锁机制
  13. js面向对象编程(三)非构造函数的继承(转载)
  14. Linux之top命令
  15. 电商产品经理:电商后台系统
  16. 马士兵老师Java虚拟机调优
  17. Spring源码解析系列汇总
  18. c语言大学题库pdf,C语言试题库(完整版)..pdf
  19. POJ - 1625 Censored!
  20. 2021年美容师(初级)考试题及美容师(初级)报名考试

热门文章

  1. 利用MapShaper将.shp文件转换成JSON文件
  2. EXCEL IF、AND以及OR函数的嵌套使用
  3. oracle 流标和sql效率,Oracle 中流标使用实例
  4. 使用 Redis 实现一个轻量级的搜索引擎,牛x啊 !
  5. php mysql 组件_Ubuntu20.04安装apache、mysql、php、phpmyadmin、wordpress(一)
  6. php实现观看记录,PHP实现浏览历史记录
  7. python-学生管理系统--8-排序功能模块
  8. XPath 基本语法
  9. python中strptime函数_python datetime中strptime用法详解
  10. 安卓开发 实现文字渐变效果_AI教程!用网格工具做渐变字效