研究背景

根据Faster-RCNN算法的运行和调试情况,对代码进行深入分析。

参考资料

各部分代码分析

1 编译Cython模块

cd tf-faster-rcnn/lib # 首先进入目录Faster-RCNN_TF/lib

make clean

make #编译

编译成功之后,目录tf-faster-rcnn/lib/nms 和tf-faster-rcnn/lib/roi_pooling_layer/ 和tf-faster-rcnn/lib/utils下面会出现一些.so文件。

注意:.so文件不具可移植到性,因为编译生成的文件是只适应本台计算机的,换一台计算机之后,用原来的.so文件程序会出错。并且,必须要先删除旧的.so文件make clean,否则就会调用旧的.so文件,而不生成新的.so文件。重新运行程序的时候,要先删除这几个.so文件,并重新进行编译。

2 pascal_voc数据集的数据读写接口

2.1 工程文件tf-faster-rcnn中读取数据的接口都在目录tf-faster-rcnn/lib/datasets下。共有2种数据来训练网络,分别是pascal_voc和coco,数据读写接口分别是tf-faster-rcnn/lib/datasets中的pascal_voc.py和coco.py。

工程主要用到的是目录Annotations中的XML文件、目录JPEGImages中的图片、目录ImageSets/Layout中的txt文件。

目录下其他文件:

factory.py:是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用;

imdb.py:是数据库读写类的基类,分装了许多db的操作,具体的一些文件读写需要继承继续读写。

VOCdevkit/

VOCdevkit/VOC2007/

VOCdevkit/VOC2007/Annotations #所有图片的XML文件,一张图片对应一个XML文件,XML文件中给出的图片gt的形式是左上角和右下角的坐标

VOCdevkit/VOC2007/ImageSets/

VOCdevkit/VOC2007/ImageSets/Layout #里面有三个txt文件,分别是train.txt,trainval.txt,val.txt,存储的分别是训练图片的名字列表,训练验证集的图片名字列表,验证集图片的名字列表(名字均没有.jpg后缀)

VOCdevkit/VOC2007/ImageSets/Main

VOCdevkit/VOC2007/ImageSets/Segmentation

VOCdevkit/VOC2007/JPEGImages #所有的图片*.jpg

VOCdevkit/VOC2007/SegmentationClass #segmentations by class

VOCdevkit/VOC2007/SegmentationObject #segmentations by object

2.2 pascal_voc的数据读写接口

主函数 if name == ‘main’在文件pascal_voc.py的最下面

if __name__ == '__main__':

from datasets.pascal_voc import pascal_voc

d = pascal_voc('trainval', '2007') #pascal_voc是一个类

res = d.roidb

from IPython import embed;

embed()

主函数中的类 pascal_voc代码,在文件pascal_voc.py的最上面:

class pascal_voc(imdb):

def __init__(self, image_set, year, use_diff=False):

name = 'voc_' + year + '_' + image_set

if use_diff:

name += '_diff'

imdb.__init__(self, name)

self._year = year

self._image_set = image_set

self._devkit_path = self._get_default_path()

self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)

self._classes = ('__background__', # always index 0

'aeroplane', 'bicycle', 'bird', 'boat',

'bottle', 'bus', 'car', 'cat', 'chair',

'cow', 'diningtable', 'dog', 'horse',

'motorbike', 'person', 'pottedplant',

'sheep', 'sofa', 'train', 'tvmonitor')

self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))

self._image_ext = '.jpg'

self._image_index = self._load_image_set_index()

# Default to roidb handler

self._roidb_handler = self.gt_roidb

self._salt = str(uuid.uuid4())

self._comp_id = 'comp4'

# PASCAL specific config options

self.config = {'cleanup': True,

'use_salt': True,

'use_diff': use_diff,

'matlab_eval': False,

'rpn_file': None}

assert os.path.exists(self._devkit_path), \

'VOCdevkit path does not exist: {}'.format(self._devkit_path)

assert os.path.exists(self._data_path), \

'Path does not exist: {}'.format(self._data_path)

def image_path_at(self, i):

"""

Return the absolute path to image i in the image sequence.

"""

return self.image_path_from_index(self._image_index[i])

def image_path_from_index(self, index):

"""

Construct an image path from the image's "index" identifier.

"""

image_path = os.path.join(self._data_path, 'JPEGImages',

index + self._image_ext)

assert os.path.exists(image_path), \

'Path does not exist: {}'.format(image_path)

return image_path

def _load_image_set_index(self):

"""

Load the indexes listed in this dataset's image set file.

"""

# Example path to image set file:

# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt

image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',

self._image_set + '.txt')

assert os.path.exists(image_set_file), \

'Path does not exist: {}'.format(image_set_file)

with open(image_set_file) as f:

image_index = [x.strip() for x in f.readlines()]

return image_index

def _get_default_path(self):

"""

Return the default path where PASCAL VOC is expected to be installed.

"""

return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

def gt_roidb(self):

"""

Return the database of ground-truth regions of interest.

This function loads/saves from/to a cache file to speed up future calls.

"""

cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')

if os.path.exists(cache_file):

with open(cache_file, 'rb') as fid:

try:

roidb = pickle.load(fid)

except:

roidb = pickle.load(fid, encoding='bytes')

print('{} gt roidb loaded from {}'.format(self.name, cache_file))

return roidb

gt_roidb = [self._load_pascal_annotation(index)

for index in self.image_index]

with open(cache_file, 'wb') as fid:

pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)

print('wrote gt roidb to {}'.format(cache_file))

return gt_roidb

def rpn_roidb(self):

if int(self._year) == 2007 or self._image_set != 'test':

gt_roidb = self.gt_roidb()

rpn_roidb = self._load_rpn_roidb(gt_roidb)

roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)

else:

roidb = self._load_rpn_roidb(None)

return roidb

def _load_rpn_roidb(self, gt_roidb):

filename = self.config['rpn_file']

print('loading {}'.format(filename))

assert os.path.exists(filename), \

