一. 代码准备

基于pytorch。

mask scoring rcnn 代码参考:【github】

mask rcnn benchmark 【github】

二. 环境安装

1. 基于conda创建pytorch环境:

conda create -n pytorch python=3.7.4
conda install ipython
conda install -c pytorch pytorch-nightly torchvision=0.2.1 cudatoolkit=9.0 # 注:必须9.0
conda activate pytorch
pip install numpy scipy ninja yacs cython matplotlib tqdm opencv-python

注:pip install torchvision==0.2.1,否则会出现 AttributeError: 'list' object has no attribute 'resize' #45

参考 https://github.com/zjhuang22/maskscoring_rcnn/issues/45

2. 安装cocoapi & apex:

export INSTALL_DIR=$PWD# install pycocotools
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install# install apex
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext

3. 安装detection:

# install PyTorch Detection
cd $INSTALL_DIR#maskrcnn-benchmark
#git clone https://github.com/facebookresearch/maskrcnn-benchmark.gitgit clone https://github.com/zjhuang22/maskscoring_rcnncd maskscoring_rcnn
python setup.py build develop

三. 数据准备

训练数据基于labelme标注,需要转换成coco格式,目前常用的是 labelme2coco.py,可以找到比较多的code,直接用就好了。

# -*- coding:utf-8 -*-
import os, sys
import argparse
import json
import matplotlib.pyplot as plt
import skimage.io as io
from labelme import utils
import numpy as np
import glob
import PIL.Imageclass MyEncoder(json.JSONEncoder):def default(self, obj):if isinstance(obj, np.integer):return int(obj)elif isinstance(obj, np.floating):return float(obj)elif isinstance(obj, np.ndarray):return obj.tolist()else:return super(MyEncoder, self).default(obj)class labelme2coco(object):def __init__(self, labelme_json=[], save_json_path='./tran.json'):''':param labelme_json: 所有labelme的json文件路径组成的列表:param save_json_path: json保存位置'''self.labelme_json = labelme_jsonself.save_json_path = save_json_pathself.images = []self.categories = []self.annotations = []# self.data_coco = {}self.label = []self.annID = 1self.height = 0self.width = 0self.save_json()def data_transfer(self):for num, json_file in enumerate(self.labelme_json):with open(json_file, 'r') as fp:data = json.load(fp)  # 加载json文件self.images.append(self.image(data, num))for shapes in data['shapes']:label = shapes['label']if label not in self.label:self.categories.append(self.categorie(label))self.label.append(label)points = shapes['points']  # 这里的point是用rectangle标注得到的,只有两个点,需要转成四个点points.append([points[0][0], points[1][1]])points.append([points[1][0], points[0][1]])self.annotations.append(self.annotation(points, label, num))self.annID += 1def image(self, data, num):image = {}img = utils.img_b64_to_arr(data['imageData'])  # 解析原图片数据# img=io.imread(data['imagePath']) # 通过图片路径打开图片# img = cv2.imread(data['imagePath'], 0)height, width = img.shape[:2]img = Noneimage['height'] = heightimage['width'] = widthimage['id'] = num + 1image['file_name'] = data['imagePath'].split('/')[-1]self.height = heightself.width = widthreturn imagedef categorie(self, label):categorie = {}categorie['supercategory'] = 'Cancer'categorie['id'] = len(self.label) + 1  # 0 默认为背景categorie['name'] = labelreturn categoriedef annotation(self, points, label, num):annotation = {}annotation['segmentation'] = [list(np.asarray(points).flatten())]annotation['iscrowd'] = 0annotation['image_id'] = num + 1# annotation['bbox'] = str(self.getbbox(points)) # 使用list保存json文件时报错(不知道为什么)# list(map(int,a[1:-1].split(','))) a=annotation['bbox'] 使用该方式转成listannotation['bbox'] = list(map(float, self.getbbox(points)))annotation['area'] = annotation['bbox'][2] * annotation['bbox'][3]# annotation['category_id'] = self.getcatid(label)annotation['category_id'] = self.getcatid(label)  # 注意,源代码默认为1annotation['id'] = self.annIDreturn annotationdef getcatid(self, label):for categorie in self.categories:if label == categorie['name']:return categorie['id']return 1def getbbox(self, points):# img = np.zeros([self.height,self.width],np.uint8)# cv2.polylines(img, [np.asarray(points)], True, 1, lineType=cv2.LINE_AA)  # 画边界线# cv2.fillPoly(img, [np.asarray(points)], 1)  # 画多边形 内部像素值为1polygons = pointsmask = self.polygons_to_mask([self.height, self.width], polygons)return self.mask2box(mask)def mask2box(self, mask):'''从mask反算出其边框mask:[h,w]  0、1组成的图片1对应对象,只需计算1对应的行列号(左上角行列号,右下角行列号,就可以算出其边框)'''# np.where(mask==1)index = np.argwhere(mask == 1)rows = index[:, 0]clos = index[:, 1]# 解析左上角行列号left_top_r = np.min(rows)  # yleft_top_c = np.min(clos)  # x# 解析右下角行列号right_bottom_r = np.max(rows)right_bottom_c = np.max(clos)# return [(left_top_r,left_top_c),(right_bottom_r,right_bottom_c)]# return [(left_top_c, left_top_r), (right_bottom_c, right_bottom_r)]# return [left_top_c, left_top_r, right_bottom_c, right_bottom_r]  # [x1,y1,x2,y2]return [left_top_c, left_top_r, right_bottom_c - left_top_c,right_bottom_r - left_top_r]  # [x1,y1,w,h] 对应COCO的bbox格式def polygons_to_mask(self, img_shape, polygons):mask = np.zeros(img_shape, dtype=np.uint8)mask = PIL.Image.fromarray(mask)xy = list(map(tuple, polygons))PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)mask = np.array(mask, dtype=bool)return maskdef data2coco(self):data_coco = {}data_coco['images'] = self.imagesdata_coco['categories'] = self.categoriesdata_coco['annotations'] = self.annotationsreturn data_cocodef save_json(self):self.data_transfer()self.data_coco = self.data2coco()# 保存json文件json.dump(self.data_coco, open(self.save_json_path, 'w'), indent=4, cls=MyEncoder)  # indent=4 更加美观显示if __name__ == '__main__':src_folder = os.path.abspath(sys.argv[1])# load src - join jsonlabelme_json = glob.glob(src_folder+'/*.json')labelme2coco(labelme_json, sys.argv[2])

