『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"REGISTER_OP("Custom")    .Input("custom_input: int32").Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"using namespace tensorflow;class CustomOp : public OpKernel{public:explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}void Compute(OpKernelContext* context) override {// 获取输入 tensor.const Tensor& input_tensor = context->input(0);auto input = input_tensor.flat<int32>();// 创建一个输出 tensor.Tensor* output_tensor = NULL;OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),&output_tensor));auto output = output_tensor->template flat<int32>();//进行具体的运算,操作input和output//……}
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

  • Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

  • CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

  • Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>
#include <vector>
#include <cmath>#include "tensorflow/core/util/ctc/ctc_beam_search.h"#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;using namespace tensorflow;REGISTER_OP("CTCBeamSearchDecoderWithParam").Input("inputs: float").Input("sequence_length: int32").Attr("beam_width: int >= 1").Attr("top_paths: int >= 1").Attr("merge_repeated: bool = true")//新添加了两个参数.Attr("label_selection_size: int >= 0 = 0") .Attr("label_selection_margin: float") .Output("decoded_indices: top_paths * int64").Output("decoded_values: top_paths * int64").Output("decoded_shape: top_paths * int64").Output("log_probability: float").SetShapeFn([](InferenceContext* c) {ShapeHandle inputs;ShapeHandle sequence_length;TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));// Get batch size from inputs and sequence_length.DimensionHandle batch_size;TF_RETURN_IF_ERROR(c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));int32 top_paths;TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));// Outputs.int out_idx = 0;for (int i = 0; i < top_paths; ++i) {  // decoded_indicesc->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));}for (int i = 0; i < top_paths; ++i) {  // decoded_valuesc->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));}ShapeHandle shape_v = c->Vector(2);for (int i = 0; i < top_paths; ++i) {  // decoded_shapec->set_output(out_idx++, shape_v);}c->set_output(out_idx++, c->Matrix(batch_size, top_paths));return Status::OK();});typedef Eigen::ThreadPoolDevice CPUDevice;inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,int* c) {*c = 0;CHECK_LT(0, m.dimension(1));float p = m(r, 0);for (int i = 1; i < m.dimension(1); ++i) {if (m(r, i) > p) {p = m(r, i);*c = i;}}return p;
}class CTCDecodeHelper {public:CTCDecodeHelper() : top_paths_(1) {}inline int GetTopPaths() const { return top_paths_; }void SetTopPaths(int tp) { top_paths_ = tp; }Status ValidateInputsGenerateOutputs(OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,Tensor** log_prob, OpOutputList* decoded_indices,OpOutputList* decoded_values, OpOutputList* decoded_shape) const {Status status = ctx->input("inputs", inputs);if (!status.ok()) return status;status = ctx->input("sequence_length", seq_len);if (!status.ok()) return status;const TensorShape& inputs_shape = (*inputs)->shape();if (inputs_shape.dims() != 3) {return errors::InvalidArgument("inputs is not a 3-Tensor");}const int64 max_time = inputs_shape.dim_size(0);const int64 batch_size = inputs_shape.dim_size(1);if (max_time == 0) {return errors::InvalidArgument("max_time is 0");}if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {return errors::InvalidArgument("sequence_length is not a vector");}if (!(batch_size == (*seq_len)->dim_size(0))) {return errors::FailedPrecondition("len(sequence_length) != batch_size.  ", "len(sequence_length):  ",(*seq_len)->dim_size(0), " batch_size: ", batch_size);}auto seq_len_t = (*seq_len)->vec<int32>();for (int b = 0; b < batch_size; ++b) {if (!(seq_len_t(b) <= max_time)) {return errors::FailedPrecondition("sequence_length(", b, ") <= ",max_time);}}Status s = ctx->allocate_output("log_probability", TensorShape({batch_size, top_paths_}), log_prob);if (!s.ok()) return s;s = ctx->output_list("decoded_indices", decoded_indices);if (!s.ok()) return s;s = ctx->output_list("decoded_values", decoded_values);if (!s.ok()) return s;s = ctx->output_list("decoded_shape", decoded_shape);if (!s.ok()) return s;return Status::OK();}// sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".Status StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int> > >& sequences,OpOutputList* decoded_indices, OpOutputList* decoded_values,OpOutputList* decoded_shape) const {// Calculate the total number of entries for each pathconst int64 batch_size = sequences.size();std::vector<int64> num_entries(top_paths_, 0);// Calculate num_entries per pathfor (const auto& batch_s : sequences) {CHECK_EQ(batch_s.size(), top_paths_);for (int p = 0; p < top_paths_; ++p) {num_entries[p] += batch_s[p].size();}}for (int p = 0; p < top_paths_; ++p) {Tensor* p_indices = nullptr;Tensor* p_values = nullptr;Tensor* p_shape = nullptr;const int64 p_num = num_entries[p];Status s =decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);if (!s.ok()) return s;s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);if (!s.ok()) return s;s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);if (!s.ok()) return s;auto indices_t = p_indices->matrix<int64>();auto values_t = p_values->vec<int64>();auto shape_t = p_shape->vec<int64>();int64 max_decoded = 0;int64 offset = 0;for (int64 b = 0; b < batch_size; ++b) {auto& p_batch = sequences[b][p];int64 num_decoded = p_batch.size();max_decoded = std::max(max_decoded, num_decoded);std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));for (int64 t = 0; t < num_decoded; ++t, ++offset) {indices_t(offset, 0) = b;indices_t(offset, 1) = t;}}shape_t(0) = batch_size;shape_t(1) = max_decoded;}return Status::OK();}private:int top_paths_;TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {public:explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));//从参数列表中读取新添的两个参数OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));int top_paths;OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));decode_helper_.SetTopPaths(top_paths);}void Compute(OpKernelContext* ctx) override {const Tensor* inputs;const Tensor* seq_len;Tensor* log_prob = nullptr;OpOutputList decoded_indices;OpOutputList decoded_values;OpOutputList decoded_shape;OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(ctx, &inputs, &seq_len, &log_prob, &decoded_indices,&decoded_values, &decoded_shape));auto inputs_t = inputs->tensor<float, 3>();auto seq_len_t = seq_len->vec<int32>();auto log_prob_t = log_prob->matrix<float>();const TensorShape& inputs_shape = inputs->shape();const int64 max_time = inputs_shape.dim_size(0);const int64 batch_size = inputs_shape.dim_size(1);const int64 num_classes_raw = inputs_shape.dim_size(2);OP_REQUIRES(ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),errors::InvalidArgument("num_classes cannot exceed max int"));const int num_classes = static_cast<const int>(num_classes_raw);log_prob_t.setZero();std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;for (std::size_t t = 0; t < max_time; ++t) {input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,batch_size, num_classes);}ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,&beam_scorer_, 1 /* batch_size */,merge_repeated_);//使用传入的两个参数进行Setbeam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));auto input_chip_t = input_chip.flat<float>();std::vector<std::vector<std::vector<int> > > best_paths(batch_size);std::vector<float> log_probs;// Assumption: the blank index is num_classes - 1for (int b = 0; b < batch_size; ++b) {auto& best_paths_b = best_paths[b];best_paths_b.resize(decode_helper_.GetTopPaths());for (int t = 0; t < seq_len_t(b); ++t) {input_chip_t = input_list_t[t].chip(b, 0);auto input_bi =Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);beam_search.Step(input_bi);}OP_REQUIRES_OK(ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,&log_probs, merge_repeated_));beam_search.Reset();for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {log_prob_t(b, bp) = log_probs[bp];}}OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(best_paths, &decoded_indices, &decoded_values,&decoded_shape));}private:CTCDecodeHelper decode_helper_;ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;bool merge_repeated_;int beam_width_;//新添两个数据成员,用于存储新加的参数int label_selection_size;float label_selection_margin;TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

  • 在tensorflow-master目录下新建一个文件夹custom_op
  • cd custom_op
  • 新建一个BUILD文件,并在其中添加如下代码:
cc_library(name = "ctc_decoder_with_param",srcs = ["new_beamsearch.cc"] +glob(["boost_locale/**/*.hpp"]),includes = ["boost_locale"],copts = ["-std=c++11"],deps = ["//tensorflow/core:core","//tensorflow/core/util/ctc","//third_party/eigen3",],
)
  • 编译过程:

1. cd 到 tensorflow-master 目录下
2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param
3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,top_paths=1, merge_repeated=True):decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (decode_param_op_module.ctc_beam_search_decoder_with_param(inputs, sequence_length, beam_width=beam_width,top_paths=top_paths, merge_repeated=merge_repeated,label_selection_size=40, label_selection_margin=0.99))return ([tf.SparseTensor(ix, val, shape) for (ix, val, shape)in zip(decoded_ixs, decoded_vals, decoded_shapes)],log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

TensorFlow实现自定义Op相关推荐

  1. tensorflow:自定义op

    比官网介绍的更好理解,特此转载 tensorflow:自定义op简单介绍 2017年06月26日 13:32:55 阅读数:6094 tensorflow 自定义 op 本文只是简单的翻译了 http ...

  2. tensorflow:自定义op简单介绍

    本文只是简单的翻译了 https://www.tensorflow.org/extend/adding_an_op 的简单部分,高级部分请移步官网. 可能需要新定义 c++ operation 的几种 ...

  3. tensorflow自定义op:梯度

    暂时并未解决我的问题,但感觉将来会有用,特此转载 . 在使用 tensorflow 的时候,有时不可避免的会需要自定义 op,官方文档对于 定义 op 的前向过程介绍挺详细,但是对于 梯度 的介绍有点 ...

  4. TensorFlow使用Python自定义op和损失函数

    TensorFlow使用Python自定义op和损失函数 TensorFlow是静态图结构,即必须把所有的操作以及网络结构定义好(后来有了动态图功能,即Eager Execution ),在没有用tf ...

  5. TensorFlow之变量OP

    TensorFlow之变量OP TensorFlow变量是表示程序处理的共享持久状态的最佳方法.变量通过 tf.Variable OP类进行操作.变量的特点: 存储持久化 可修改值 可指定被训练 1 ...

  6. stylegan2 示例命令fused_bias_act.cu环境配置异常(无法打开包括文件: “tensorflow/core/framework/op.h”

    在python运行stylegan2示例时,运行过程中,触发fused_bias_act.cu中的异常,可以看到fused_bias_act.cu中实际上是用c/c++写的实现代码. 仔细看异常信息会 ...

  7. TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测

    TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测 目录 输出结果 LSTM代码 输出结果 数据集 L ...

  8. Pytorch1.1.0 入门 自定义op(python)

    因为需求,需要调研tensorRT与ONNX关于自定义层的方法.经过之前的调研,首先,关于onnx,开发者手册中的介绍有限,在已知的demo中没有关于onnx自定义层的,详情见TensorRT 5.1 ...

  9. tensorflow自定义op和梯度

    参考资料 官网教程链接 http://www.tensorfly.cn/tfdoc/how_tos/adding_an_op.html#AUTOGENERATED-implement-the-grad ...

最新文章

  1. Directx11 教程(2) 基本的windows应用程序框架(2)
  2. 数据蒋堂 | 这个产品能支持多大数据量?
  3. 清华成果发布 | 广度学习基础计算系统集成平台
  4. php跨平台总结 常用预定义常量
  5. [crypto]-90-crypto的一些术语和思考[inProgress]
  6. 2Boost之UPD,Client and Server
  7. javamail读取并发送完整的html页面
  8. python抓取网站的图片并下载到本地
  9. Kali学习笔记15:防火墙识别、负载均衡识别、WAF识别
  10. 游戏开发中的数学和物理算法(13):点积和叉积
  11. DBS:CUPhone
  12. Word可折叠多级标题
  13. Ubuntu 18.04安装Eclipse教程
  14. 计算机专业答辩需要演示系统么,计算机专业毕业设计答辩流程
  15. system verilog编程题_SystemVerilog通用程序库(下)
  16. 头同尾合十的算法_尾同头合十或头同尾合十等的速算方法word精品
  17. IDM下载器下载百度网盘文件
  18. 小程序楼层索引,将汉字转换为拼音并以首字母排序
  19. Fisher信息量与Cramer-Rao不等式
  20. deepin不安装任何软件实现局域网快速共享文件

热门文章

  1. 认证管理(锐捷网关篇)
  2. 什么是数字证书、公钥私钥
  3. jenkins自动部署到tomcat/weblogic
  4. git libpng warning: iCCP: cHRM chunk does not match sRGB
  5. 我们分析了GitHub上5.46 亿条日志,发现中国开源虽然贡献大但还有这些不足......
  6. 什么是元数据?为何需要元数据?
  7. 在计算机语言中go是什么意思,golang中的断言是什么意思
  8. CTP常见问题系列之一 “CTP : 不合法的登录“
  9. JDBC-----什么是JDBC
  10. 解决在vue中切换图片,gif格式的图片停在最后一帧的问题