'rpn data not found at: {}'.format(filename)

with open(filename, 'rb') as f:

box_list = pickle.load(f)

return self.create_roidb_from_box_list(box_list, gt_roidb)

def _load_pascal_annotation(self, index):

"""

Load image and bounding boxes info from XML file in the PASCAL VOC

format.

"""

filename = os.path.join(self._data_path, 'Annotations', index + '.xml')

tree = ET.parse(filename)

objs = tree.findall('object')

if not self.config['use_diff']:

# Exclude the samples labeled as difficult

non_diff_objs = [

obj for obj in objs if int(obj.find('difficult').text) == 0]

# if len(non_diff_objs) != len(objs):

# print 'Removed {} difficult objects'.format(

# len(objs) - len(non_diff_objs))

objs = non_diff_objs

num_objs = len(objs)

boxes = np.zeros((num_objs, 4), dtype=np.uint16)

gt_classes = np.zeros((num_objs), dtype=np.int32)

overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)

# "Seg" area for pascal is just the box area

seg_areas = np.zeros((num_objs), dtype=np.float32)

# Load object bounding boxes into a data frame.

for ix, obj in enumerate(objs):

bbox = obj.find('bndbox')

# Make pixel indexes 0-based

x1 = float(bbox.find('xmin').text) - 1

y1 = float(bbox.find('ymin').text) - 1

x2 = float(bbox.find('xmax').text) - 1

y2 = float(bbox.find('ymax').text) - 1

cls = self._class_to_ind[obj.find('name').text.lower().strip()]

boxes[ix, :] = [x1, y1, x2, y2]

gt_classes[ix] = cls

overlaps[ix, cls] = 1.0

seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

overlaps = scipy.sparse.csr_matrix(overlaps)

return {'boxes': boxes,

'gt_classes': gt_classes,

'gt_overlaps': overlaps,

'flipped': False,

'seg_areas': seg_areas}

def _get_comp_id(self):

comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']

else self._comp_id)

return comp_id

def _get_voc_results_file_template(self):

# VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt

filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'

path = os.path.join(

self._devkit_path,

'results',

'VOC' + self._year,

'Main',

filename)

return path

def _write_voc_results_file(self, all_boxes):

for cls_ind, cls in enumerate(self.classes):

if cls == '__background__':

continue

print('Writing {} VOC results file'.format(cls))

filename = self._get_voc_results_file_template().format(cls)

with open(filename, 'wt') as f:

for im_ind, index in enumerate(self.image_index):

dets = all_boxes[cls_ind][im_ind]

if dets == []:

continue

# the VOCdevkit expects 1-based indices

for k in range(dets.shape[0]):

f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.

format(index, dets[k, -1],

dets[k, 0] + 1, dets[k, 1] + 1,

dets[k, 2] + 1, dets[k, 3] + 1))

def _do_python_eval(self, output_dir='output'):

annopath = os.path.join(

self._devkit_path,

'VOC' + self._year,

'Annotations',

'{:s}.xml')

imagesetfile = os.path.join(

self._devkit_path,

'VOC' + self._year,

'ImageSets',

'Main',

self._image_set + '.txt')

cachedir = os.path.join(self._devkit_path, 'annotations_cache')

aps = []

# The PASCAL VOC metric changed in 2010

use_07_metric = True if int(self._year) < 2010 else False

print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))

if not os.path.isdir(output_dir):

os.mkdir(output_dir)

for i, cls in enumerate(self._classes):

if cls == '__background__':

continue

filename = self._get_voc_results_file_template().format(cls)

rec, prec, ap = voc_eval(

filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,

use_07_metric=use_07_metric, use_diff=self.config['use_diff'])

aps += [ap]

print(('AP for {} = {:.4f}'.format(cls, ap)))

with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:

pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)

print(('Mean AP = {:.4f}'.format(np.mean(aps))))

print('~~~~~~~~')

print('Results:')

for ap in aps:

print(('{:.3f}'.format(ap)))

print(('{:.3f}'.format(np.mean(aps))))

print('~~~~~~~~')

print('')

print('--------------------------------------------------------------')

print('Results computed with the **unofficial** Python eval code.')

print('Results should be very close to the official MATLAB eval code.')

print('Recompute with `./tools/reval.py --matlab ...` for your paper.')

print('-- Thanks, The Management')

print('--------------------------------------------------------------')

def _do_matlab_eval(self, output_dir='output'):

print('-----------------------------------------------------')

print('Computing results with the official MATLAB eval code.')

print('-----------------------------------------------------')

path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',

'VOCdevkit-matlab-wrapper')

cmd = 'cd {} && '.format(path)

cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)

cmd += '-r "dbstop if error; '

cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \

.format(self._devkit_path, self._get_comp_id(),

self._image_set, output_dir)

print(('Running:\n{}'.format(cmd)))

status = subprocess.call(cmd, shell=True)

def evaluate_detections(self, all_boxes, output_dir):

self._write_voc_results_file(all_boxes)

self._do_python_eval(output_dir)

if self.config['matlab_eval']:

self._do_matlab_eval(output_dir)

if self.config['cleanup']:

for cls in self._classes:

if cls == '__background__':

continue

filename = self._get_voc_results_file_template().format(cls)

os.remove(filename)

def competition_mode(self, on):

if on:

self.config['use_salt'] = False

self.config['cleanup'] = False

else:

self.config['use_salt'] = True

self.config['cleanup'] = True

init是初始化函数,对应着的是pascal_voc的数据集访问格式

def __init__(self, image_set, year, use_diff=False):

name = 'voc_' + year + '_' + image_set

if use_diff:

name += '_diff'

imdb.__init__(self, name) #继承了类imdb的初始化函数__init__(),传进去的参数是voc_2007_train。类imdb在lib/datasets/imdb.py里面被定义

self._year = year #是一个str,是VOC数据的年份,值是'2007'或者'2012',以2007为例

self._image_set = image_set #是一个str,值是'train'或者'test'或者'trainval'或者'val',表示的意思是用(训练集)或者(测试集)或者(训练验证集)或者(验证集)里面的数据,以train为例