建立datasets文件夹,执行转换脚本:

pip install labelme scikit-image
cd datasets
mkdir annotations# convert
python labelme2coco.py xxx_train annotations/xxx_train.json
python labelme2coco.py xxx_test annotations/xxx_test.json

四. 修改参数,训练模型

1. 修改 configs下的训练文件

选择你的训练脚本 e2e_ms_rcnn_R_50_FPN_1x.yaml(或者e2e_mask_rcnn_R_50_FPN_1x.yaml) 里的配置项:

MODEL:META_ARCHITECTURE: "GeneralizedRCNN"WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
DATASETS:TRAIN: ("coco_train_xxx",) # 1.设置训练验证集TEST: ("coco_val_xxx",)

2. 修改 maskrcnn_benchmark/config 下的 paths_catalog.py 配置项,增加数据集路径:

如果你用的是maskscoring_rcnn,采用这种结构:

"coco_train_xxx": ("xxx_train", "annotations/xxx_train.json"),
"coco_val_xxx": ("xxx_test", "annotations/xxx_test.json"),

如果用的是 maskrcnn-benchmark,采用这种结构(获取方式对应上就ok):

"coco_train_xxx": {"img_dir": "xxx_train","ann_file": "annotations/xxx_train.json"
},
"coco_val_xxx": {"img_dir": "xxx_test","ann_file": "annotations/xxx_test.json"
},

3. 修改 maskrcnn_benchmark/config 下的 defaults.py 配置项,设置训练参数:

_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 3 # 1.修改分类数量,coco对应81(80+1)_C.MODEL.RETINANET.NUM_CLASSES = 3 # 1.修改类别,默认为81(采用retinaNet修改此项)_C.SOLVER.BASE_LR = 0.0005   # 2.修改学习率,默认为0.001
_C.SOLVER.CHECKPOINT_PERIOD = 1000  # 3.修改check point数量,根据需要自定义
_C.SOLVER.IMS_PER_BATCH = 4   # 4.修改batch size,默认16_C.TEST.IMS_PER_BATCH = 4   # 5.修改test batch size,默认8_C.OUTPUT_DIR = "weights/"   # 6.设置模型保存路径(对应自定义文件夹)

4. 执行训练:

python tools/train_net.py --config-file configs/e2e_ms_rcnn_R_50_FPN_1x.yaml
python tools/test_net.py --config-file configs/e2e_ms_rcnn_R_50_FPN_1x.yaml

训练成功可以看到log日志:

