TensorFlow实现自定义Op方式


Posted in Python onFebruary 04, 2020

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"
 
REGISTER_OP("Custom")  
  .Input("custom_input: int32")
  .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"
 
using namespace tensorflow;
 
class CustomOp : public OpKernel{
  public:
  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
  // 获取输入 tensor.
  const Tensor& input_tensor = context->input(0);
  auto input = input_tensor.flat<int32>();
  // 创建一个输出 tensor.
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                           &output_tensor));
  auto output = output_tensor->template flat<int32>();
  //进行具体的运算,操作input和output
  //……
 }
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>
#include <vector>
#include <cmath>
 
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
 
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
 
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
 
using namespace tensorflow;
 
REGISTER_OP("CTCBeamSearchDecoderWithParam")
  .Input("inputs: float")
  .Input("sequence_length: int32")
  .Attr("beam_width: int >= 1")
  .Attr("top_paths: int >= 1")
  .Attr("merge_repeated: bool = true")
  //新添加了两个参数
  .Attr("label_selection_size: int >= 0 = 0") 
  .Attr("label_selection_margin: float") 
  .Output("decoded_indices: top_paths * int64")
  .Output("decoded_values: top_paths * int64")
  .Output("decoded_shape: top_paths * int64")
  .Output("log_probability: float")
  .SetShapeFn([](InferenceContext* c) {
   ShapeHandle inputs;
   ShapeHandle sequence_length;
 
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
 
   // Get batch size from inputs and sequence_length.
   DimensionHandle batch_size;
   TF_RETURN_IF_ERROR(
     c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
 
   int32 top_paths;
   TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
 
   // Outputs.
   int out_idx = 0;
   for (int i = 0; i < top_paths; ++i) { // decoded_indices
    c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
   }
   for (int i = 0; i < top_paths; ++i) { // decoded_values
    c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
   }
   ShapeHandle shape_v = c->Vector(2);
   for (int i = 0; i < top_paths; ++i) { // decoded_shape
    c->set_output(out_idx++, shape_v);
   }
   c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
   return Status::OK();
  });
 
typedef Eigen::ThreadPoolDevice CPUDevice;
 
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
          int* c) {
 *c = 0;
 CHECK_LT(0, m.dimension(1));
 float p = m(r, 0);
 for (int i = 1; i < m.dimension(1); ++i) {
  if (m(r, i) > p) {
   p = m(r, i);
   *c = i;
  }
 }
 return p;
}
 
class CTCDecodeHelper {
 public:
 CTCDecodeHelper() : top_paths_(1) {}
 
 inline int GetTopPaths() const { return top_paths_; }
 void SetTopPaths(int tp) { top_paths_ = tp; }
 
 Status ValidateInputsGenerateOutputs(
   OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
   Tensor** log_prob, OpOutputList* decoded_indices,
   OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
  Status status = ctx->input("inputs", inputs);
  if (!status.ok()) return status;
  status = ctx->input("sequence_length", seq_len);
  if (!status.ok()) return status;
 
  const TensorShape& inputs_shape = (*inputs)->shape();
 
  if (inputs_shape.dims() != 3) {
   return errors::InvalidArgument("inputs is not a 3-Tensor");
  }
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
 
  if (max_time == 0) {
   return errors::InvalidArgument("max_time is 0");
  }
  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
   return errors::InvalidArgument("sequence_length is not a vector");
  }
 
  if (!(batch_size == (*seq_len)->dim_size(0))) {
   return errors::FailedPrecondition(
     "len(sequence_length) != batch_size. ", "len(sequence_length): ",
     (*seq_len)->dim_size(0), " batch_size: ", batch_size);
  }
 
  auto seq_len_t = (*seq_len)->vec<int32>();
 
  for (int b = 0; b < batch_size; ++b) {
   if (!(seq_len_t(b) <= max_time)) {
    return errors::FailedPrecondition("sequence_length(", b, ") <= ",
                     max_time);
   }
  }
 
  Status s = ctx->allocate_output(
    "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
  if (!s.ok()) return s;
 
  s = ctx->output_list("decoded_indices", decoded_indices);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_values", decoded_values);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_shape", decoded_shape);
  if (!s.ok()) return s;
 