self._devkit_path = self._get_default_path() #调用def _get_default_path(self) 路径data/VOCdevkit/VOC2007

self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)#VOC2007

self._classes = ('__background__', # always index 0

'aeroplane', 'bicycle', 'bird', 'boat',

'bottle', 'bus', 'car', 'cat', 'chair',

'cow', 'diningtable', 'dog', 'horse',

'motorbike', 'person', 'pottedplant',

'sheep', 'sofa', 'train', 'tvmonitor')

#数据集中所包含的全部的object类别

self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))

# 构建字典{'__background__':'0','aeroplane':'1', 'bicycle':'2', 'bird':'3', 'boat':'4','bottle':'5', 'bus':'6', 'car':'7', 'cat':'8', 'chair':'9','cow':'10', 'diningtable':'11', 'dog':'12', 'horse':'13','motorbike':'14', 'person':'15', 'pottedplant':'16','sheep':'17', 'sofa':'18', 'train':'19', 'tvmonitor':'20'} self.num_classes是object的类别总数21(背景background也算一类),这个函数继承自lib/datasets/imdb.py

self._image_ext = '.jpg' # 图片后缀名

self._image_index = self._load_image_set_index() #加载了样本的list文件

# Default to roidb handler

self._roidb_handler = self.gt_roidb # 当有RPN的时候,读取并返回图片gt的db。函数gt_roidb里面并没有提取图片的ROI,因为faster-rcnn有RPN,用RPN来提取ROI。函数gt_roidb返回的是图片的gt。(fast-rcnn没有RPN)

self._salt = str(uuid.uuid4())

self._comp_id = 'comp4'

# PASCAL specific config options

self.config = {'cleanup': True,

'use_salt': True,

'use_diff': use_diff,

'matlab_eval': False,

'rpn_file': None}

assert os.path.exists(self._devkit_path), \

'VOCdevkit path does not exist: {}'.format(self._devkit_path) #如果路径self._devkit_path(也就是目录VOCdevkit)不存在,退出

assert os.path.exists(self._data_path), \

'Path does not exist: {}'.format(self._data_path)#如果路径self._data_path(也就是VOCdevkit/VOC2007)不存在,退出

子函数def _get_default_path(self)

def _get_default_path(self):

"""

Return the default path where PASCAL VOC is expected to be installed.

返回数据集pascal_voc的默认路径:tf-faster-rcnn/data/VOCdevkit/2007

"""

return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)#cfg.DATA_DIR是在tf-faster-rcnn/lib/model/config.py里面定义

tf-faster-rcnn/lib/model/config.py中定义DATA_DIR的地方是这样的(在257-261行):

# Root directory of project

__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))

# Data directory

__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))

子函数def _load_image_set_index(self)

def _load_image_set_index(self):

"""

Load the indexes listed in this dataset's image set file.

"""

# Example path to image set file:

# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt

image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',

self._image_set + '.txt')

# image_set_file就是tf-faster-rcnn/data/VOCdevkit2007/VOC2007/ImageSets/Layout/train.txt

#之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)

assert os.path.exists(image_set_file), \

'Path does not exist: {}'.format(image_set_file)

with open(image_set_file) as f: # 读上面的train.txt文件

image_index = [x.strip() for x in f.readlines()] #将train.txt的内容(图片名字)读取出来放在image_index里面

return image_index #得到image_set里面所有图片的名字(没有后缀.jpg)

得到一个list,这个list里面是集合self._image_set中所有图片的名字(注意,图片名字没有后缀.jpg)

子函数def gt_roidb(self)

def gt_roidb(self):

"""

Return the database of ground-truth regions of interest.

This function loads/saves from/to a cache file to speed up future calls.

"""

cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')

#给.pkl文件起个名字。参数self.cache_path和self.name继承自类imdb,类imdb在lib/datasets/imdb.py里面定义

if os.path.exists(cache_file): # 如果这个.pkl文件存在(说明之前执行过本函数,生成了这个pkl文件)即预处理模型pretrain model

with open(cache_file, 'rb') as fid: #打开

try:

roidb = pickle.load(fid)

except:

roidb = pickle.load(fid, encoding='bytes') #将里面的数据加载进来

print('{} gt roidb loaded from {}'.format(self.name, cache_file))

return roidb #返回

gt_roidb = [self._load_pascal_annotation(index) # 如果这个.pkl文件不存在,说明是第一次执行本函数。

for index in self.image_index] #那么首先要做的就是获取图片的gt,函数_load_pascal_annotation的作用就是获取图片gt。

with open(cache_file, 'wb') as fid: pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)

pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL) #将图片的gt保存在.pkl文件里面

print('wrote gt roidb to {}'.format(cache_file))

return gt_roidb

读取并返回图片gt的db。这个函数就是将图片的gt加载进来。

其中,pascal_voc图片的gt信息在XML文件中;并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中)之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。

参数self.cache_path和self.name继承自类imdb,类imdb在tf-faster-rcnn/lib/datasets/imdb.py里面被定义。类imdb中定义函数self.cache_path的地方在imdb.py中的77-82行:

@property

def name(self):

return self._name

@property

def cache_path(self):

cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))

if not os.path.exists(cache_path):

os.makedirs(cache_path)

return cache_path

类imdb中定义函数self.name的地方在imdb.py中的23-35行:

def __init__(self, name, classes=None): #是类imdb的初始化函数,在pascal_voc.py被用到

self._name = name # name是形参,传进来的参数是'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'

self._num_classes = 0

if not classes:

self._classes = [] #类imdb中定义函数self.name的地方

else:

self._classes = classes

self._image_index = []

self._obj_proposer = 'gt'

self._roidb = None

self._roidb_handler = self.default_roidb

# Use this dict for storing dataset specific config options

self.config = {}

@property

def name(self): #类imdb中定义函数self.name的地方

return self._name #返回的是本文件imdb.py中的self._name

