本文只是简单的翻译了 https://www.tensorflow.org/extend/adding_an_op 的简单部分,高级部分请移步官网。

可能需要新定义 c++ operation 的几种情况:

  • 现有的 operation 组合不出来你想要的 op
  • 现有的 operation 组合 出来的 operation 十分低效
  • 如果你想要手动融合一些操作。

为了实现你的自定义操作,你需要做一下几件事:

  1. 在 c++ 文件中注册一个新opOp registration 定义了 op 的功能接口,它和 op 的实现是独立的。例如:op registration 定义了 op 的名字和 op的输出输出。它同时也定义了 shape 方法,被用于 tensorshape 接口。
  2. c++ 中实现 opop 的实现称之为 kernel ,它是op 的一个具体实现。对于不同的输入输出类型或者 架构(CPUs,GPUs)可以有不同的 kernel 实现 。
  3. 创建一个 python wrapper(可选的): 这个 wrapper 是一个 公开的 API,用来在 python中创建 opop registration 会生成一个默认的 wrapper,我们可以直接使用或者自己添加一个。
  4. 写一个计算 op 梯度的方法(可选)。
  5. 测试 op:为了方便,我们通常在 python 中测试 op,但是你也可以在 c++ 中进行测试。如果你定义了 gradients,你可以 通过 Python 的 gradient checker 验证他们。 这里有个例子relu_op_test.py ,测试 ReLU-likeop 的 前向和梯度过程。

Define the op’s interface

**You define the interface of an op by registering it with the TensorFlow system. **

在注册 op 的时候,你需要指定:

  • op 的名字
  • op 的输入(名字,类型),op 的输出(名字,类型)
  • docstrings
  • op 可能需要的 一些 attrs

为了演示这个到底怎么工作的,我们来看一个简单的例子:

  • 定义一个 op :输入是一个 int32tensor ,输出是输入的 拷贝,除了第一个元素保留,其它全都置零。

为了创建这个 op 的接口, 我们需要:

  • 创建一个文件,名字为 zero_out.cc. 然后调用 REGISTER_OP 宏,使用这个宏来定义 op 的接口 :
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"using namespace tensorflow;REGISTER_OP("ZeroOut").Input("to_zero: int32").Output("zeroed: int32").SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {c->set_output(0, c->input(0));return Status::OK();});

这个 ZeroOut op 接收一个 int 32tensor 作为输入,输出同样也是一个 int32tensor。这个 op 也使用了一个 shape 方法来确保输入和输出的维度是一样的。例如,如果输入的tensor 的shape 是 [10, 20],那么,这个 shape 方法保证输出的 shape 也是 [10, 20]

注意: op 的名字必须遵循驼峰命名法,而且要保证 op 的名字的唯一性。

Implement the kernel for the op

当你 定义了 op 的接口之后,你可以提供一个或多个 关于op 的实现。

为了实现这些 kernels

  • 创建一个类,继承 OpKernel
  • 重写 OpKernel 类的 Compute 方法
    • Compute 方法提供了一个 类型为 OpKernelContext*context 参数 ,从这里,我们可以访问到一些有用的信息,比如 输入 和 输出 tensor

kernel 代码也放到 之前创建的 zero_out.cc 文件中:

#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;class ZeroOutOp : public OpKernel {public:explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}void Compute(OpKernelContext* context) override {// 获取输入 tensorconst Tensor& input_tensor = context->input(0);auto input = input_tensor.flat<int32>();// 创建输出 tensor, context->allocate_output 用来分配输出内存?Tensor* output_tensor = NULL;OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),&output_tensor));auto output_flat = output_tensor->flat<int32>();// 执行计算操作。const int N = input.size();for (int i = 1; i < N; i++) {output_flat(i) = 0;}// Preserve the first input value if possible.if (N > 0) output_flat(0) = input(0);}
};

在实现了 kernel 之后,就可以将这个注册到 tensorflow 系统中去了。在注册时,你需要对 op 的运行环境指定一些限制。例如,你可能有一个 kernel 代码是给 CPU 用的,另一个是给 GPU用的。通过把下列代码添加到 zero_out.cc 中来完成这个功能:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

