caffe 利用Python API 做数据输入层
caffe (Convolutional Architecture for Fast Feature Embedding)
在caffe中,主要使用LMDB提供数据管理,将形形色色的原始数据转换为统一的Key-Value形式存储,便于数据输入层获得这些数据,而且提高了磁盘IO的利用率。
但是,有时我们可以使用python作为网络结构数据的输入层,毕竟python 简单易写。
参考网址:https://chrischoy.github.io/research/caffe-python-layer/
编译选择
想要使用Python Layer,我们需要在编译的时候修改Makefile.config中的pyhon 的选项。
WITH_PYTHON_LAYER :=1
修改完之后,重新编译:
#在caffe_root依次运行
make clean
make
make pycaffe
如果还是 import caffe错误,则将以下几句加入到Python文件的最前面,再导入caffe库:
import sys
sys.path.append("/home/yonghuming/caffe-master/python")
sys.path.append("/home/yonghuming/caffe-master/python/caffe")
或者是修改全局变量:
sudo vim /etc/profile
添加全局变量:
export PYTHONPATH=/home/yonghu/caffe/python
最后:source /etc/profile
先看一个官方例程
此历程所在位置为$caffe_root/examples/pycaffe下
配置文件linreg.prototxt
layer {type: 'Python'name: 'loss'top: 'loss'bottom: 'ipx'bottom: 'ipy'python_param {# the module name -- usually the filename -- that needs to be in $PYTHONPATHmodule: 'pyloss'# the layer name -- the class name in the modulelayer: 'EuclideanLossLayer'}# set loss weight so Caffe knows this is a loss layer.# since PythonLayer inherits directly from Layer, this isn't automatically# known to Caffeloss_weight: 1
}
此配置文件中的loss层就是用Python写的层。
module: 'pyloss'layer: 'EuclideanLossLayer'
module指的是python文件的名字;layer值得是Python文件中的类。
# the module name -- usually the filename -- that needs to be in $PYTHONPATH
其中$PYTHONPATH指的是$caffe_root/python , 只要把python文件放到此目录下就可以了。
再看一下pyloss.py文件
import caffe
import numpy as npclass EuclideanLossLayer(caffe.Layer):"""Compute the Euclidean Loss in the same manner as the C++ EuclideanLossLayerto demonstrate the class interface for developing layers in Python."""def setup(self, bottom, top):# check input pairif len(bottom) != 2:raise Exception("Need two inputs to compute distance.")def reshape(self, bottom, top):# check input dimensions matchif bottom[0].count != bottom[1].count:raise Exception("Inputs must have the same dimension.")# difference is shape of inputsself.diff = np.zeros_like(bottom[0].data, dtype=np.float32)# loss output is scalartop[0].reshape(1)def forward(self, bottom, top):self.diff[...] = bottom[0].data - bottom[1].datatop[0].data[...] = np.sum(self.diff**2) / bottom[0].num / 2.def backward(self, top, propagate_down, bottom):for i in range(2):if not propagate_down[i]:continueif i == 0:sign = 1else:sign = -1bottom[i].diff[...] = sign * self.diff / bottom[i].num
使用Python layer 做数据输入层
此处我导入的数据为28*28*6的数组,也就是数据类型为:
print type(data)
print np.shape(data)# <type 'numpy.ndarray'>
# (28, 28, 6)
图片也是类似,利用opencv读入的图片也是数组结构:
import numpy as np
import cv2
data = cv2.imread('1.jpg')
print type(data)
#<type 'numpy.ndarray'>
print np.shape(data)
# <type 'numpy.ndarray'>
# (375, 500, 3)
但是一定要注意三色通道问题!
此时我的配置文件中有关Python层的定义:
name: "LeNet"
layer {name:"Data"type: "Python"top: "data"top: "label"include {phase: TRAIN}python_param {module:"dataLayer"layer: "Custom_Data_Layer"param_str: '{"batch_size":64, "im_shape":28, "src_file":"data/input/train"}'}
}
param_str为参数。
此时py文件:
import caffe
import numpy as np
import os
import random
def GetTupleList(src_file, dirtag):subDirTuples = []folder = os.path.join(src_file, dirtag)fns = os.listdir(folder)if dirtag=='pos':tag = 0elif dirtag=='neg':tag = 1else:raise Exception('Invalid dirtag {}'.format(str(dirtag)))for fn in fns:path = os.path.join(folder, fn)data = np.load(path)subDirTuples.append((data, np.array([tag])))return subDirTuplesdef readSrcFile(src_file):posTuples = GetTupleList(src_file, 'pos')print(len(posTuples))negTuples = GetTupleList(src_file, 'neg')print(len(negTuples))imgTuples = posTuples + negTuplesreturn imgTuplesclass Custom_Data_Layer(caffe.Layer):def setup(self, bottom, top):# Check top shapeif len(top) != 2:raise Exception("Need to define tops (data and label1)")# Check bottom shapeif len(bottom) != 0:raise Exception("Do not define a bottom")# Read parametersparams = eval(self.param_str)src_file = params["src_file"]self.batch_size = params["batch_size"]self.im_shape = params["im_shape"]top[0].reshape(self.batch_size, 6, self.im_shape, self.im_shape)top[1].reshape(self.batch_size, 1)self.imgTuples = readSrcFile(src_file)self._cur = 0 # use this to check if we need to restart the list of imagesdef forward(self, bottom, top):for itt in range(self.batch_size):# Use the batch loader to load the next imageim, label = self.load_next_image()# Here we could preprocess the image# Add directly to the top blobim_data = np.reshape(im, (6,28,28))#注意!!np.reshape()top[0].data[itt, ...] = im_data top[1].data[itt, ...] = labeldef load_next_image(self):# If we have finished forwarding all images, then an epoch has finished# and it is time to start a new oneif self._cur == len(self.imgTuples):self._cur = 0random.shuffle(self.imgTuples)im, label = self.imgTuples[self._cur]self._cur += 1return im, labeldef reshape(self, bottom, top):"""There is no need to reshape the data, since the input is of fixed size(img shape and batch size)"""passdef backward(self, bottom, top):"""This layer does not back propagate"""pass
完成!
caffe 利用Python API 做数据输入层相关推荐
- 学习Python:做数据科学还是网站开发?
本文的英文原文地址是:Python for Data Science vs Python for Web Development 译者:EarlGrey@codingpy 译者一开始在Python日报 ...
- python怎么读取sav格式_利用Python读取外部数据文件
利用Python读取外部数据文件 [color=rgb(0, 0, 0) !important]刘顺祥 [color=rgb(0, 0, 0) !important]摘要: 不论是数据分析,数据可视化 ...
- python数据预测_利用Python编写一个数据预测工具
利用Python编写一个数据预测工具 发布时间:2020-11-07 17:12:20 来源:亿速云 阅读:96 这篇文章运用简单易懂的例子给大家介绍利用Python编写一个数据预测工具,内容非常详细 ...
- python学习音频-机器学习利用Python进行音频数据增强
2019-09-24 机器学习利用Python进行音频数据增强 数据增强通常用于机器学习和深度学习,以便在训练产生大量数据后获得良好的性能. 在这篇文章中,我将展示如何用一些音频增强技术使用输入音频文 ...
- python能处理nc文件吗_利用python如何处理nc数据详解
前言 这两天帮一个朋友处理了些 nc 数据,本以为很简单的事情,没想到里面涉及到了很多的细节和坑,无论是"知难行易"还是"知易行难"都不能充分的说明问题,还是& ...
- python处理nc数据_利用python如何处理nc数据详解
利用python如何处理nc数据详解 来源:中文源码网 浏览: 次 日期:2018年9月2日 [下载文档: 利用python如何处理nc数据详解.txt ] (友情提示:右键点上行txt ...
- python json 转csv_利用python将json数据转换为csv格式的方法
假设.json文件中存储的数据为: {"type": "Point", "link": "http://www.dianping. ...
- Java学习笔记之[ 利用扫描仪Scanner进行数据输入 ]
/*********数据的输入********/ /**利用扫描仪Scanner进行数据输入 怎么使用扫描仪Scanner *1.放在类声明之前,引入扫描仪 import java.util.Scan ...
- 利用Python读取外部数据文件
不论是数据分析,数据可视化,还是数据挖掘,一切的一切全都是以数据作为最基础的元素.利用Python进行数据分析,同样最重要的一步就是如何将数据导入到Python中,然后才可以实现后面的数据分析.数据可 ...
- 初学者笔记(三):利用python列表做一个最简单的垃圾分类
系列文章目录 初学者笔记(一):利用python求100的因数 初学者笔记(二):利用python输出一个1-100的奇数列表 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目 ...
最新文章
- SQLI DUMB SERIES-5
- ajax中的同步异步
- 42岁老码农找工作记录
- 【转载】VC遍历文件夹下所有文件和文件夹
- 在idea配置jetty和创建(包、文件)javaWeb以及Servlet简单实现
- word文档压缩文件大小
- 【CTF】明御攻防实验平台 crypto 鸡藕椒盐味 wp--海明校验码
- 用python解决放苹果问题_放苹果问题(组合数学经典)
- python列表原地交换nums[i], nums[nums[i]] = nums[nums[i]], nums[i]的解决方法
- Bsgrid表格插入日期表头
- 数据挖掘之关联规则挖掘的一些定义
- Android-实现图文混排编辑
- 使用pyecharts绘制中国历代都城的分布图
- 机器学习模型的集成方法总结:Bagging, Boosting, Stacking, Voting, Blending
- 基于django的微信小程序搭建
- 算法——求某个数的质因数
- p什么水管_pvc管规格-P是什么意思pvc给水管规格中T?pvc给 – 手机爱问
- Go-ICP: A Globally Optimal Solutionto 3D ICP Point-Set Registration(2016)
- 第一章 eNSP学习(1-5)
- 【Python数据分析学习实例】对学生成绩单和信息进行整合以及数据分析
热门文章
- 利用openssl进行base64的编码与解码
- JAVA JDK 、Maven、IDEA安装
- Linux服务器---配置apache支持用户认证
- 非阻塞IO发送http请求
- 数据结构:二维ST表
- Java:下拉列表绑定后台数据
- POJ 3669 简单BFS
- js 去空格 和 获得字节数
- nw.js桌面软件开发系列 第0.1节 HTML5和桌面软件开发的碰撞
- enumerateObjectsUsingBlock 、for 、for(... in ...) 的区别 性能测试