注意:如果再次训练的时候修改了train数据库,增加或者删除了一些数据,再想重新训练的时候,一定要先删除这个output中的.pkl文件。因为如果不删除的话,就会自动加载旧的pkl文件,而不会生成新的pkl文件。

子函数def _load_pascal_annotation(self, index),这个函数是读取图片gt的具体实现

def _load_pascal_annotation(self, index):

"""

Load image and bounding boxes info from XML file in the PASCAL VOC

format.

从XML文件中获取图片信息和gt。

这个XML文件存储的是PASCAL VOC图片的信息和gt的信息,下载VOC数据集的时候,XML文件是一块下载下来的。在文件夹Annotation里面。

"""

filename = os.path.join(self._data_path, 'Annotations', index + '.xml')

#这个filename就是一个XML文件的路径,其中index是一张图片的名字(没有后缀)。例如VOCdevkit2007/VOC2007/Annotations/000005.xml

tree = ET.parse(filename)

objs = tree.findall('object')

if not self.config['use_diff']:

# Exclude the samples labeled as difficult

non_diff_objs = [

obj for obj in objs if int(obj.find('difficult').text) == 0]

# if len(non_diff_objs) != len(objs):

# print 'Removed {} difficult objects'.format(

# len(objs) - len(non_diff_objs))

objs = non_diff_objs

num_objs = len(objs) # 输进来的图片上的物体object的个数

boxes = np.zeros((num_objs, 4), dtype=np.uint16)

gt_classes = np.zeros((num_objs), dtype=np.int32)

overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)

# "Seg" area for pascal is just the box area

seg_areas = np.zeros((num_objs), dtype=np.float32)

# Load object bounding boxes into a data frame.

for ix, obj in enumerate(objs): # 对于该图片上每一个object

bbox = obj.find('bndbox')

# pascal_voc的XML文件中给出的图片gt的形式是左上角和右下角的坐标

# Make pixel indexes 0-based

x1 = float(bbox.find('xmin').text) - 1

y1 = float(bbox.find('ymin').text) - 1

x2 = float(bbox.find('xmax').text) - 1

y2 = float(bbox.find('ymax').text) - 1

#为什么要减去1?是因为VOC的数据,坐标-1,默认坐标从0开始(这个还有待商榷,先忽略)

cls = self._class_to_ind[obj.find('name').text.lower().strip()]

#找到该object的类别cls

boxes[ix, :] = [x1, y1, x2, y2]

gt_classes[ix] = cls

overlaps[ix, cls] = 1.0

seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

# seg_areas[ix]是该object gt的面积

overlaps = scipy.sparse.csr_matrix(overlaps)

return {'boxes': boxes,

'gt_classes': gt_classes,

'gt_overlaps': overlaps,

'flipped': False,

'seg_areas': seg_areas}

子函数def image_path_at(self, i)

def image_path_at(self, i):

"""

Return the absolute path to image i in the image sequence.

"""

return self.image_path_from_index(self._image_index[i])

根据第i个图像样本返回其对应的path,其调用了image_path_from_index(self, index)作为其具体实现。

子函数def image_path_from_index(self, index)

def image_path_from_index(self, index):

"""

Construct an image path from the image's "index" identifier.

"""

image_path = os.path.join(self._data_path, 'JPEGImages',

index + self._image_ext)

#这个就是图片本身所在的路径。其中index是一张图片的名字(没有后缀),_image_ext是图片后缀名.jpg。例如VOCdevkit2007/VOC2007/JPEGImages/000005.jpg

assert os.path.exists(image_path), \

'Path does not exist: {}'.format(image_path)

# 如果该路径不存在,退出

return image_path

以上可见,pascal_voc.py用了较多的路径拼接

3 修改模型文件配置

修改config.py

工程tf-faster-rcnn中模型的参数都在文件tf-faster-rcnn/lib/model/config.py中被定义。

# Images to use per minibatch

__C.TRAIN.IMS_PER_BATCH = 1 #每次输入到faster-rcnn网络中的图片数量是1张

# Iterations between snapshots

__C.TRAIN.SNAPSHOT_ITERS = 5000 # 训练的时候,每5000步保存一次模型。

# solver.prototxt specifies the snapshot path prefix, this adds an optional

# infix to yield the path: [_]_iters_XYZ.caffemodel

__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn' #模型在保存时的名字

# Use RPN to detect objects

__C.TRAIN.HAS_RPN = True #是否使用RPN。True代表使用RPN

demo.py分析

CLASSES = ('__background__',

'"seaurchin"', '"scallop"', '"seacucumber"')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_15000.ckpt',),

'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

def vis_detections模块:画出测试图片的bounding boxes, 参数im为测试图片; class_name 为类别名称,在前面定义的 CLASSES 中; dets为非极大值抑制后的bbox和score的数组;thresh是最后score的阈值,高于该阈值的候选框才会被画出来。

def vis_detections(im, class_name, dets, thresh=0.5):

"""Draw detected bounding boxes."""

##选取候选框score大于阈值的dets

inds = np.where(dets[:, -1] >= thresh)[0]

if len(inds) == 0:

return

# python-opencv 中读取图片默认保存为[w,h,channel](w,h顺序不确定)

# 其中 channel:BGR 存储,而画图时,需要按RGB格式,因此此处作转换。

im = im[:, :, (2, 1, 0)]

fig, ax = plt.subplots(figsize=(12, 12))

ax.imshow(im, aspect='equal')

for i in inds: #从dets中取出 bbox, score

bbox = dets[i, :4]

score = dets[i, -1]

# 根据起始点坐标以及w,h 画出矩形框

ax.add_patch(

plt.Rectangle((bbox[0], bbox[1]),

bbox[2] - bbox[0],

bbox[3] - bbox[1], fill=False,

edgecolor='red', linewidth=3.5)

)

ax.text(bbox[0], bbox[1] - 2,

'{:s} {:.3f}'.format(class_name, score),

bbox=dict(facecolor='blue', alpha=0.5),

fontsize=14, color='white')