注意:你实现的 OpKernel 的实例可能会被并行访问,所以,请确保 Compute方法是线程安全的。保证访问 类成员的 方法都加上 mutex。或者更好的选择是,不要通过 类成员来分享 状态。考虑使用 ResourceMgr 来追踪状态。

Multi-threaded CPU kernels

多线程主要由 work shard 搞定。work shard

GPU kernels

请移步官网

Build the op library

使用系统编译器 编译 定义的 op

我们可以使用 系统上的 c++ 编译器 g++ 或者 clang 来编译 zero_out.cc 。二进制的 PIP 包 已经将编译所需的 头文件 和 库 安装到了系统上。Tensorflowpython library 提供了一个用来获取 头文件目录的函数 get_include。下面是这个函数在 ubuntu 上的输出:

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python2.7/site-packages/tensorflow/include'

假设你已经安装好了 g++ ,你可以使用 下面一系列的命令 将你的 op 编译成一个 动态库。

TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -I $TF_INC -O2

如果你的 g++ 版本>5.0 的话,加上这个参数 -D_GLIBCXX_USE_CXX11_ABI=0

Use the op in Python

Tensorflow 的 python 接口提供了 tf.load_op_library 函数用来加载动态 library,同时将 op 注册到tensorflow 框架上。load_op_library 返回一个 python module,它包含了 opkernelpython wrapper 。因此,一旦你编译好了一个 op,就可以使用下列代码通过 python来执行它:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
with tf.Session(''):zero_out_module.zero_out([[1, 2], [3, 4]]).eval()# Prints
array([[1, 0], [0, 0]], dtype=int32)

记住:生成的函数的名字是 snake_case name。如果在c++文件中, op 的名字是ZeroOut,那么在python 中,名字是 zero_out

完整的代码在文章的最后

Verify that the op works

一个验证你的自定义的op是否正确工作的一个好的方法是 为它写一个测试文件。创建一个 zero_out_op_test.py 文件,内容为:

import tensorflow as tfclass ZeroOutTest(tf.test.TestCase):def testZeroOut(self):zero_out_module = tf.load_op_library('./zero_out.so')with self.test_session():result = zero_out_module.zero_out([5, 4, 3, 2, 1])self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])if __name__ == "__main__":tf.test.main()

然后运行这个 test

代码