2019-10-09 16:40:32,881 maskrcnn_benchmark.trainer INFO: eta: 6:26:27  iter: 80  loss: 0.7165 (0.8410)  loss_classifier: 0.0824 (0.1038)  loss_box_reg: 0.0734 (0.0740)  loss_mask: 0.5274 (0.6152)  loss_objectness: 0.0145 (0.0446)  loss_rpn_box_reg: 0.0010 (0.0034)  time: 0.2501 (0.2579)  data: 0.0047 (0.0115)  lr: 0.008800  max mem: 1692
2019-10-09 16:40:37,926 maskrcnn_benchmark.trainer INFO: eta: 6:24:41  iter: 100  loss: 0.6458 (0.8120)  loss_classifier: 0.0851 (0.1024)  loss_box_reg: 0.0731 (0.0735)  loss_mask: 0.4559 (0.5938)  loss_objectness: 0.0120 (0.0388)  loss_rpn_box_reg: 0.0010 (0.0035)  time: 0.2511 (0.2567)  data: 0.0048 (0.0101)  lr: 0.009333  max mem: 1692
2019-10-09 16:40:43,004 maskrcnn_benchmark.trainer INFO: eta: 6:23:54  iter: 120  loss: 0.6602 (0.7915)  loss_classifier: 0.0995 (0.1036)  loss_box_reg: 0.0801 (0.0758)  loss_mask: 0.4405 (0.5739)  loss_objectness: 0.0144 (0.0348)  loss_rpn_box_reg: 0.0010 (0.0033)  time: 0.2527 (0.2563)  data: 0.0048 (0.0093)  lr: 0.009867  max mem: 1692

注:可能会遇到 maskrcnn_benchmark/utils/model_zool.py 报错,修改接口即可:

from torch.hub import _download_url_to_file
from torch.hub import urlparse
from torch.hub import HASH_REGEX

五. 预测

1. 修改对应 yaml 文件的WEIGHT(改成自己训练好的权重):

这里对应的是 e2e_ms_rcnn_R_50_FPN_1x.yaml

MODEL:META_ARCHITECTURE: "GeneralizedRCNN"WEIGHT: "weights/model_0080000.pth"BACKBONE:CONV_BODY: "R-50-FPN"RESNETS:BACKBONE_OUT_CHANNELS: 256

2. 如果用的是maskrcnn-benchmark 的话,需要修改demo 文件夹下的 predictor.py(将里面的类别标签改为自定义标签):

如果是mask scoring ,需要将这个文件 copy过来并按照下面修改(可能提示缺少工具,需要把对应util下的文件也copy过来):

class COCODemo(object):# COCO categories for pretty printCATEGORIES = ["__background","cls1","cls2",]

在 demo 文件夹下新建文件 predict.py,用于执行预测:

#!/usr/bin/env python
# coding=UTF-8import os, sys
import numpy as np
import cv2
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo# 1.修改后的配置文件
config_file = "../configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"# 2.配置
cfg.merge_from_file(config_file) # merge配置文件
cfg.merge_from_list(["MODEL.MASK_ON", True]) # 打开mask开关
cfg.merge_from_list(["MODEL.DEVICE", "cuda"]) # or设置为CPU ["MODEL.DEVICE", "cpu"]coco_demo = COCODemo(cfg,min_image_size=800,confidence_threshold=0.5, # 3.设置置信度
)# test
if __name__ == '__main__':in_folder = os.path.abspath(sys.argv[1]) # '../datasets/test_images/'out_folder = os.path.abspath(sys.argv[2]) # "../datasets/test_images_out/"if not os.path.exists(out_folder):os.makedirs(out_folder)for file_name in os.listdir(in_folder):if not file_name.endswith(('jpg', 'png')):continue# load fileimg_path = os.path.join(in_folder, file_name)image = cv2.imread(img_path)# method1. 直接得到opencv图片结果#predictions = coco_demo.run_on_opencv_image(image)#save_path = os.path.join(out_folder, file_name)#cv2.imwrite(save_path, predictions)# method2. 获取预测结果predictions = coco_demo.compute_prediction(image)predictions = coco_demo.compute_prediction(image)top_predictions = coco_demo.select_top_predictions(predictions)# drawimg = coco_demo.overlay_boxes(image, top_predictions)img = coco_demo.overlay_mask(img, predictions)img = coco_demo.overlay_class_names(img, top_predictions)save_path = os.path.join(out_folder, file_name)cv2.imwrite(save_path, img)# print resultsboxes = top_predictions.bbox.numpy()labels = top_predictions.get_field("labels").numpy()  #label = labelList[np.argmax(scores)]scores = top_predictions.get_field("scores").numpy()masks = top_predictions.get_field("mask").numpy()for i in range(len(boxes)):print('box:', i, ' label:', labels[i])x1,y1,x2,y2 = [round(x) for x in boxes[i]] # = map(int, boxes[i])print('x1,y1,x2,y2:', x1,y1,x2,y2)