ax.set_title(('{} detections with '

'p({} | box) >= {:.1f}').format(class_name, class_name,

thresh),

fontsize=14)

plt.axis('off')

plt.tight_layout()

plt.draw()

def demo模块:对测试图片提取预选框,并进行非极大值抑制,然后调用def vis_detections 画矩形框。参数:net 测试时使用的网络结构;image_name:图片名称。

def demo(sess, net, image_name):

"""Detect object classes in an image using pre-computed object proposals."""

# Load the demo image

im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)

im = cv2.imread(im_file)

# Detect all object classes and regress object bounds

timer = Timer()

timer.tic()

scores, boxes = im_detect(sess, net, im)

timer.toc()

print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

# Visualize detections for each class

CONF_THRESH = 0.8

#score 阈值,最后画出候选框时需要,>thresh才会被画出

NMS_THRESH = 0.3

#非极大值抑制的阈值,剔除重复候选框

for cls_ind, cls in enumerate(CLASSES[1:]):

#利用enumerate函数,获得CLASSES中 类别的下标cls_ind和类别名cls

cls_ind += 1 # because we skipped background

#将bbox,score 一起存入dets

cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]

# because we skipped background

cls_scores = scores[:, cls_ind]

#取出bbox ,score#将bbox,score 一起存入dets

dets = np.hstack((cls_boxes,

cls_scores[:, np.newaxis])).astype(np.float32)

keep = nms(dets, NMS_THRESH)

#进行非极大值抑制,得到抑制后的 dets

dets = dets[keep, :] #画框

vis_detections(im, cls, dets, thresh=CONF_THRESH)

def parse_args模块:解析命令行参数,得到gpu||cpu, net等。

def parse_args():

"""Parse input arguments."""

#创建解析对象

parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')

#parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

# choices=NETS.keys(), default='res101') #default

parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

choices=NETS.keys(), default='vgg16')

parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',

choices=DATASETS.keys(), default='pascal_voc_0712')

#调用parser.parse_args进行解析,返回带标注的args

args = parser.parse_args()

return args

主函数

if __name__ == '__main__':

cfg.TEST.HAS_RPN = True # Use RPN for proposals

#解析

args = parse_args()

#添加路径

cfg.USE_GPU_NMS = False

# model path

demonet = args.demo_net

dataset = args.dataset

tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',

NETS[demonet][0])

if not os.path.isfile(tfmodel + '.meta'):

raise IOError(('{:s} not found.\nDid you download the proper networks from '

'our server and place them properly?').format(tfmodel + '.meta'))

# set config

tfconfig = tf.ConfigProto(allow_soft_placement=True)

tfconfig.gpu_options.allow_growth=True

# init session

sess = tf.Session(config=tfconfig)

# load network

if demonet == 'vgg16':

net = vgg16()

elif demonet == 'res101':

net = resnetv1(num_layers=101)

else:

raise NotImplementedError

net.create_architecture("TEST", 21,

tag='default', anchor_scales=[8, 16, 32])

#用自己的数据集测试时,21根据classes类别数量修改

saver = tf.train.Saver()

saver.restore(sess, tfmodel)

print('Loaded network {:s}'.format(tfmodel))

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg'] #default

im_names = ['000000.jpg']

for im_name in im_names:

print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

print('Demo for data/demo/{}'.format(im_name))

demo(sess, net, im_name)

plt.show()

根据自己的数据集训练好模型后,要想运行Demo并将所有类别在同一图片显示,需要按照如下代码进行修改调整。

#!/usr/bin/env python

# --------------------------------------------------------

# Tensorflow Faster R-CNN

# Licensed under The MIT License [see LICENSE for details]

# Written by Xinlei Chen, based on code from Ross Girshick

# --------------------------------------------------------

"""

Demo script showing detections in sample images.

See README.md for installation instructions before running.

"""

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import _init_paths

from model.config import cfg

from model.test import im_detect

from model.nms_wrapper import nms

from utils.timer import Timer

import tensorflow as tf

import matplotlib.pyplot as plt

import numpy as np

import os, cv2

import argparse

from nets.vgg16 import vgg16

from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',

'"seaurchin"', '"scallop"', '"seacucumber"')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

#增加ax参数,即第4项

def vis_detections(im, class_name, dets, ax, thresh=0.5):

"""Draw detected bounding boxes."""

inds = np.where(dets[:, -1] >= thresh)[0]

if len(inds) == 0:

return

#注释原代码的以下三行

#im = im[:, :, (2, 1, 0)]

#fig, ax = plt.subplots(figsize=(12, 12))

#ax.imshow(im, aspect='equal')

for i in inds:

bbox = dets[i, :4]

score = dets[i, -1]

ax.add_patch(

plt.Rectangle((bbox[0], bbox[1]),

bbox[2] - bbox[0],

bbox[3] - bbox[1], fill=False,

#edgecolor='red', linewidth=3.5)

edgecolor='red', linewidth=1)

# 矩形线宽从3.5改为1,红框变细

)

ax.text(bbox[0], bbox[1] - 2,

'{:s} {:.3f}'.format(class_name, score),

bbox=dict(facecolor='blue', alpha=0.5),

fontsize=14, color='white')

ax.set_title(('{} detections with '

'p({} | box) >= {:.1f}').format(class_name, class_name,

thresh),

fontsize=14)

#注释原代码以下三行

#plt.axis('off')

#plt.tight_layout()

#plt.draw()

def demo(sess, net, image_name):

"""Detect object classes in an image using pre-computed object proposals."""

# Load the demo image

im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)

im = cv2.imread(im_file)

# Detect all object classes and regress object bounds

timer = Timer()

timer.tic()

scores, boxes = im_detect(sess, net, im)

timer.toc()

print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

# Visualize detections for each class

CONF_THRESH = 0.8

NMS_THRESH = 0.3

# 将vis_detections 函数中for 循环之前的3行代码移动到这里

im = im[:, :, (2, 1, 0)]

fig,ax = plt.subplots(figsize=(12, 12))