//zero_out.cc 文件
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;REGISTER_OP("ZeroOut").Input("to_zero: int32").Output("zeroed: int32").SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {c->set_output(0, c->input(0));return Status::OK();});class ZeroOutOp : public OpKernel {public:explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}void Compute(OpKernelContext* context) override {// 将输入 tensor 从 context 中取出。const Tensor& input_tensor = context->input(0);auto input = input_tensor.flat<int32>();// 创建一个 ouput_tensor, 使用 context->allocate_ouput() 给它分配空间。Tensor* output_tensor = NULL;OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),&output_tensor));auto output_flat = output_tensor->flat<int32>();// Set all but the first element of the output tensor to 0.const int N = input.size();for (int i = 1; i < N; i++) {output_flat(i) = 0;}// Preserve the first input value if possible.if (N > 0) output_flat(0) = input(0);}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
#创建动态链接库的命令
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -I $TF_INC -O2

总结

tensorflow 自定义 op 的方法可以总结为:

  1. 写个 diy_op.cc 文件
  2. g++ 把这个文件编译成动态链接库
  3. python 中使用 tf.load_op_library 将库导入。
  4. 就可以使用了。

还有一种方法是用 bazel 编译。

参考资料

https://www.tensorflow.org/extend/adding_an_op

tensorflow:自定义op简单介绍相关推荐

  1. TensorFlow团队:TensorFlow Probability的简单介绍

    文章来源:ATYUN AI平台 在2018年TensorFlow开发者峰会上,我们(TensorFlow团队)宣布发布TensorFlow Probability:一种使机器学习研究人员及相关从业人员 ...

  2. tensorflow自定义op:梯度

    暂时并未解决我的问题,但感觉将来会有用,特此转载 . 在使用 tensorflow 的时候,有时不可避免的会需要自定义 op,官方文档对于 定义 op 的前向过程介绍挺详细,但是对于 梯度 的介绍有点 ...

  3. c++自定义函数简单介绍

    大家好, 今天给大家介绍一下自定义函数. 如有错误请在评论区指出 正文: 1.简单介绍: 函数是一组一起执行一个任务的语句.每个 C++ 程序都至少有一个函数,即主函数 main() ,所有简单的程序 ...

  4. tensorflow自定义op和梯度

    参考资料 官网教程链接 http://www.tensorfly.cn/tfdoc/how_tos/adding_an_op.html#AUTOGENERATED-implement-the-grad ...

  5. php自定义模块,简单介绍OpenCart自定义模块

    OpenCart模块可以自定义模块显示位置.排序.是否开启等功能,用起来十分方便. OpenCart用到的模块管理非常多,首页幻灯.导航.最新商品.特价商品.热卖商品等. 如何自定义一个模块?其实也挺 ...

  6. tensorflow:自定义op

    比官网介绍的更好理解,特此转载 tensorflow:自定义op简单介绍 2017年06月26日 13:32:55 阅读数:6094 tensorflow 自定义 op 本文只是简单的翻译了 http ...

  7. TensorFlow使用Python自定义op和损失函数

    TensorFlow使用Python自定义op和损失函数 TensorFlow是静态图结构,即必须把所有的操作以及网络结构定义好(后来有了动态图功能,即Eager Execution ),在没有用tf ...

  8. TensorFlow实现自定义Op

    『写在前面』 以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程. 基本的流程 1. 定义Op接口 #include "tenso ...

  9. Ubuntu tensorflow自定义GPU版本op节点

    参考:https://blog.csdn.net/qq_27637315/article/details/79114633 windows增加op节点: https://github.com/tens ...

最新文章

  1. 如何用LSTM自编码器进行极端事件预测?(含Python代码)
  2. TCP 三次握手、四手挥手,这样说你能明白吧!
  3. Hibernate5-双向关联-多对多(n:n)
  4. (转)WCF光芒下的Web Service
  5. 美国返还中国文物,阿里谣言粉碎机获奖,教育部规范研究生培养,腾讯严打微信跑分活动,推动降低港澳漫游费,这就是今天的大新闻。...
  6. 认识VLAN,并学会VLAN的划分和网络配置实例
  7. 高性能javascript读书笔记(三.DOM 编程2)
  8. 扩展欧几里得原理与模板
  9. CI中创建你自己的类库
  10. vijos1846 [NOIP2013] 华容道【最短路】
  11. swift 第三方库SwiftyJSON
  12. CSS 框架 Bulma 教程
  13. 校招网工面试经历(持续更新)
  14. AddressSanitizer: heap-buffer-overflow on address 0x602000000534 at pc 0x00000040699d bp 0x7ffce0afd
  15. 专科计算机教育的现状,探析高职计算机专业英语教学现状
  16. Cobaltstrike系列教程(三)beacon详解
  17. 大数据运维 | 集群_监控_CDH_Docker_K8S_两项目_云服务器
  18. Eclipse的架构
  19. 单元测试之verify及使用when打桩时对ArgumentMatchers的使用
  20. 饥荒mod制作教程--物品(食物)该篇主讲贴图--01

热门文章

  1. Unity DOTS 学习笔记2 - 面向数据设计的基本概念(上)
  2. 山西计算机大赛崔奕,2021年中国大学生计算机设计大赛山西省赛评审结果名单公示通知...
  3. 《深入理解计算机系统》漫游指南
  4. java中的LinkedList(链表)与ArrayList(动态数组):(2)尝试简单实现LinkedList
  5. DDR SDRAM芯片DQS的作用以及读写DQS/DQ对齐方式不同的原因
  6. FreeSWITCH权威指南 -- 1.PSTN与VoIP基础(笔记)
  7. 把Excel里的url链接转换为图片显示
  8. 5000元起家,40年4万倍!一个来自贫民窟的亿万富翁
  9. Surface pro 4 使用心得
  10. linux 光功率 模块_光模块及调整光模块输入光功率的方法