3. 执行预测脚本,获取结果:

CUDA_VISIBLE_DEVICES=0 python demo/predict.py

Mask Scoring RCNN训练自己的数据相关推荐

  1. 感知算法论文(四):Mask Scoring R-CNN (2019)译文

    文章目录 摘要 1. 引言 2. 相关工作 2.1 实例分割 2.2 检测得分校正 3. 方法 3.1 动机 3.2 Mask scoring in Mask R-CNN 4. 实验 4.1 实验细节 ...

  2. Mask Scoring R-CNN

    Mask Scoring R-CNN CVPR2019会议论文 主要任务不是实例分割,而是评估获得的实例分割的掩码的质量.文中提到,以前通常用分类类别的置信度来评估分割的掩码的质量,这是没有说服力的, ...

  3. Mask Scoring R-CNN[详解]

    Mask Scoring R-CNN:Path Aggregation Network for Instance Segmentation(CVPR2019 oral) (消息来源: 性能超越何恺明M ...

  4. 【Mask scoring RCNN】实现目标检测

    Mask scoring RCNN 数据集准备 1. 标注数据集 使用labelme标注原始数据,每一张图片都会生成json文件.labelme标注工具的使用可以查看笔者的博客[labelme]数据标 ...

  5. opencv 阈值分割_CVPR2019实例分割Mask Scoring RCNN

    点击上方↑↑↑"OpenCV学堂"关注我 欢迎留言,参与互动讨论,发表自己的看法 作者博客: https://blog.csdn.net/linolzhang 今年的Oral,在c ...

  6. Mask Scoring Rcnn论文解读《Mask Scoring R-CNN》

    参考链接: 论文链接<Mask Scoring R-CNN> Github 地址 Mask Scoring RCNN 在大多数实例分割框架中,实例分类的置信度被用作MASK质量分数. MA ...

  7. Mask Scoring R-CNN——源码运行记录

    Mask Scoring R-CNN--源码运行记录 最近在跑该模型,遇到了很多问题,github上给的东西不足以将这个模型给正常运行起来,所以在此记录一下 github源码地址 1.环境说明 各个版 ...

  8. Mask RCNN -- Mask Scoring R-CNN

    https://zhuanlan.zhihu.com/p/37998710https://zhuanlan.zhihu.com/p/37998710https://blog.csdn.net/qq_3 ...

  9. Mask Scoring R-CNN论文阅读

    为了以后的学习方便,把几篇计算机视觉的论文翻译放上来,仅为自己的学习方便,本文仅将自己感兴趣部分简单翻译.排版对手机端不友好,欢迎各位指正. 为提高实例分割的性能,该论文寻找了一个新的方向--对生成的 ...

最新文章

  1. Java网络编程笔记4
  2. 戴尔电脑 linux ssh,使用SSH管理Dell iDRAC远程控制卡
  3. c语言经典编程案例猜数字,用c语言编程猜数字
  4. 为徐小斌、张悦然两篇小说写的推荐语
  5. Spring事物的实现方式和原理以及隔离级别
  6. 推荐一款神器-VBAC#代码编辑管理器
  7. 矩阵特征值的物理意义
  8. 《八佰》正式上映不到两天 累计票房破6亿元
  9. 关于spring+springMVC+myBatis的一些基础配置以及整合
  10. RDD DataFrame DataSet 区别和转换
  11. java mybatis缓存机制_mybatis缓存机制与陷阱
  12. python整体设计目标怎么写_设计模式及Python实现
  13. 最常用的五种数据分析方法,建议收藏!
  14. oracle 币种符号,Oracle用户密码使用特殊符号,例如(AND)、$(Dollar)、#(Pound)、*(Star)等...
  15. 树莓派安装MPlayer播放器
  16. 论文写作: Abstract 和 Introduction 的 区别
  17. 组卷与考试系统_题库添加选择题模块
  18. PS滤镜给城市夜空照片添加满天星
  19. 五笔中词组的输入技巧
  20. 开关、电机、断路器、电热偶、电表接线图大全

热门文章

  1. 如何快速将图片中的文字提取出来
  2. dotnetbar-SuperTabControl禁止调整顺序
  3. 细说shiro之一:shiro简介
  4. KNN算法(二) sklearn KNN实践
  5. Java网络编程之实现资源下载详解【王道Java】
  6. Linux下的常用的打包和解压缩命令
  7. OSI(open system internet)七层模型介绍以及NAT(Network Address Translation)技术详解
  8. Matrixdb添加mirror
  9. 怎么提取抖音里的音乐制作手机铃声
  10. 【Matlab绘图进阶第5弹】Matlab绘制三维散点图