ax.imshow(im, aspect='equal')

for cls_ind, cls in enumerate(CLASSES[1:]):

cls_ind += 1 # because we skipped background

cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]

cls_scores = scores[:, cls_ind]

dets = np.hstack((cls_boxes,

cls_scores[:, np.newaxis])).astype(np.float32)

keep = nms(dets, NMS_THRESH)

dets = dets[keep, :]

#将ax做为参数传入vis_detections,即增加第4项

vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)

# 将vis_detections 函数中for 循环之后的3行代码移动到这里

plt.axis('off')

plt.tight_layout()

plt.draw()

def parse_args():

"""Parse input arguments."""

parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')

#parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

# choices=NETS.keys(), default='res101') #default

parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

choices=NETS.keys(), default='vgg16')

parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',

choices=DATASETS.keys(), default='pascal_voc_0712')

args = parser.parse_args()

return args

if __name__ == '__main__':

cfg.TEST.HAS_RPN = True # Use RPN for proposals

args = parse_args()

cfg.USE_GPU_NMS = False

# model path

demonet = args.demo_net

dataset = args.dataset

tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',

NETS[demonet][0])

if not os.path.isfile(tfmodel + '.meta'):

raise IOError(('{:s} not found.\nDid you download the proper networks from '

'our server and place them properly?').format(tfmodel + '.meta'))

# set config

tfconfig = tf.ConfigProto(allow_soft_placement=True)

tfconfig.gpu_options.allow_growth=True

# init session

sess = tf.Session(config=tfconfig)

# load network

if demonet == 'vgg16':

net = vgg16()

elif demonet == 'res101':

net = resnetv1(num_layers=101)

else:

raise NotImplementedError

net.create_architecture("TEST", 4,

tag='default', anchor_scales=[8, 16, 32])

#net.create_architecture第2个参数是需要识别的类别+1,本例有3个待识别物体,加background共计4

saver = tf.train.Saver()

saver.restore(sess, tfmodel)

print('Loaded network {:s}'.format(tfmodel))

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg'] #default

im_names = ['000337.jpg'] #测试的图片,保存在tf-faster-rcnn-contest/data/demo 路径

for im_name in im_names:

print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

print('Demo for data/demo/{}'.format(im_name))

demo(sess, net, im_name)

plt.show()

根据自己的数据集训练好模型后,要想运行demo.py批量处理测试图片,并将所有类别在同一图片显示,需要按照如下代码进行修改调整。

#!/usr/bin/env python

# --------------------------------------------------------

# Tensorflow Faster R-CNN

# Licensed under The MIT License [see LICENSE for details]

# Written by Xinlei Chen, based on code from Ross Girshick

# --------------------------------------------------------

"""

Demo script showing detections in sample images.

See README.md for installation instructions before running.

"""

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import _init_paths

from model.config import cfg

from model.test import im_detect

from model.nms_wrapper import nms

from utils.timer import Timer

import tensorflow as tf

import matplotlib.pyplot as plt

import numpy as np

import os, cv2

import argparse

from nets.vgg16 import vgg16

from nets.resnet_v1 import resnetv1

import scipy.io as sio

import os, sys, cv2

import argparse

import os

import numpy

from PIL import Image #导入Image模块

from pylab import * #导入savetxt模块

CLASSES = ('__background__',

'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, ax, thresh=0.5):

"""Draw detected bounding boxes."""

inds = np.where(dets[:, -1] >= thresh)[0]

if len(inds) == 0:

return

#im = im[:, :, (2, 1, 0)]

#fig, ax = plt.subplots(figsize=(12, 12))

#ax.imshow(im, aspect='equal')

for i in inds:

bbox = dets[i, :4]

score = dets[i, -1]

ax.add_patch(

plt.Rectangle((bbox[0], bbox[1]),

bbox[2] - bbox[0],

bbox[3] - bbox[1], fill=False,

#edgecolor='red', linewidth=3.5)

edgecolor='red', linewidth=1)

)

ax.text(bbox[0], bbox[1] - 2,

'{:s} {:.3f}'.format(class_name, score),

bbox=dict(facecolor='blue', alpha=0.5),

fontsize=14, color='white')

ax.set_title(('{} detections with '

'p({} | box) >= {:.1f}').format(class_name, class_name,

thresh),

fontsize=14)

#plt.axis('off')

#plt.tight_layout()

#plt.draw()

def demo(sess, net, image_name):

"""Detect object classes in an image using pre-computed object proposals."""

# Load the demo image

im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)

im = cv2.imread(im_file)

# Detect all object classes and regress object bounds

timer = Timer()

timer.tic()

scores, boxes = im_detect(sess, net, im)

timer.toc()

print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

save_jpg = os.path.join('/data/test',im_name)

# Visualize detections for each class

CONF_THRESH = 0.8

NMS_THRESH = 0.3

im = im[:, :, (2, 1, 0)]

fig,ax = plt.subplots(figsize=(12, 12))

ax.imshow(im, aspect='equal')

for cls_ind, cls in enumerate(CLASSES[1:]):

cls_ind += 1 # because we skipped background

cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]

cls_scores = scores[:, cls_ind]

dets = np.hstack((cls_boxes,

cls_scores[:, np.newaxis])).astype(np.float32)

keep = nms(dets, NMS_THRESH)

dets = dets[keep, :]

vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)

plt.axis('off')

plt.tight_layout()

plt.draw()

def get_imlist(path): # 此函数读取特定文件夹下的jpg格式图像

return [os.path.join(f) for f in os.listdir(path) if f.endswith('.jpg')]

def parse_args():

"""Parse input arguments."""

parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')

#parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

# choices=NETS.keys(), default='res101') #default

parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

choices=NETS.keys(), default='vgg16')

parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',

choices=DATASETS.keys(), default='pascal_voc_0712')

args = parser.parse_args()

return args

if __name__ == '__main__':

cfg.TEST.HAS_RPN = True # Use RPN for proposals

args = parse_args()

cfg.USE_GPU_NMS = False

# model path

