《Instance-aware Semantic Segmentation via Multi-task Network Cascades》(MNC),用自己的数据运行
这篇博文主要对《Instance-aware Semantic Segmentation via Multi-task Network Cascades
》论文的代码用自己的数据运行,数据格式与原代码所需的格式一致。只是基础的利用代码,从这个过程里能了解代码的运行框架。
对这篇论文的理解参见:http://blog.csdn.net/u011070171/article/details/53071216
代码链接:https://github.com/daijifeng001/MNC
1.代码环境搭建
根据github上的代码说明,进行操作。
1)下载代码
git clone --recursive https://github.com/daijifeng001/MNC.git
2)安装python依赖包。caffe环境搭建好了的话,这些依赖包都会装过了,没有装的后面再补全。
Python packages: numpy
, scipy
, cython
,python-opencv
, easydict
, yaml
.
3)进入lib目录下进行编译。
cd $MNC_ROOT/lib
make
4)进入/MNC/caffe-mnc/目录,更改Makefile.config文件的内容如下:
# In your Makefile.config, make sure to have this line uncommented
WITH_PYTHON_LAYER := 1
# CUDNN is recommended in building to reduce memory footprint
USE_CUDNN := 1
5)进入caffe-mnc目录进行编译。
cd $MNC_ROOT/caffe-mnc
# If you have all of the requirements installed
# and your Makefile.config in place, then simply do:
make -j8 && make pycaffe
6)运行代码的demo
下载训练好的模型:
./data/scripts/fetch_mnc_model.sh
运行demo:
cd $MNC_ROOT
./tools/demo.py
7) 结果
我们可以在/MNC/data/demo/目录下看到图片分割的结果,原图以及运行demo后的结果图。
2. 用自己的数据
这篇博文只介绍到不更改原代码的数据输入格式进行代码运行,模型训练,用自己想要的数据格式进行代码运行后面再总结一下。不更改原代码数据输入格式,用自己的数据进行代码运行能让我们整体了解下代码运行流程。
1)运行入口在/MNC/experiments/scripts/目录里,新建一个my_train.sh文件,内容如下(是参考同目录下的shell文件写的):
路径最好用绝对路径,涉及个人隐私,贴出来的代码为相对路径的。。
#!/bin/bash
# Usage:
# ./experiments/scripts/my_train.sh [--set ...]
# Example:
# ./experiments/scripts/mnc_5stage.sh \
# --set EXP_DIR foobar RNG_SEED 42 TRAIN.SCALES "[400,500,600,700]"set -x
set -eexport PYTHONUNBUFFERED="True"ITERS=20
DATASET_TRAIN=my_train
DATASET_TEST=my_testLOG="/MNC/experiments/logs/znlog.txt"
exec &> >(tee -a "$LOG")
echo Logging output to "$LOG"NET_INIT=/MNC/data/mnc_model/mnc_model.caffemodel.h5
time /MNC/tools/train_net.py --gpu 0 \--solver /MNC/models/VGG16/my_train/solver.prototxt \--weights ${NET_INIT} \--imdb ${DATASET_TRAIN} \--iters ${ITERS} \--cfg /MNC/experiments/cfgs/VGG16/my_train.yml
2)shell文件里调用/MNC/tools/train_net.py文件。
训练参数说明:
--solver 求解器文件--imdb 图像数据库--iters 训练迭代次数--cfg 配置文件,可以在这个文件里设置EXP_DIR: mytrain,这为训练结果的保存目录
my_train.yml内容:
EXP_DIR: mytrain
MASK_SIZE: 21
TRAIN:RPN_POST_NMS_TOP_N: 300IMS_PER_BATCH: 1BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
3)train_net.py里调用attach_roidb(args.imdb_name)和attach_maskdb(args.imdb_name),分别获取imdb和maskdb.
attach_roidb()和attach_maskdb()函数分别在/MNC/lib/db/roidb.py和,/MNC/lib/db/roidb.py文件里。
attach_roidb()和attach_maskdb()函数均调用imdb.py里的函数。
我们在 imdb.py 里添加自己的数据库,imdb.py内容如下(路径仍是绝对路径,给出的是相对路径):
# --------------------------------------------------------
# Multitask Network Cascade
# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
# Copyright (c) 2016, Haozhi Qi
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------from datasets.pascal_voc_det import PascalVOCDet
from datasets.pascal_voc_seg import PascalVOCSeg
from datasets.my_dataset_seg import MyDatasetSeg__sets = {'voc_2012_seg_train': (lambda: PascalVOCSeg('train', '2012', '/home/zhuangni/code/idear2/MNC/data/VOCdevkitSDS/')),'voc_2012_seg_val': (lambda: PascalVOCSeg('val', '2012', '/home/zhuangni/code/idear2/MNC/data/VOCdevkitSDS')),'voc_2007_trainval': (lambda: PascalVOCDet('trainval', '2007')),'voc_2007_test': (lambda: PascalVOCDet('test', '2007')),'my_train': (lambda: MyDatasetSeg('train', '/MNC/data/MyDataset')),'my_test': (lambda: MyDatasetSeg('test', '/MNC/data/MyDataset'))
}def get_imdb(name):""" Get an imdb (image database) by name."""if not __sets.has_key(name):raise KeyError('Unknown dataset: {}'.format(name))return __sets[name]()def list_imdbs():return __sets.keys()
4)在imdb.py里,
from datasets.my_dataset_seg import MyDatasetSeg
这是加载数据库的类。
我们在/MNC/lib/datasets/目录下创建加载数据库主要有三个类:my_dataset.py,my_dataset_det.py,my_dataset_seg.py。这三个文件的内容分别可以拷贝对应的pascal_voc.py,pascal_voc_det.py,pascal_voc_seg.py并进行修改。
其中, my_dataset.py可不修改。pascal_voc.py,pascal_voc_det.py和pascal_voc_seg.py主要修改路径及类名问题。
代码里的cache_file为后面生成的。
数据库主要是参考VOC 2012 dataset建立的。运行如下命令可获取该数据库,数据库在/MNC/data/VOCdevkitSDS/里:
./data/scripts/fetch_sbd_data.sh
照此,我们建立自己的数据库:
my_dataset_det.py内容:
# --------------------------------------------------------
# Multitask Network Cascade
# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
# Copyright (c) 2016, Haozhi Qi
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------import os
import uuid
import cPickle
import numpy as np
import scipy.sparse
import PILimport xml.etree.ElementTree as xmlET
from datasets.my_dataset import MyDataset
from my_config import cfg
from utils.my_eval import my_evalclass MyDatasetDet(MyDataset):"""A subclass for PascalVOC"""def __init__(self, image_set, devkit_path=None):MyDataset.__init__(self, image_set)self._image_set = image_setself._devkit_path = devkit_pathself._data_path = self._devkit_pathself._classes = ('__background__', # always index 0'aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair','cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant','sheep', 'sofa', 'train', 'tvmonitor')self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))self._image_ext = '.jpg'self._image_index = self._load_image_set_index()# Default to roidb handler# self._roidb_handler = self.selective_search_roidbself._salt = str(uuid.uuid4())self._comp_id = 'comp4'# PASCAL specific config optionsself.config = {'cleanup': True,'use_salt': True,'top_k': 2000,'use_diff': False,'matlab_eval': False,'rpn_file': None}assert os.path.exists(self._devkit_path), \'VOCdevkit path does not exist: {}'.format(self._devkit_path)assert os.path.exists(self._data_path), \'Path does not exist: {}'.format(self._data_path)def image_path_at(self, i):return self.image_path_from_index(self._image_index[i])def image_path_from_index(self, index):image_path = os.path.join(self._data_path, 'img',index + self._image_ext)assert os.path.exists(image_path), \'Path does not exist: {}'.format(image_path)return image_pathdef gt_roidb(self):"""Return the database of ground-truth regions of interest.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:roidb = cPickle.load(fid)print '{} gt roidb loaded from {}'.format(self.name, cache_file)return roidbnum_image = len(self.image_index)if cfg.MNC_MODE:gt_roidb = [self._load_sbd_annotations(index) for index in xrange(num_image)]else:gt_roidb = [self._load_pascal_annotations(index) for index in xrange(num_image)]with open(cache_file, 'wb') as fid:cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote gt roidb to {}'.format(cache_file)return gt_roidbdef _load_image_set_index(self):"""Load the indexes listed in this dataset's image set file.Examplespath is: self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt--------"""image_set_file = os.path.join(self._data_path, self._image_set + '.txt')assert os.path.exists(image_set_file), \'Path does not exist: {}'.format(image_set_file)with open(image_set_file) as f:image_index = [x.strip() for x in f.readlines()]return image_indexdef append_flipped_rois(self):"""This method is irrelevant with database, so implement hereAppend flipped images to ROI databaseNote this method doesn't actually flip the 'image', it flipboxes instead"""cache_file = os.path.join(self.cache_path, self.name + '_' + cfg.TRAIN.PROPOSAL_METHOD + '_roidb_flip.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:flip_roidb = cPickle.load(fid)print '{} gt flipped roidb loaded from {}'.format(self.name, cache_file)else:num_images = self.num_imageswidths = [PIL.Image.open(self.image_path_at(i)).size[0]for i in xrange(num_images)]flip_roidb = []for i in xrange(num_images):boxes = self.roidb[i]['boxes'].copy()oldx1 = boxes[:, 0].copy()oldx2 = boxes[:, 2].copy()boxes[:, 0] = widths[i] - oldx2 - 1boxes[:, 2] = widths[i] - oldx1 - 1assert (boxes[:, 2] >= boxes[:, 0]).all()entry = {'boxes': boxes,'gt_overlaps': self.roidb[i]['gt_overlaps'],'gt_classes': self.roidb[i]['gt_classes'],'flipped': True}flip_roidb.append(entry)with open(cache_file, 'wb') as fid:cPickle.dump(flip_roidb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote gt flipped roidb to {}'.format(cache_file)self.roidb.extend(flip_roidb)self._image_index *= 2def _load_pascal_annotations(self, index):"""Load image and bounding boxes info from XML filein the PASCAL VOC format according to image index"""image_name = self._image_index[index]filename = os.path.join(self._data_path, 'Annotations', image_name + '.xml')tree = xmlET.parse(filename)objs = tree.findall('object')if not self.config['use_diff']:# Exclude the samples labeled as difficultnon_diff_objs = [obj for obj in objs if int(obj.find('difficult').text) == 0]if len(non_diff_objs) != len(objs):print 'Removed {} difficult objects'.format(len(objs) - len(non_diff_objs))objs = non_diff_objsnum_objs = len(objs)boxes = np.zeros((num_objs, 4), dtype=np.uint16)gt_classes = np.zeros(num_objs, dtype=np.int32)overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)# Load object bounding boxes into a data frame.# boxes[ind, :] will be boxes# gt_classes[ind] will be the associated class name for this box# overlaps[ind, class] will assign 1.0 to ground truthfor ix, obj in enumerate(objs):bbox = obj.find('bndbox')# Make pixel indexes 0-basedx1 = float(bbox.find('xmin').text) - 1y1 = float(bbox.find('ymin').text) - 1x2 = float(bbox.find('xmax').text) - 1y2 = float(bbox.find('ymax').text) - 1cls = self._class_to_ind[obj.find('name').text.lower().strip()]boxes[ix, :] = [x1, y1, x2, y2]gt_classes[ix] = clsoverlaps[ix, cls] = 1.0overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes': boxes,'gt_classes': gt_classes,'gt_overlaps': overlaps,'flipped': False}def _load_sbd_annotations(self, index):if index % 1000 == 0: print '%d / %d' % (index, len(self._image_index))image_name = self._image_index[index] inst_file_name = os.path.join(self._data_path, 'inst', image_name + '.mat') gt_inst_mat = scipy.io.loadmat(inst_file_name) gt_inst_data = gt_inst_mat['GTinst']['Segmentation'][0][0]unique_inst = np.unique(gt_inst_data)background_ind = np.where(unique_inst == 0)[0]unique_inst = np.delete(unique_inst, background_ind)cls_file_name = os.path.join(self._data_path, 'cls', image_name + '.mat')gt_cls_mat = scipy.io.loadmat(cls_file_name)gt_cls_data = gt_cls_mat['GTcls']['Segmentation'][0][0]boxes = np.zeros((len(unique_inst), 4), dtype=np.uint16)gt_classes = np.zeros(len(unique_inst), dtype=np.int32)overlaps = np.zeros((len(unique_inst), self.num_classes), dtype=np.float32)for ind, inst_mask in enumerate(unique_inst):im_mask = (gt_inst_data == inst_mask)im_cls_mask = np.multiply(gt_cls_data, im_mask)unique_cls_inst = np.unique(im_cls_mask)background_ind = np.where(unique_cls_inst == 0)[0]unique_cls_inst = np.delete(unique_cls_inst, background_ind)assert len(unique_cls_inst) == 1gt_classes[ind] = unique_cls_inst[0][r, c] = np.where(im_mask > 0)boxes[ind, 0] = np.min(c)boxes[ind, 1] = np.min(r)boxes[ind, 2] = np.max(c)boxes[ind, 3] = np.max(r)overlaps[ind, unique_cls_inst[0]] = 1.0overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes': boxes,'gt_classes': gt_classes,'gt_overlaps': overlaps,'flipped': False}"""-----------------Evaluation--------------------"""def evaluate_detections(self, all_boxes, output_dir):self._write_my_results_file(all_boxes)self._do_python_eval(output_dir)if self.config['matlab_eval']:raise NotImplementedErrorif self.config['cleanup']:for cls in self._classes:if cls == '__background__':continuefilename = self._get_my_results_file_template().format(cls)os.remove(filename)def _get_comp_id(self):comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']else self._comp_id)return comp_iddef _get_my_results_file_template(self):# VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txtfilename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'path = os.path.join(self._devkit_path,'results',filename)return pathdef _write_my_results_file(self, all_boxes):for cls_ind, cls in enumerate(self.classes):if cls == '__background__':continueprint 'Writing {} VOC results file'.format(cls)filename = self._get_my_results_file_template().format(cls)with open(filename, 'wt') as f:for im_ind, index in enumerate(self.image_index):dets = all_boxes[cls_ind][im_ind]if dets == []:continue# the VOCdevkit expects 1-based indicesfor k in xrange(dets.shape[0]):f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(index, dets[k, -1],dets[k, 0] + 1, dets[k, 1] + 1,dets[k, 2] + 1, dets[k, 3] + 1))def _do_python_eval(self, output_dir = 'output'):print '--------------------------------------------------------------'print 'Computing results with **unofficial** Python eval code.'print 'Results should be very close to the official MATLAB eval code.'print 'Recompute with `./tools/reval.py --matlab ...` for your paper.'print '--------------------------------------------------------------'annopath = os.path.join(self._devkit_path,'Annotations','{:s}.xml')imagesetfile = os.path.join(self._devkit_path,'ImageSets','Main',self._image_set + '.txt')cachedir = os.path.join(self._devkit_path, 'annotations_cache')aps = []if not os.path.isdir(output_dir):os.mkdir(output_dir)for i, cls in enumerate(self._classes):if cls == '__background__':continuefilename = self._get_my_results_file_template().format(cls)rec, prec, ap = my_eval(filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,use_07_metric=use_07_metric)aps += [ap]print('AP for {} = {:.4f}'.format(cls, ap))with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)print('Mean AP = {:.4f}'.format(np.mean(aps)))print('~~~~~~~~')print('Results:')for ap in aps:print('{:.3f}'.format(ap))print('{:.3f}'.format(np.mean(aps)))print('~~~~~~~~')
my_dataset_seg.py内容:
# --------------------------------------------------------
# Multitask Network Cascade
# Written by Haozhi Qi
# Copyright (c) 2016, Haozhi Qi
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------import cPickle
import os
import scipy.io as sio
import numpy as np
from datasets.my_dataset_det import MyDatasetDet
from my_config import cfg
from utils.vis_seg import vis_seg
from utils.my_eval import my_eval_sds
import scipyclass MyDatasetSeg(MyDatasetDet):"""A subclass for datasets.imdb.imdbThis class contains information of ROIDB and MaskDBThis class implements roidb and maskdb related functions"""def __init__(self, image_set, devkit_path=None):MyDatasetDet.__init__(self, image_set, devkit_path)self._ori_image_num = len(self._image_index)self._comp_id = 'comp6'# PASCAL specific config optionsself.config = {'cleanup': True,'use_salt': True,'top_k': 2000,'use_diff': False,'matlab_eval': False,'rpn_file': None}self._data_path = os.path.join(self._devkit_path)self._roidb_path = os.path.join(self.cache_path, image_set)def image_path_at(self, i):image_path = os.path.join(self._data_path, 'img', self._image_index[i] + self._image_ext)assert os.path.exists(image_path), 'Path does not exist: {}'.format(image_path)return image_pathdef roidb_path_at(self, i):if i >= self._ori_image_num:return os.path.join(self._roidb_path,self.image_index[i % self._ori_image_num] + '_flip.mat')else:return os.path.join(self._roidb_path,self.image_index[i] + '.mat')def gt_maskdb(self):cache_file = os.path.join(self.cache_path, self.name + '_gt_maskdb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:gt_maskdb = cPickle.load(fid)print '{} gt maskdb loaded from {}'.format(self.name, cache_file)else:num_image = len(self.image_index)gt_roidbs = self.gt_roidb()gt_maskdb = [self._load_sbd_mask_annotations(index, gt_roidbs)for index in xrange(num_image)]with open(cache_file, 'wb') as fid:cPickle.dump(gt_maskdb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote gt roidb to {}'.format(cache_file)return gt_maskdbdef _load_image_set_index(self):image_set_file = os.path.join(self._data_path, self._image_set + '.txt')assert os.path.exists(image_set_file), 'Path does not exist: {}'.format(image_set_file)with open(image_set_file) as f:image_index = [x.strip() for x in f.readlines()]return image_indexdef _load_sbd_mask_annotations(self, index, gt_roidbs):"""Load gt_masks information from SBD's additional data"""if index % 1000 == 0:print '%d / %d' % (index, len(self._image_index))image_name = self._image_index[index] inst_file_name = os.path.join(self._data_path, 'inst', image_name + '.mat')gt_inst_mat = scipy.io.loadmat(inst_file_name)gt_inst_data = gt_inst_mat['GTinst']['Segmentation'][0][0]unique_inst = np.unique(gt_inst_data)background_ind = np.where(unique_inst == 0)[0]unique_inst = np.delete(unique_inst, background_ind)gt_roidb = gt_roidbs[index]cls_file_name = os.path.join(self._data_path, 'cls', image_name + '.mat')gt_cls_mat = scipy.io.loadmat(cls_file_name)gt_cls_data = gt_cls_mat['GTcls']['Segmentation'][0][0]gt_masks = []for ind, inst_mask in enumerate(unique_inst):box = gt_roidb['boxes'][ind]im_mask = (gt_inst_data == inst_mask)im_cls_mask = np.multiply(gt_cls_data, im_mask)unique_cls_inst = np.unique(im_cls_mask)background_ind = np.where(unique_cls_inst == 0)[0]unique_cls_inst = np.delete(unique_cls_inst, background_ind)assert len(unique_cls_inst) == 1assert unique_cls_inst[0] == gt_roidb['gt_classes'][ind]mask = im_mask[box[1]: box[3]+1, box[0]:box[2]+1]gt_masks.append(mask)# Also record the maximum dimension to create fixed dimension array when do forwardingmask_max_x = max(gt_masks[i].shape[1] for i in xrange(len(gt_masks)))mask_max_y = max(gt_masks[i].shape[0] for i in xrange(len(gt_masks)))return {'gt_masks': gt_masks,'mask_max': [mask_max_x, mask_max_y],'flipped': False}def append_flipped_masks(self):"""This method is only accessed when we use maskdb, so implement hereAppend flipped images to mask databaseNote this method doesn't actually flip the 'image', it flip masks instead"""cache_file = os.path.join(self.cache_path, self.name + '_' + cfg.TRAIN.PROPOSAL_METHOD + '_maskdb_flip.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:flip_maskdb = cPickle.load(fid)print '{} gt flipped roidb loaded from {}'.format(self.name, cache_file)self.maskdb.extend(flip_maskdb)# Need to check this condition since otherwise we may occasionally *4if self._image_index == self.num_images:self._image_index *= 2else:# pure image number hold for future development# this is useless since append flip mask will only be called oncenum_images = self._ori_image_numflip_maskdb = []for i in xrange(num_images):masks = self.maskdb[i]['gt_masks']masks_flip = []for mask_ind in xrange(len(masks)):mask_flip = np.fliplr(masks[mask_ind])masks_flip.append(mask_flip)entry = {'gt_masks': masks_flip,'mask_max': self.maskdb[i]['mask_max'],'flipped': True}flip_maskdb.append(entry)with open(cache_file, 'wb') as fid:cPickle.dump(flip_maskdb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote gt flipped maskdb to {}'.format(cache_file)self.maskdb.extend(flip_maskdb)# Need to check this condition since otherwise we may occasionally *4if self._image_index == self.num_images:self._image_index *= 2def visualization_segmentation(self, output_dir):vis_seg(self.image_index, self.classes, output_dir, self._data_path)# --------------------------- Evaluation ---------------------------def evaluate_segmentation(self, all_boxes, all_masks, output_dir):self._write_my_seg_results_file(all_boxes, all_masks, output_dir)self._py_evaluate_segmentation(output_dir)def _write_my_seg_results_file(self, all_boxes, all_masks, output_dir):"""Write results as a pkl file, note this is different fromdetection task since it's difficult to write masks to txt"""# Always reformat result in case of sometimes masks are not# binary or is in shape (n, sz*sz) instead of (n, sz, sz)all_boxes, all_masks = self._reformat_result(all_boxes, all_masks)for cls_inds, cls in enumerate(self.classes):if cls == '__background__':continueprint 'Writing {} VOC results file'.format(cls)filename = os.path.join(output_dir, cls + '_det.pkl')with open(filename, 'wr') as f:cPickle.dump(all_boxes[cls_inds], f, cPickle.HIGHEST_PROTOCOL)filename = os.path.join(output_dir, cls + '_seg.pkl')with open(filename, 'wr') as f:cPickle.dump(all_masks[cls_inds], f, cPickle.HIGHEST_PROTOCOL)def _reformat_result(self, boxes, masks):num_images = len(self.image_index)num_class = len(self.classes)reformat_masks = [[[] for _ in xrange(num_images)]for _ in xrange(num_class)]for cls_inds in xrange(1, num_class):for img_inds in xrange(num_images):if len(masks[cls_inds][img_inds]) == 0:continuenum_inst = masks[cls_inds][img_inds].shape[0]reformat_masks[cls_inds][img_inds] = masks[cls_inds][img_inds]\.reshape(num_inst, cfg.MASK_SIZE, cfg.MASK_SIZE)reformat_masks[cls_inds][img_inds] = reformat_masks[cls_inds][img_inds] >= cfg.BINARIZE_THRESHall_masks = reformat_masksreturn boxes, all_masksdef _py_evaluate_segmentation(self, output_dir):gt_dir = self._data_pathimageset_file = os.path.join(gt_dir, self._image_set + '.txt')cache_dir = os.path.join(self._devkit_path, 'annotations_cache')aps = []if not os.path.isdir(output_dir):os.mkdir(output_dir)print '~~~~~~ Evaluation use min overlap = 0.5 ~~~~~~'for i, cls in enumerate(self._classes):if cls == '__background__':continuedet_filename = os.path.join(output_dir, cls + '_det.pkl')seg_filename = os.path.join(output_dir, cls + '_seg.pkl')ap = my_eval_sds(det_filename, seg_filename, gt_dir,imageset_file, cls, cache_dir, self._classes, ov_thresh=0.5)aps += [ap]print('AP for {} = {:.2f}'.format(cls, ap*100))print('Mean AP@0.5 = {:.2f}'.format(np.mean(aps)*100))print '~~~~~~ Evaluation use min overlap = 0.7 ~~~~~~'aps = []for i, cls in enumerate(self._classes):if cls == '__background__':continuedet_filename = os.path.join(output_dir, cls + '_det.pkl')seg_filename = os.path.join(output_dir, cls + '_seg.pkl')ap = my_eval_sds(det_filename, seg_filename, gt_dir,imageset_file, cls, cache_dir, self._classes, ov_thresh=0.7)aps += [ap]print('AP for {} = {:.2f}'.format(cls, ap*100))print('Mean AP@0.7 = {:.2f}'.format(np.mean(aps)*100))
5)到此,我们可以进行训练了,进入/MNC/experiments/scripts/目录,运行:
./my_train.sh
6)此外,mnc_config.py为配置参数设置,要是想制定自己的配置参数,可以建立一个自己的配置文件,在/MNC/lib/目录里新建一个my_config.py,内容可以拷贝mnc_config.py,并根据自己的需要修改。各引用文件也需要修改。有引用的文件有:train_net.py,roidb.py,maskdb.py,my_dataset.py,my_dataset_det.py,my_dataset_seg.py,/MNC/lib/caffeWrapper/SolverWrapper.py。
from mnc_config import cfg, cfg_from_file, get_output_dir
改为
from my_config import cfg, cfg_from_file, get_output_dir
7)实验结果
后记: 这篇博文的主要目的在于将这篇论文代码的流程熟悉一遍。也是自己学习的自我记录。
《Instance-aware Semantic Segmentation via Multi-task Network Cascades》(MNC),用自己的数据运行相关推荐
- 语义分割--Large Kernel Matters--Improve Semantic Segmentation by Global Convolutional Network
Large Kernel Matters–Improve Semantic Segmentation by Global Convolutional Network https://arxiv.org ...
- JSNet: Joint Instance and Semantic Segmentation of 3D Point Clouds
Abstract 在本文中,我们提出了一种新的联合实例和语义分割方法,称为JSNet,以同时解决3D点云的实例和语义分割.首先,我们建立了一个有效的骨干网络来从原始点云中提取鲁棒的特征.其次,为了获得 ...
- ECCV2020语义分割——Self-Prediction for Joint Instance and Semantic Segmentation of Point Clouds
Self-Prediction for Joint Instance and Semantic Segmentation of Point Clouds Abstract (一) Introducti ...
- Robust semantic segmentation by dense fusion network on blurred vhr remote sensing images
Robust semantic segmentation by dense fusion network on blurred vhr remote sensing images 说明se的设计还是很 ...
- Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 )
Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 ) Abstract 本文 ...
- [论文阅读] Semi-supervised Semantic Segmentation via Strong-Weak Dual-Branch Network
[论文地址] [代码] [ECCV 20] Abstract 虽然现有的工作已经探索了各种技术来推动弱监督语义分割的发展,但与监督方法相比仍有很大差距.在现实世界的应用中,除了大量的弱监督数据外,通常 ...
- [分割]Learning a Discriminative Feature Network for Semantic Segmentation(DFN)
本文转自3篇文章当作自己的笔记. 文章1:链接 Learning a Discriminative Feature Network for Semantic Segmentation Learning ...
- Fully Convolutional Networks for Semantic Segmentation (FCN)论文翻译和理解
论文题目:Fully Convolutional Networks for Semantic Segmentation 论文来源:Fully Convolutional Networks for Se ...
- Weakly Supervised Semantic Segmentation list
Weakly Supervised Semantic Segmentation list 文章转自Github:https://github.com/JackieZhangdx/WeakSupervi ...
- 【论文翻译】Fully Convolutional Networks for Semantic Segmentation
论文题目:Fully Convolutional Networks for Semantic Segmentation 论文来源:Fully Convolutional Networks for Se ...
最新文章
- php 汉字的首字母
- 【计算机网络】数据链路层 : IEEE 802.11 无线局域网 ( 802.11 MAC 帧头格式 | 无线局域网分类 )
- sublime golang 开发的时候(go get)第三方包没办法自动提示问题
- 【技术系列】浅谈GPU虚拟化技术(第一章)
- reactjs ref属性:字符串类型的ref和createRef
- 几个常见的Python面试题分享,帮你顺利求职
- amd and nvidia gop_抱紧台积电的大腿就是好,AMD不仅CPU领先,显卡也逆袭有望
- .NET3.0已经Pre-release了
- C++工作笔记-枚举类型的作用
- Android子线程进度条不显示的问题
- mysql根据视图update表数据_怎么更新Mysql数据表视图中数据
- 24_多易教育之《yiee数据运营系统》OLAP平台-运营分析篇
- 地图可视化开发技巧:geojson转svg后再转emf格式插入ppt实现编辑的解决方案
- 如何设置微信公众号多条被关注自动回复个性语
- windows hotkey
- 瑞利商(Rayleigh Quotient)及瑞利定理(Rayleigh-Ritz theorem)的证明
- C++:实现量化相关的Interpolation插值测试实例
- 【opencv 450 core】使用统一向量指令(Universal Intrinsics)对代码进行矢量化
- oracle控制文件有坏块处理过程
- Teams会议/实时事件中的参会者报告详解