  return Status::OK();
 }
 
 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 Status StoreAllDecodedSequences(
   const std::vector<std::vector<std::vector<int> > >& sequences,
   OpOutputList* decoded_indices, OpOutputList* decoded_values,
   OpOutputList* decoded_shape) const {
  // Calculate the total number of entries for each path
  const int64 batch_size = sequences.size();
  std::vector<int64> num_entries(top_paths_, 0);
 
  // Calculate num_entries per path
  for (const auto& batch_s : sequences) {
   CHECK_EQ(batch_s.size(), top_paths_);
   for (int p = 0; p < top_paths_; ++p) {
    num_entries[p] += batch_s[p].size();
   }
  }
 
  for (int p = 0; p < top_paths_; ++p) {
   Tensor* p_indices = nullptr;
   Tensor* p_values = nullptr;
   Tensor* p_shape = nullptr;
 
   const int64 p_num = num_entries[p];
 
   Status s =
     decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
   if (!s.ok()) return s;
   s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
   if (!s.ok()) return s;
   s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
   if (!s.ok()) return s;
 
   auto indices_t = p_indices->matrix<int64>();
   auto values_t = p_values->vec<int64>();
   auto shape_t = p_shape->vec<int64>();
 
   int64 max_decoded = 0;
   int64 offset = 0;
 
   for (int64 b = 0; b < batch_size; ++b) {
    auto& p_batch = sequences[b][p];
    int64 num_decoded = p_batch.size();
    max_decoded = std::max(max_decoded, num_decoded);
    std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
    for (int64 t = 0; t < num_decoded; ++t, ++offset) {
     indices_t(offset, 0) = b;
     indices_t(offset, 1) = t;
    }
   }
 
   shape_t(0) = batch_size;
   shape_t(1) = max_decoded;
  }
  return Status::OK();
 }
 
 private:
 int top_paths_;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
 
// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
 public:
 explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
  //从参数列表中读取新添的两个参数
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));
  int top_paths;
  OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
  decode_helper_.SetTopPaths(top_paths);
 }
 
 void Compute(OpKernelContext* ctx) override {
  const Tensor* inputs;
  const Tensor* seq_len;
  Tensor* log_prob = nullptr;
  OpOutputList decoded_indices;
  OpOutputList decoded_values;
  OpOutputList decoded_shape;
  OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
              ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
              &decoded_values, &decoded_shape));
 
  auto inputs_t = inputs->tensor<float, 3>();
  auto seq_len_t = seq_len->vec<int32>();
  auto log_prob_t = log_prob->matrix<float>();
 
  const TensorShape& inputs_shape = inputs->shape();
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
  const int64 num_classes_raw = inputs_shape.dim_size(2);
  OP_REQUIRES(
    ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    errors::InvalidArgument("num_classes cannot exceed max int"));
  const int num_classes = static_cast<const int>(num_classes_raw);
 
  log_prob_t.setZero();
 
  std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
 
  for (std::size_t t = 0; t < max_time; ++t) {
   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
                batch_size, num_classes);
  }
 
  ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
                      &beam_scorer_, 1 /* batch_size */,
                      merge_repeated_);
  //使用传入的两个参数进行Set
  beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);
  Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
  auto input_chip_t = input_chip.flat<float>();
 
  std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
  std::vector<float> log_probs;
 
  // Assumption: the blank index is num_classes - 1
  for (int b = 0; b < batch_size; ++b) {
   auto& best_paths_b = best_paths[b];
   best_paths_b.resize(decode_helper_.GetTopPaths());
   for (int t = 0; t < seq_len_t(b); ++t) {
    input_chip_t = input_list_t[t].chip(b, 0);
    auto input_bi =
      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
    beam_search.Step(input_bi);
   }
   OP_REQUIRES_OK(
     ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
                  &log_probs, merge_repeated_));
 
   beam_search.Reset();
 
   for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    log_prob_t(b, bp) = log_probs[bp];
   }
  }
 
  OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
              best_paths, &decoded_indices, &decoded_values,
              &decoded_shape));
 }
 
 private:
 CTCDecodeHelper decode_helper_;
 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
 bool merge_repeated_;
 int beam_width_;
 //新添两个数据成员,用于存储新加的参数
 int label_selection_size;
 float label_selection_margin;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};
 
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),
            CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:

cc_library(
  name = "ctc_decoder_with_param",
  srcs = [
      "new_beamsearch.cc"
      ] +
      glob(["boost_locale/**/*.hpp"]),
  includes = ["boost_locale"],
  copts = ["-std=c++11"],
  deps = ["//tensorflow/core:core",
      "//tensorflow/core/util/ctc",
      "//third_party/eigen3",
  ],
)

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,
          top_paths=1, merge_repeated=True):
  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
    decode_param_op_module.ctc_beam_search_decoder_with_param(
      inputs, sequence_length, beam_width=beam_width,
      top_paths=top_paths, merge_repeated=merge_repeated,
      label_selection_size=40, label_selection_margin=0.99))
  return (
    [tf.SparseTensor(ix, val, shape) for (ix, val, shape)
     in zip(decoded_ixs, decoded_vals, decoded_shapes)],
    log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

以上这篇TensorFlow实现自定义Op方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython框架类和面板类的使用实例
Sep 28 Python
Python实现获取命令行输出结果的方法
Jun 10 Python
基于python requests库中的代理实例讲解
May 07 Python
Python 图像处理: 生成二维高斯分布蒙版的实例
Jul 04 Python
使用python打印十行杨辉三角过程详解
Jul 10 Python
python多任务之协程的使用详解
Aug 26 Python
python爬取Ajax动态加载网页过程解析
Sep 05 Python
python实现按首字母分类查找功能
Oct 31 Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 Python
python使用建议与技巧分享(一)
Aug 17 Python
Python模拟键盘输入自动登录TGP
Nov 27 Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 Python
tensorflow使用指定gpu的方法
Feb 04 #Python
TensorFlow梯度求解tf.gradients实例
Feb 04 #Python
基于TensorFlow中自定义梯度的2种方式
Feb 04 #Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
You might like
无限级别菜单的实现
2006/10/09 PHP
PHP与MySQL开发中页面乱码的产生与解决
2008/03/27 PHP
php str_pad 函数用法简介
2009/07/11 PHP
解析PayPal支付接口的PHP开发方式
2010/11/28 PHP
php文本转图片自动换行的方法
2013/03/13 PHP
利用PHP实现图片等比例放大和缩小的方法详解
2013/06/06 PHP
php简单备份与还原MySql的方法
2016/05/09 PHP
自己使用js/jquery写的一个定制对话框控件
2014/05/02 Javascript
纯js和css实现渐变色包括静态渐变和动态渐变
2014/05/29 Javascript
jquery实现弹出层效果实例
2015/05/19 Javascript
HTML5使用DeviceOrientation实现摇一摇功能
2015/06/05 Javascript
基于JS实现新闻列表无缝向上滚动实例代码
2016/01/22 Javascript
简单的渐变轮播插件
2017/01/12 Javascript
深入浅析AngularJS中的一次性数据绑定 (bindonce)
2017/05/11 Javascript
js 提取某()特殊字符串长度的实例
2017/12/06 Javascript
如何用input标签和jquery实现多图片的上传和回显功能
2018/05/16 jQuery
jquery+ajax实现上传图片并显示上传进度功能【附php后台接收】
2019/06/06 jQuery
Javascript中window.name属性详解
2020/11/19 Javascript
给Python的Django框架下搭建的BLOG添加RSS功能的教程
2015/04/08 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
2015/07/02 Python
Python简单生成8位随机密码的方法
2017/05/24 Python
使用python进行波形及频谱绘制的方法
2019/06/17 Python
基于python的列表list和集合set操作
2019/11/24 Python
python操作yaml说明
2020/04/08 Python
浅析python中的del用法
2020/09/02 Python
一款基于css3和jquery实现的动画显示弹出层按钮教程
2015/01/04 HTML / CSS
HTML5 预加载让页面得以快速呈现
2013/08/13 HTML / CSS
英国首屈一指的票务公司:See Tickets
2019/05/11 全球购物
Moda Operandi官网:美国奢侈品电商,海淘秀场T台同款
2020/05/26 全球购物
weblogic面试题
2016/03/07 面试题
我的梦想演讲稿500字
2014/08/21 职场文书
中层领导干部群众路线对照检查材料思想汇报
2014/10/02 职场文书
小学班主任培训心得体会
2016/01/07 职场文书
关于战胜挫折的名言警句大全!
2019/07/05 职场文书
详解Vue的options
2021/05/15 Vue.js
简单谈谈Python面向对象的相关知识
2021/06/28 Python