demonet = args.demo_net

dataset = args.dataset

tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',

NETS[demonet][0])

if not os.path.isfile(tfmodel + '.meta'):

raise IOError(('{:s} not found.\nDid you download the proper networks from '

'our server and place them properly?').format(tfmodel + '.meta'))

# set config

tfconfig = tf.ConfigProto(allow_soft_placement=True)

tfconfig.gpu_options.allow_growth=True

# init session

sess = tf.Session(config=tfconfig)

# load network

if demonet == 'vgg16':

net = vgg16()

elif demonet == 'res101':

net = resnetv1(num_layers=101)

else:

raise NotImplementedError

net.create_architecture("TEST",5,

tag='default', anchor_scales=[8, 16, 32])

saver = tf.train.Saver()

saver.restore(sess, tfmodel)

print('Loaded network {:s}'.format(tfmodel))

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg'] #default

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg']

im_names = get_imlist(r"/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/demo")

print(im_names)

for im_name in im_names:

#path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"

#filelist = os.listdir(path)

#for im_name in path:

print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

print('Demo for data/demo/{}'.format(im_name))

demo(sess, net, im_name)

plt.savefig("testfigs/" + im_name)

#plt.show()

根据自己的数据集训练好模型后,要想运行demo.py批量处理测试图片,并按照格式输出信息,需要按照如下代码进行修改调整。

#!/usr/bin/env python

# --------------------------------------------------------

# Tensorflow Faster R-CNN

# Licensed under The MIT License [see LICENSE for details]

# Written by Xinlei Chen, based on code from Ross Girshick

# --------------------------------------------------------

"""

Demo script showing detections in sample images.

See README.md for installation instructions before running.

"""

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import _init_paths

from model.config import cfg

from model.test import im_detect

from model.nms_wrapper import nms

from utils.timer import Timer

import tensorflow as tf

import matplotlib.pyplot as plt

import numpy as np

import os, cv2

import os.path

import argparse

from nets.vgg16 import vgg16

from nets.resnet_v1 import resnetv1

import scipy.io as sio

import os, sys, cv2

import argparse

import os

import numpy

from PIL import Image #导入Image模块

from pylab import * #导入savetxt模块

CLASSES = ('__background__',

'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, thresh=0.5):

"""Draw detected bounding boxes."""

inds = np.where(dets[:, -1] >= thresh)[0]

if len(inds) == 0:

return

#im = im[:, :, (2, 1, 0)]

#fig, ax = plt.subplots(figsize=(12, 12))

#ax.imshow(im, aspect='equal')

# !/usr/bin/env python

# -*- coding: UTF-8 -*-

# --------------------------------------------------------

# Faster R-CNN

# Copyright (c) 2015 Microsoft

# Licensed under The MIT License [see LICENSE for details]

# Written by Ross Girshick

# --------------------------------------------------------

for i in inds:

bbox = dets[i, :4]

score = dets[i, -1]

if class_name == '__background__':

fw = open('result.txt', 'a') # 最终的txt保存在这个路径下,下面的都改

fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')

fw.close()

elif class_name == 'holothurian':

fw = open('result.txt', 'a') # 最终的txt保存在这个路径下,下面的都改

fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')

fw.close()

elif class_name == 'echinus':

fw = open('result.txt', 'a') # 最终的txt保存在这个路径下,下面的都改

fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')

fw.close()

elif class_name == 'scallop':

fw = open('result.txt', 'a') # 最终的txt保存在这个路径下,下面的都改

fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')

fw.close()

elif class_name == 'starfish':

fw = open('result.txt', 'a') # 最终的txt保存在这个路径下,下面的都改

fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')

fw.close()

def demo(sess, net, image_name):

"""Detect object classes in an image using pre-computed object proposals."""

# Load the demo image

im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)

im = cv2.imread(im_file)

# Detect all object classes and regress object bounds

timer = Timer()

timer.tic()

scores, boxes = im_detect(sess, net, im)

timer.toc()

print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

save_jpg = os.path.join('/data/test',im_name)

# Visualize detections for each class

CONF_THRESH = 0.8

NMS_THRESH = 0.3

#im = im[:, :, (2, 1, 0)]

#fig,ax = plt.subplots(figsize=(12, 12))

#ax.imshow(im, aspect='equal')

for cls_ind, cls in enumerate(CLASSES[1:]):

cls_ind += 1 # because we skipped background

cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]

cls_scores = scores[:, cls_ind]

dets = np.hstack((cls_boxes,

cls_scores[:, np.newaxis])).astype(np.float32)

keep = nms(dets, NMS_THRESH)

dets = dets[keep, :]

vis_detections(im, cls, dets,thresh=CONF_THRESH)

def parse_args():

"""Parse input arguments."""

parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')

#parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

# choices=NETS.keys(), default='res101') #default

parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',

choices=NETS.keys(), default='vgg16')

parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',

choices=DATASETS.keys(), default='pascal_voc_0712')

args = parser.parse_args()

return args

if __name__ == '__main__':

cfg.TEST.HAS_RPN = True # Use RPN for proposals

args = parse_args()

cfg.USE_GPU_NMS = False

# model path

demonet = args.demo_net

dataset = args.dataset

tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',

NETS[demonet][0])

if not os.path.isfile(tfmodel + '.meta'):

raise IOError(('{:s} not found.\nDid you download the proper networks from '

'our server and place them properly?').format(tfmodel + '.meta'))

# set config

tfconfig = tf.ConfigProto(allow_soft_placement=True)

tfconfig.gpu_options.allow_growth=True

# init session

sess = tf.Session(config=tfconfig)

# load network

if demonet == 'vgg16':

net = vgg16()

elif demonet == 'res101':

net = resnetv1(num_layers=101)

else:

raise NotImplementedError

net.create_architecture("TEST",5,

tag='default', anchor_scales=[8, 16, 32])

saver = tf.train.Saver()

saver.restore(sess, tfmodel)

print('Loaded network {:s}'.format(tfmodel))

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg'] #default

