caffe实战之classify.py解析
本文将对caffe/python下的classify.py代码以及相关的classifier.py和io.py进行解析。
一、classify.py
由最后的if__name__ == '__main__': main(sys.argv)代表该文件在命令行下运行,则运行main函数,参数存放在sys.argv中。在main函数定义中,分别判断并存入各类参数,分别如下:
input_file:输入图像,参数为必需。
output_file:输出文件,参数为必需。
--model_def:网络测试结构文件,默认为imagenet的deploy.txt文件
--pretrained_model:网络参数文件,默认为imagenet的bvlc_reference_caffenet.caffemodel文件。
--gpu:是否用gpu计算,action=’store true’表示如果不指定,则默认false,用cpu,否则为true,用gpu。对于一张128*128的灰度图像,cpu前向计算大概20ms,而gpu仅5ms左右。
--center_only:默认false,即对输入图像的裁剪图像做预测,然后将结果进行平均;指定为true,即只取输入图像的中间部分做一次预测。当然,如果指定输入图像和裁剪尺寸一致,那么取中间部分即为原图本身。
--images_dim:输入图像尺寸,只考虑高和宽,默认256*256。
--mean_file:均值文件。注意数据格式是npy文件,即存储为numpy.array格式,维度为(通道,高,宽)。如果仅有通过compute_mean.bin计算的均值文件,需要进行转化。默认均值文件为imagenet的ilsvrc_2012_mean.npy文件。
--input_scale:图像预处理后的缩放系数,发生在减去均值后,默认为1。
--raw_scale:图像预处理前的缩放系数,发生在减去均值前。由于读入的像素值在[0,1]区间,则默认为255.0,使像素在[0,255]区间。
--channel_swap:通道调整,默认为’2,1,0’,因为caffe通过opencv读入的图片通道为BGR,因此必须将RGB-->BGR,即第0个通道和第2个通道交换。
--ext:默认’jpg’,代表如果输入指定为目录,则仅读取后缀名为jpg的文件。
下面几个参数是改进版classify.py中新加的。
--labels_file:标签类别文件,默认为imagenet的synset_words.txt文件。
--print_results:是否打印结果到屏幕,不指定则false,指定为true。
--force_grayscale:是否指定输入为单通道图像,不指定则false,指定为true。
通过args= parser.parse_args()更新,确认最终输入的参数。下面进行分类测试:
# 列表生成式,通过逗号划分维度字符串,并强制转化为int类型。最后为列表。
image_dims = [int(s) for s inargs.images_dim.split(',')]
# 如果指定了均值文件,则加载均值文件
if args.mean_file:
mean =np.load(args.mean_file)
# 如果是灰度图像,则没有通道交换。如果是rgb图像,如果有通道交换,通过逗号划分字符串,强制转化为int类型,存到列表中。
if args.force_grayscale:
channel_swap = None
else:
if args.channel_swap
channel_swap = [int(s) for s inargs.channel_swap.split(',')]
# 如果指定了gpu,则启动gpu模式
if args.gpu:
caffe.set_mode_gpu()
print("GPU mode")
else:
caffe.set_mode_cpu()
print("CPU mode")
# 初始化分类器,见classifier.py
classifier = caffe.Classifier(..)
# 下面就是读取文件的代码,有反映说加载灰度图像会报错的情况,这里给出记载灰度和rgb图像的代码。
if args.force_grayscale:
# 这里的false代表返回单通道图像,见io.py
inputs =[caffe.io.load_image(args.input_file, False)]
else:
inputs = [caffe.io.load_image(args.input_file)]
# inputs用[]括起来,代表用列表存储,所以len(inputs)代表有多少张输入图像。
# 计时,这里以ms为单位
start = time.time() * 1000
# 前向计算,见classifier.py,得到preditions为np数组,行为输入图像张数,列为预测总类别数目
predictions = classifier.predict(inputs,not args.center_only)
print("Done in %.2f ms." %(time.time() * 1000 - start))
print("Predictions : %s"% predictions)
# 打印结果,根据得分排序,给出分数较高的前五类,类名称由labels_file指定。
# print result, add by caisenchuan
if args.print_results:
...
二、Classifier.py
该文件定义了classifier类,包括了初始化函数__init__和predict函数。
1、 __init__:
首先调用了caffe类的初始化函数,并设定了test模式。
接着调用了transformer类,以cifar-10为例,输入为字典{’data’: (1,3,32,32)}。
然后是set_transpose方法:
# 将维度从(32,32,3)转化为(3,32,32),适用于caffe中的处理
self.transformer.set_transpose(in_,(2, 0, 1))
然后调用transformer类的set方法,设置各种参数,具体见下文io.py中的解析。
最后,关于图像维度的定义:
# 裁剪尺寸根据prototxt定义
self.crop_dims =np.array(self.blobs[in_].data.shape[2:])
# 如果没有定义图片尺寸参数,则等于裁剪的尺寸;否则按定义的来
# 一般来说,如果用了裁剪,则图像尺寸>裁剪尺寸
if not image_dims:
image_dims = self.crop_dims
self.image_dims = image_dims
2、 predict:
执行前向计算,预测图像分类的概率。参数为输入以及是否过采样的布尔值。
# 定义inputs_维度(m,h,w,channel)
input_ = np.zeros((len(inputs),
self.image_dims[0],
self.image_dims[1],
inputs[0].shape[2]),
dtype=np.float32)
# 将所有待分类尺寸统一为image_dims尺寸
for ix, in_ in enumerate(inputs):
input_[ix] = caffe.io.resize_image(in_,self.image_dims)
# 如果过采样,则每张图像通过裁剪生成10张图像
# 维度将变为(10*m,h,w,channel)
if oversample:
# Generate center, corner, and mirroredcrops.
input_ = caffe.io.oversample(input_,self.crop_dims)
# 否则,裁剪中心区域。取图像尺寸的中点,然后分别往上往下取裁剪的尺寸长度。
# 以64*64裁剪32*32为例,(64,64)取中点-->(32,32),扩充到四个坐标-->(32,32,32,32),
# 取裁剪尺寸(32,32,32,32)+(-16,-16,16,16)-->(16,16,48,48)
else:
# Take center crop.
center = np.array(self.image_dims) /2.0
crop = np.tile(center, (1, 2))[0] +np.concatenate([
-self.crop_dims / 2.0,
self.crop_dims / 2.0
])
crop = crop.astype(int)
input_ = input_[:, crop[0]:crop[2],crop[1]:crop[3], :]
# 将输入转化为caffe需要的格式,维度变为(m,channel,h,w)
caffe_in =np.zeros(np.array(input_.shape)[[0, 3, 1, 2]],
dtype=np.float32)
# 每张图片都进行预处理,见io.py的preprocess函数
for ix, in_ in enumerate(input_):
caffe_in[ix] =self.transformer.preprocess(self.inputs[0], in_)
# 前向计算,输出为字典,out[‘prob’]为各类概率
out =self.forward_all(**{self.inputs[0]: caffe_in})
predictions = out[self.outputs[0]]
# 如果过采样,需要对每10个预测结果进行平均
if oversample:
predictions =predictions.reshape((len(predictions) / 10, 10, -1))
predictions = predictions.mean(1)
# 返回结果
return predictions
三、io.py
该文件重点介绍预处理类Transformer的成员函数。
1、preprocess
注意到函数的注释部分表明了预处理的全部流程,包括:
转化为单精度;
resize到统一尺寸;
维度转化为(channel,h,w);
通道交换,转化为BGR;
减去均值前缩放;
减去均值;
减去均值后缩放。
重要代码:
...
# 返回[h,w]
in_dims = self.inputs[in_][2:]
# 输入图像和规定尺寸不一样,则resize统一
if caffe_in.shape[:2] != in_dims:
caffe_in = resize_image(caffe_in, in_dims)
# 维度转化
if transpose is not None:
caffe_in = caffe_in.transpose(transpose)
# 通道交换,指的是channel的交换,h和w不变
if channel_swap is not None:
caffe_in = caffe_in[channel_swap, :, :]
# 乘法
if raw_scale is not None:
caffe_in *= raw_scale
# 减法
if mean is not None:
caffe_in -= mean
# 乘法
if input_scale is not None:
caffe_in *= input_scale
return caffe_in
2、load_image,注意color参数默认True
# 利用skimage工具读入图片,默认读入彩色图片,如果as_grey为1,则读入灰度图片;读入值为[0,1]的浮点数
img =skimage.img_as_float(skimage.io.imread(filename,
as_grey=notcolor)).astype(np.float32)
# 保证返回的是三维的数组。
if img.ndim == 2:
# 如果读入只有二维,需要增加维度
img = img[:, :, np.newaxis]
if color:
# 如果是灰度图像却以彩色图片读入,那么扩充为三个通道
img = np.tile(img, (1, 1, 3))
elif img.shape[2] == 4:
#如果有四个通道,去掉第四个通道
img= img[:, :, :3]
# 返回(h,w,3)的数组
return img
还有resize_image,oversample,以及各种set函数,这里就不一一介绍了。所谓caffe的python接口或者matlab接口,都是对caffe的输入预处理以及输出结果的处理,而不用理会网络计算的中间过程。
caffe实战之classify.py解析相关推荐
- Java生鲜电商平台-电商中海量搜索ElasticSearch架构设计实战与源码解析
Java生鲜电商平台-电商中海量搜索ElasticSearch架构设计实战与源码解析 生鲜电商搜索引擎的特点 众所周知,标准的搜索引擎主要分成三个大的部分,第一步是爬虫系统,第二步是数据分析,第三步才 ...
- YOLOv3 代码详解(2) —— 数据处理 dataset.py解析:输入图片增强、制作模型的每层输出的标签
前言: yolo系列的论文阅读 论文阅读 || 深度学习之目标检测 重磅出击YOLOv3 论文阅读 || 深度学习之目标检测yolov2 论文阅读 || 深度学习之目标检测yolov1 该篇讲解的 ...
- SpringCloud之Eureka实战和架构设计解析
SpringCloud之Eureka实战和架构设计解析 Netflix Eureka(后文简称Eureka)是由Netflix开源的一款基于REST的服务发现组件,包括Eureka Server及Eu ...
- 实战 webpack 4 配置解析四
接上篇: 实战 webpack 4 配置解析三 WEBPACK.PROD.JS 解析 现在让我们看看我们的 webpack.prod.js 配置文件,它包含了我们正在处理项目时用于生产构建的所有设置. ...
- 机器学习实战 支持向量机SVM 代码解析
机器学习实战 支持向量机SVM 代码解析 <机器学习实战>用代码实现了算法,理解源代码更有助于我们掌握算法,但是比较适合有一定基础的小伙伴.svm这章代码看起来风轻云淡,实则对于新手来说有 ...
- 4大行业实战案例,深度解析数字化转型升级路径
本篇文章为亿信华辰<4大行业实战案例,深度解析数字化转型升级路径>视频直播稿件. 大家晚上好,欢迎来到小亿直播间!今天主讲的内容是以4个行业的典型应用为背景,给大家讲讲数字化转型的项目是 ...
- 实战 webpack 4 配置解析一
配置 github 仓库:https://github.com/nystudio107/annotated-webpack-4-config 随着Web开发变得越来越复杂,我们需要工具来帮助我们构建现 ...
- PASCAL VOC的评估代码voc_eval.py解析
参考 PASCAL VOC的评估代码voc_eval.py解析 - 云+社区 - 腾讯云 目录 1.读检测的结果 2.解析一幅图像中的目标数 3.计算AP 4.VOC的评估 5.进行python评估 ...
- MMDetection框架的anchor_generators.py解析与船数据解析
anchor_generators.py解析 import mmcv import numpy as np import torch from torch.nn.modules.utils impor ...
最新文章
- 在CentOS 6.3 64bit上安装Apache Trafficserver 4.2.3
- 二叉树的遍历 (递归和非递归实现)
- 3d geometric model website http://www.cse.ohio-state.edu/~tamaldey/
- python 目标检测 训练_YOLOv3目标检测有了TensorFlow实现,可用自己的数据来训练
- openstarck安装指南(图文详解,超小白版本)
- Directx11教程(47) alpha blend(4)-雾的实现
- Eclipse正式代替Oracle接管Java EE
- Android 4.0 开机启动广播
- AR/VR研究框架——迎接AR元年
- 5G网络规划解决方案
- 分享一个好的清理系统垃圾软件
- 计算机改名字sql2008不能登录,Win7电脑修改计算机名称后SQL2008数据库无法登录提示无法连接到load怎么处理...
- 相比传统监控,智慧门店的摄像机有多“能干”
- Team System:基本 Power Tool 工具。
- Android studio成品 记账本(附带文档)
- javascript 代码中的“use strict“;是什么意思
- 【MySQL】创建数据库表
- CEST日期格式转换为 年月日时分秒
- 解决:Win11蓝牙鼠标经常断连问题(亲测有效)
- linux find之exec用法
热门文章
- Java的注解Annotation
- 鸿蒙开发访问webapi,Web API接口
- 百度降权一个月排名怎么恢复?
- 树莓派更换pip源为国内
- 一个随机数引发的血案
- 后端开发学习日志(基础篇)
- 优化Feed流遭遇拦路虎,是谁帮百度打破了“内存墙”?
- 数据分析岗笔试卷——目录索引
- 干货!高德、VPGAME(老干爹)等MongoDB应用实践(暨MongoDB杭州用户会成立
- pandas 排序 给excel_给Excel重度用户准备的Pandas教程:用Pandas逐帧还原20个Excel常用操作...