caffe (Convolutional Architecture for Fast Feature Embedding)

在caffe中,主要使用LMDB提供数据管理,将形形色色的原始数据转换为统一的Key-Value形式存储,便于数据输入层获得这些数据,而且提高了磁盘IO的利用率。
但是,有时我们可以使用python作为网络结构数据的输入层,毕竟python 简单易写。

参考网址:https://chrischoy.github.io/research/caffe-python-layer/

编译选择

想要使用Python Layer,我们需要在编译的时候修改Makefile.config中的pyhon 的选项。

WITH_PYTHON_LAYER :=1

修改完之后,重新编译:

#在caffe_root依次运行
make clean
make
make pycaffe

如果还是 import caffe错误,则将以下几句加入到Python文件的最前面,再导入caffe库:

import sys
sys.path.append("/home/yonghuming/caffe-master/python")
sys.path.append("/home/yonghuming/caffe-master/python/caffe")

或者是修改全局变量:

sudo vim /etc/profile

添加全局变量:

export PYTHONPATH=/home/yonghu/caffe/python

最后:source /etc/profile

先看一个官方例程

此历程所在位置为$caffe_root/examples/pycaffe下
配置文件linreg.prototxt

layer {type: 'Python'name: 'loss'top: 'loss'bottom: 'ipx'bottom: 'ipy'python_param {# the module name -- usually the filename -- that needs to be in $PYTHONPATHmodule: 'pyloss'# the layer name -- the class name in the modulelayer: 'EuclideanLossLayer'}# set loss weight so Caffe knows this is a loss layer.# since PythonLayer inherits directly from Layer, this isn't automatically# known to Caffeloss_weight: 1
}

此配置文件中的loss层就是用Python写的层。

 module: 'pyloss'layer: 'EuclideanLossLayer'

module指的是python文件的名字;layer值得是Python文件中的类。

    # the module name -- usually the filename -- that needs to be in $PYTHONPATH

其中$PYTHONPATH指的是$caffe_root/python , 只要把python文件放到此目录下就可以了。
再看一下pyloss.py文件

import caffe
import numpy as npclass EuclideanLossLayer(caffe.Layer):"""Compute the Euclidean Loss in the same manner as the C++ EuclideanLossLayerto demonstrate the class interface for developing layers in Python."""def setup(self, bottom, top):# check input pairif len(bottom) != 2:raise Exception("Need two inputs to compute distance.")def reshape(self, bottom, top):# check input dimensions matchif bottom[0].count != bottom[1].count:raise Exception("Inputs must have the same dimension.")# difference is shape of inputsself.diff = np.zeros_like(bottom[0].data, dtype=np.float32)# loss output is scalartop[0].reshape(1)def forward(self, bottom, top):self.diff[...] = bottom[0].data - bottom[1].datatop[0].data[...] = np.sum(self.diff**2) / bottom[0].num / 2.def backward(self, top, propagate_down, bottom):for i in range(2):if not propagate_down[i]:continueif i == 0:sign = 1else:sign = -1bottom[i].diff[...] = sign * self.diff / bottom[i].num

使用Python layer 做数据输入层

此处我导入的数据为28*28*6的数组,也就是数据类型为:

print type(data)
print np.shape(data)# <type 'numpy.ndarray'>
# (28, 28, 6)

图片也是类似,利用opencv读入的图片也是数组结构:

import numpy as np
import cv2
data = cv2.imread('1.jpg')
print type(data)
#<type 'numpy.ndarray'>
print np.shape(data)
# <type 'numpy.ndarray'>
# (375, 500, 3)

但是一定要注意三色通道问题!
此时我的配置文件中有关Python层的定义:

name: "LeNet"
layer {name:"Data"type: "Python"top: "data"top: "label"include {phase: TRAIN}python_param {module:"dataLayer"layer: "Custom_Data_Layer"param_str: '{"batch_size":64, "im_shape":28, "src_file":"data/input/train"}'}
}

param_str为参数。
此时py文件:

import caffe
import numpy as np
import os
import random
def GetTupleList(src_file, dirtag):subDirTuples = []folder = os.path.join(src_file, dirtag)fns = os.listdir(folder)if dirtag=='pos':tag = 0elif dirtag=='neg':tag = 1else:raise Exception('Invalid dirtag {}'.format(str(dirtag)))for fn in fns:path = os.path.join(folder, fn)data = np.load(path)subDirTuples.append((data, np.array([tag])))return subDirTuplesdef readSrcFile(src_file):posTuples = GetTupleList(src_file, 'pos')print(len(posTuples))negTuples = GetTupleList(src_file, 'neg')print(len(negTuples))imgTuples = posTuples + negTuplesreturn imgTuplesclass Custom_Data_Layer(caffe.Layer):def setup(self, bottom, top):# Check top shapeif len(top) != 2:raise Exception("Need to define tops (data and label1)")# Check bottom shapeif len(bottom) != 0:raise Exception("Do not define a bottom")# Read parametersparams = eval(self.param_str)src_file = params["src_file"]self.batch_size = params["batch_size"]self.im_shape = params["im_shape"]top[0].reshape(self.batch_size, 6, self.im_shape, self.im_shape)top[1].reshape(self.batch_size, 1)self.imgTuples = readSrcFile(src_file)self._cur = 0 # use this to check if we need to restart the list of imagesdef forward(self, bottom, top):for itt in range(self.batch_size):# Use the batch loader to load the next imageim, label = self.load_next_image()# Here we could preprocess the image# Add directly to the top blobim_data = np.reshape(im, (6,28,28))#注意!!np.reshape()top[0].data[itt, ...] = im_data top[1].data[itt, ...] = labeldef load_next_image(self):# If we have finished forwarding all images, then an epoch has finished# and it is time to start a new oneif self._cur == len(self.imgTuples):self._cur = 0random.shuffle(self.imgTuples)im, label = self.imgTuples[self._cur]self._cur += 1return im, labeldef reshape(self, bottom, top):"""There is no need to reshape the data, since the input is of fixed size(img shape and batch size)"""passdef backward(self, bottom, top):"""This layer does not back propagate"""pass

完成!

caffe 利用Python API 做数据输入层相关推荐

  1. 学习Python:做数据科学还是网站开发?

    本文的英文原文地址是:Python for Data Science vs Python for Web Development 译者:EarlGrey@codingpy 译者一开始在Python日报 ...

  2. python怎么读取sav格式_利用Python读取外部数据文件

    利用Python读取外部数据文件 [color=rgb(0, 0, 0) !important]刘顺祥 [color=rgb(0, 0, 0) !important]摘要: 不论是数据分析,数据可视化 ...

  3. python数据预测_利用Python编写一个数据预测工具

    利用Python编写一个数据预测工具 发布时间:2020-11-07 17:12:20 来源:亿速云 阅读:96 这篇文章运用简单易懂的例子给大家介绍利用Python编写一个数据预测工具,内容非常详细 ...

  4. python学习音频-机器学习利用Python进行音频数据增强

    2019-09-24 机器学习利用Python进行音频数据增强 数据增强通常用于机器学习和深度学习,以便在训练产生大量数据后获得良好的性能. 在这篇文章中,我将展示如何用一些音频增强技术使用输入音频文 ...

  5. python能处理nc文件吗_利用python如何处理nc数据详解

    前言 这两天帮一个朋友处理了些 nc 数据,本以为很简单的事情,没想到里面涉及到了很多的细节和坑,无论是"知难行易"还是"知易行难"都不能充分的说明问题,还是& ...

  6. python处理nc数据_利用python如何处理nc数据详解

    利用python如何处理nc数据详解 来源:中文源码网    浏览: 次    日期:2018年9月2日 [下载文档:  利用python如何处理nc数据详解.txt ] (友情提示:右键点上行txt ...

  7. python json 转csv_利用python将json数据转换为csv格式的方法

    假设.json文件中存储的数据为: {"type": "Point", "link": "http://www.dianping. ...

  8. Java学习笔记之[ 利用扫描仪Scanner进行数据输入 ]

    /*********数据的输入********/ /**利用扫描仪Scanner进行数据输入 怎么使用扫描仪Scanner *1.放在类声明之前,引入扫描仪 import java.util.Scan ...

  9. 利用Python读取外部数据文件

    不论是数据分析,数据可视化,还是数据挖掘,一切的一切全都是以数据作为最基础的元素.利用Python进行数据分析,同样最重要的一步就是如何将数据导入到Python中,然后才可以实现后面的数据分析.数据可 ...

  10. 初学者笔记(三):利用python列表做一个最简单的垃圾分类

    系列文章目录 初学者笔记(一):利用python求100的因数 初学者笔记(二):利用python输出一个1-100的奇数列表 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目 ...

最新文章

  1. SQLI DUMB SERIES-5
  2. ajax中的同步异步
  3. 42岁老码农找工作记录
  4. 【转载】VC遍历文件夹下所有文件和文件夹
  5. 在idea配置jetty和创建(包、文件)javaWeb以及Servlet简单实现
  6. word文档压缩文件大小
  7. 【CTF】明御攻防实验平台 crypto 鸡藕椒盐味 wp--海明校验码
  8. 用python解决放苹果问题_放苹果问题(组合数学经典)
  9. python列表原地交换nums[i], nums[nums[i]] = nums[nums[i]], nums[i]的解决方法
  10. Bsgrid表格插入日期表头
  11. 数据挖掘之关联规则挖掘的一些定义
  12. Android-实现图文混排编辑
  13. 使用pyecharts绘制中国历代都城的分布图
  14. 机器学习模型的集成方法总结:Bagging, Boosting, Stacking, Voting, Blending
  15. 基于django的微信小程序搭建
  16. 算法——求某个数的质因数
  17. p什么水管_pvc管规格-P是什么意思pvc给水管规格中T?pvc给 – 手机爱问
  18. Go-ICP: A Globally Optimal Solutionto 3D ICP Point-Set Registration(2016)
  19. 第一章 eNSP学习(1-5)
  20. 【Python数据分析学习实例】对学生成绩单和信息进行整合以及数据分析

热门文章

  1. 利用openssl进行base64的编码与解码
  2. JAVA JDK 、Maven、IDEA安装
  3. Linux服务器---配置apache支持用户认证
  4. 非阻塞IO发送http请求
  5. 数据结构:二维ST表
  6. Java:下拉列表绑定后台数据
  7. POJ 3669 简单BFS
  8. js 去空格 和 获得字节数
  9. nw.js桌面软件开发系列 第0.1节 HTML5和桌面软件开发的碰撞
  10. enumerateObjectsUsingBlock 、for 、for(... in ...) 的区别 性能测试