#im_names = ['000456.jpg', '000542.jpg', '001150.jpg',

# '001763.jpg', '004545.jpg']

im = 128 * np.ones((300, 500, 3), dtype=np.uint8)

for i in range(2):

_, _= im_detect(sess,net, im)

#im_names = get_imlist(r"/home/henry/Files/tf-faster-rcnn-contest/data/demo")

fr = open('/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt', 'r')

for im_name in fr:

#path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"

#filelist = os.listdir(path)

#for im_name in path:

im_name = im_name.strip('\n')

print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

print('Demo for data/demo/{}'.format(im_name))

demo(sess, net, im_name)

#plt.show()

fr.close

输出效果如下:

000646.jpg echinus 0.9531797 617 89 785 272

000646.jpg echinus 0.94367296 200 272 396 495

000646.jpg echinus 0.9090044 953 259 1112 443

000646.jpg scallop 0.8987418 1508 975 1580 1037

000646.jpg scallop 0.8006968 512 169 580 218

000646.jpg starfish 0.96790546 291 675 390 765

001834.jpg echinus 0.9706842 291 222 365 280

001834.jpg echinus 0.965007 511 161 588 229

001834.jpg echinus 0.95911396 2 184 136 283

4 知识点补充

argparse

argparse是python用于解析命令行参数和选项的标准模块,用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt output.txt –user=name –port=8080。

使用步骤:

1:import argparse

2:parser = argparse.ArgumentParser()

3:parser.add_argument()

4:parser.parse_args()

解释:首先导入该模块;然后创建一个解析对象;然后向该对象中添加你要关注的命令行参数和选项,每一个add_argument方法对应一个你要关注的参数或选项;最后调用parse_args()方法进行解析;

IoU非极大值抑制

IoU参考

faster rcnn接口_Faster R-CNN tensorflow代码详解相关推荐

  1. faster rcnn接口_Faster R-CNN教程

    Faster R-CNN教程 最后更新日期:2016年4月29日 本教程主要基于python版本的faster R-CNN,因为python layer的使用,这个版本会比matlab的版本速度慢10 ...

  2. cnn 预测过程代码_FPN的Tensorflow代码详解——特征提取

    @TOC   特征金字塔网络最早于2017年发表于CVPR,与Faster RCNN相比其在多池度特征预测的方式使得其在小目标预测上取得了较好的效果.FPN也作为mmdeteciton的Neck模块, ...

  3. fasterrcnn tensorflow代码详解_pytorch目标检测代码的一些bug调试

    这几天一直在做调包侠,是时候来总结总结了.记录一些我所遇到的不常见的问题. faster rcnn: 参考代码: jwyang/faster-rcnn.pytorch​github.com pytor ...

  4. Faster RCNN代码详解(五):关于检测网络(Fast RCNN)的proposal

    在Faster RCNN代码详解(二):网络结构构建中介绍了Faster RCNN算法的网络结构,其中有一个用于生成ROI proposal target的自定义层,该自定义层的输出作为检测网络(Fa ...

  5. 代码详解|tensorflow实现 聊天AI--PigPig养成记(1)

    Chapter1.代码详解 完整代码github链接,Untitled.ipynb文件内. [里面的测试是还没训练完的时候测试的,今晚会更新训练完成后的测试结果] 修复了网上一些代码的bug,解决了由 ...

  6. Android实战:CoolWeather酷欧天气(加强版数据接口)代码详解(上)

    -----------------------------------该文章代码已停更,可参考浩比天气(更新于2019/6/25)----------------------------------- ...

  7. Tensorflow官网CIFAR-10数据分类教程代码详解

    标题 概述 对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,本教程代码通过解决CIFAR-10数据分类任务,介绍了Tensorflow的一些高阶用法,演示了构建大型复杂模型的一些重 ...

  8. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  9. 目标检测Tensorflow:Yolo v3代码详解 (2)

    目标检测Tensorflow:Yolo v3代码详解 (2) 三.解析Dataset()数据预处理部分 四. 模型训练 yolo_train.py 五. 模型冻结 model_freeze.py 六. ...

最新文章

  1. mysql使用CONCAT()函数拼接字符串
  2. ASP防止SQL注入
  3. html如何设置滑轮效果,HTML中鼠标滚轮事件onmousewheel处理
  4. maven多模块项目部署到服务器,GitHub - baxias/foweb: 一个基于 Spring+SpringMVC+Mybatis 的Maven多模块项目。(实现前后端分离的服务器端)...
  5. Ajax跨域请求与解决方案
  6. [专栏精选]Unity中编码Encoding脱坑指南
  7. C 输入 输出——Day03
  8. 多任务计时器anytime
  9. 中兴v5max android5.1,中兴V5Max
  10. 定时下载快速精密星历
  11. 对百度输入法小米版的用户体验
  12. 大盘点|三维视觉与自动驾驶数据集(40个)
  13. UNISON文件同步
  14. 使用 In-Trangle Test 检测极点
  15. JQuery UI combogrid
  16. 支付宝生活号开发中所遇到的困难及解决记录
  17. 【二】头歌平台实验-离散数学逻辑与推理
  18. 前端企业微信开发内嵌H5记录
  19. ThingsBoard中的关系Relation
  20. 使用ggplot2画 点图、箱线图、小提琴图、蜂窝图、云雨图

热门文章

  1. 参展动态 | 璞华受邀出席第七届电气化交通前沿技术论坛展会
  2. Maven项目中出现红色波浪线的解决过程
  3. Graph-Based Object Classification for Neuromorphic Vision Sensing 论文解读
  4. 振动论坛---MATLAB
  5. 分子动力学论文--算法和参数设置1
  6. Python超声波测距仪制作教程
  7. 人以群分c\c++(社交网络中我们给每个人定义了一个“活跃度”)
  8. R语言ggplot2可视化分面图(facet_wrap)、使用lineheight参数自定义设置分面图标签栏(灰色标签栏)的高度
  9. 一线明星 A-lister
  10. Android APP智能控制设备