TensorFlow实现自定义Op方式


Posted in Python onFebruary 04, 2020

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"
 
REGISTER_OP("Custom")  
  .Input("custom_input: int32")
  .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"
 
using namespace tensorflow;
 
class CustomOp : public OpKernel{
  public:
  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
  // 获取输入 tensor.
  const Tensor& input_tensor = context->input(0);
  auto input = input_tensor.flat<int32>();
  // 创建一个输出 tensor.
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                           &output_tensor));
  auto output = output_tensor->template flat<int32>();
  //进行具体的运算,操作input和output
  //……
 }
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>
#include <vector>
#include <cmath>
 
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
 
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
 
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
 
using namespace tensorflow;
 
REGISTER_OP("CTCBeamSearchDecoderWithParam")
  .Input("inputs: float")
  .Input("sequence_length: int32")
  .Attr("beam_width: int >= 1")
  .Attr("top_paths: int >= 1")
  .Attr("merge_repeated: bool = true")
  //新添加了两个参数
  .Attr("label_selection_size: int >= 0 = 0") 
  .Attr("label_selection_margin: float") 
  .Output("decoded_indices: top_paths * int64")
  .Output("decoded_values: top_paths * int64")
  .Output("decoded_shape: top_paths * int64")
  .Output("log_probability: float")
  .SetShapeFn([](InferenceContext* c) {
   ShapeHandle inputs;
   ShapeHandle sequence_length;
 
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
 
   // Get batch size from inputs and sequence_length.
   DimensionHandle batch_size;
   TF_RETURN_IF_ERROR(
     c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
 
   int32 top_paths;
   TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
 
   // Outputs.
   int out_idx = 0;
   for (int i = 0; i < top_paths; ++i) { // decoded_indices
    c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
   }
   for (int i = 0; i < top_paths; ++i) { // decoded_values
    c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
   }
   ShapeHandle shape_v = c->Vector(2);
   for (int i = 0; i < top_paths; ++i) { // decoded_shape
    c->set_output(out_idx++, shape_v);
   }
   c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
   return Status::OK();
  });
 
typedef Eigen::ThreadPoolDevice CPUDevice;
 
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
          int* c) {
 *c = 0;
 CHECK_LT(0, m.dimension(1));
 float p = m(r, 0);
 for (int i = 1; i < m.dimension(1); ++i) {
  if (m(r, i) > p) {
   p = m(r, i);
   *c = i;
  }
 }
 return p;
}
 
class CTCDecodeHelper {
 public:
 CTCDecodeHelper() : top_paths_(1) {}
 
 inline int GetTopPaths() const { return top_paths_; }
 void SetTopPaths(int tp) { top_paths_ = tp; }
 
 Status ValidateInputsGenerateOutputs(
   OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
   Tensor** log_prob, OpOutputList* decoded_indices,
   OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
  Status status = ctx->input("inputs", inputs);
  if (!status.ok()) return status;
  status = ctx->input("sequence_length", seq_len);
  if (!status.ok()) return status;
 
  const TensorShape& inputs_shape = (*inputs)->shape();
 
  if (inputs_shape.dims() != 3) {
   return errors::InvalidArgument("inputs is not a 3-Tensor");
  }
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
 
  if (max_time == 0) {
   return errors::InvalidArgument("max_time is 0");
  }
  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
   return errors::InvalidArgument("sequence_length is not a vector");
  }
 
  if (!(batch_size == (*seq_len)->dim_size(0))) {
   return errors::FailedPrecondition(
     "len(sequence_length) != batch_size. ", "len(sequence_length): ",
     (*seq_len)->dim_size(0), " batch_size: ", batch_size);
  }
 
  auto seq_len_t = (*seq_len)->vec<int32>();
 
  for (int b = 0; b < batch_size; ++b) {
   if (!(seq_len_t(b) <= max_time)) {
    return errors::FailedPrecondition("sequence_length(", b, ") <= ",
                     max_time);
   }
  }
 
  Status s = ctx->allocate_output(
    "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
  if (!s.ok()) return s;
 
  s = ctx->output_list("decoded_indices", decoded_indices);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_values", decoded_values);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_shape", decoded_shape);
  if (!s.ok()) return s;
 
  return Status::OK();
 }
 
 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 Status StoreAllDecodedSequences(
   const std::vector<std::vector<std::vector<int> > >& sequences,
   OpOutputList* decoded_indices, OpOutputList* decoded_values,
   OpOutputList* decoded_shape) const {
  // Calculate the total number of entries for each path
  const int64 batch_size = sequences.size();
  std::vector<int64> num_entries(top_paths_, 0);
 
  // Calculate num_entries per path
  for (const auto& batch_s : sequences) {
   CHECK_EQ(batch_s.size(), top_paths_);
   for (int p = 0; p < top_paths_; ++p) {
    num_entries[p] += batch_s[p].size();
   }
  }
 
  for (int p = 0; p < top_paths_; ++p) {
   Tensor* p_indices = nullptr;
   Tensor* p_values = nullptr;
   Tensor* p_shape = nullptr;
 
   const int64 p_num = num_entries[p];
 
   Status s =
     decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
   if (!s.ok()) return s;
   s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
   if (!s.ok()) return s;
   s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
   if (!s.ok()) return s;
 
   auto indices_t = p_indices->matrix<int64>();
   auto values_t = p_values->vec<int64>();
   auto shape_t = p_shape->vec<int64>();
 
   int64 max_decoded = 0;
   int64 offset = 0;
 
   for (int64 b = 0; b < batch_size; ++b) {
    auto& p_batch = sequences[b][p];
    int64 num_decoded = p_batch.size();
    max_decoded = std::max(max_decoded, num_decoded);
    std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
    for (int64 t = 0; t < num_decoded; ++t, ++offset) {
     indices_t(offset, 0) = b;
     indices_t(offset, 1) = t;
    }
   }
 
   shape_t(0) = batch_size;
   shape_t(1) = max_decoded;
  }
  return Status::OK();
 }
 
 private:
 int top_paths_;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
 
// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
 public:
 explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
  //从参数列表中读取新添的两个参数
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));
  int top_paths;
  OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
  decode_helper_.SetTopPaths(top_paths);
 }
 
 void Compute(OpKernelContext* ctx) override {
  const Tensor* inputs;
  const Tensor* seq_len;
  Tensor* log_prob = nullptr;
  OpOutputList decoded_indices;
  OpOutputList decoded_values;
  OpOutputList decoded_shape;
  OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
              ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
              &decoded_values, &decoded_shape));
 
  auto inputs_t = inputs->tensor<float, 3>();
  auto seq_len_t = seq_len->vec<int32>();
  auto log_prob_t = log_prob->matrix<float>();
 
  const TensorShape& inputs_shape = inputs->shape();
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
  const int64 num_classes_raw = inputs_shape.dim_size(2);
  OP_REQUIRES(
    ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    errors::InvalidArgument("num_classes cannot exceed max int"));
  const int num_classes = static_cast<const int>(num_classes_raw);
 
  log_prob_t.setZero();
 
  std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
 
  for (std::size_t t = 0; t < max_time; ++t) {
   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
                batch_size, num_classes);
  }
 
  ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
                      &beam_scorer_, 1 /* batch_size */,
                      merge_repeated_);
  //使用传入的两个参数进行Set
  beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);
  Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
  auto input_chip_t = input_chip.flat<float>();
 
  std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
  std::vector<float> log_probs;
 
  // Assumption: the blank index is num_classes - 1
  for (int b = 0; b < batch_size; ++b) {
   auto& best_paths_b = best_paths[b];
   best_paths_b.resize(decode_helper_.GetTopPaths());
   for (int t = 0; t < seq_len_t(b); ++t) {
    input_chip_t = input_list_t[t].chip(b, 0);
    auto input_bi =
      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
    beam_search.Step(input_bi);
   }
   OP_REQUIRES_OK(
     ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
                  &log_probs, merge_repeated_));
 
   beam_search.Reset();
 
   for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    log_prob_t(b, bp) = log_probs[bp];
   }
  }
 
  OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
              best_paths, &decoded_indices, &decoded_values,
              &decoded_shape));
 }
 
 private:
 CTCDecodeHelper decode_helper_;
 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
 bool merge_repeated_;
 int beam_width_;
 //新添两个数据成员,用于存储新加的参数
 int label_selection_size;
 float label_selection_margin;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};
 
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),
            CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:

cc_library(
  name = "ctc_decoder_with_param",
  srcs = [
      "new_beamsearch.cc"
      ] +
      glob(["boost_locale/**/*.hpp"]),
  includes = ["boost_locale"],
  copts = ["-std=c++11"],
  deps = ["//tensorflow/core:core",
      "//tensorflow/core/util/ctc",
      "//third_party/eigen3",
  ],
)

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,
          top_paths=1, merge_repeated=True):
  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
    decode_param_op_module.ctc_beam_search_decoder_with_param(
      inputs, sequence_length, beam_width=beam_width,
      top_paths=top_paths, merge_repeated=merge_repeated,
      label_selection_size=40, label_selection_margin=0.99))
  return (
    [tf.SparseTensor(ix, val, shape) for (ix, val, shape)
     in zip(decoded_ixs, decoded_vals, decoded_shapes)],
    log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

以上这篇TensorFlow实现自定义Op方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux下安装easy_install的方法
Feb 10 Python
使用python删除nginx缓存文件示例(python文件操作)
Mar 26 Python
Python 3.x 连接数据库示例(pymysql 方式)
Jan 19 Python
python生成器,可迭代对象,迭代器区别和联系
Feb 04 Python
python实现读Excel写入.txt的方法
Apr 29 Python
解决Python pandas plot输出图形中显示中文乱码问题
Dec 12 Python
详解Python中的内建函数,可迭代对象,迭代器
Apr 29 Python
pandas read_excel()和to_excel()函数解析
Sep 19 Python
Python模拟伯努利试验和二项分布代码实例
May 27 Python
Python中openpyxl实现vlookup函数的实例
Oct 28 Python
使用python实现学生信息管理系统
Feb 25 Python
Django migrate报错的解决方案
May 20 Python
tensorflow使用指定gpu的方法
Feb 04 #Python
TensorFlow梯度求解tf.gradients实例
Feb 04 #Python
基于TensorFlow中自定义梯度的2种方式
Feb 04 #Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
You might like
缅甸的咖啡简史
2021/03/04 咖啡文化
在IIS7.0下面配置PHP 5.3.2运行环境的方法
2010/04/13 PHP
yii操作session实例简介
2014/07/31 PHP
PHP+mysql+ajax轻量级聊天室实现方法详解
2016/10/17 PHP
PHP函数用法详解【初始化、嵌套、内置函数等】
2020/06/02 PHP
javascript显示隐藏层比较不错的方法分析
2008/09/30 Javascript
ext checkboxgroup 回填数据解决
2009/08/21 Javascript
(function($){...})(jQuery)的意思
2010/07/22 Javascript
谈谈JavaScript中的函数与闭包
2013/04/14 Javascript
JavaScript模块规范之AMD规范和CMD规范
2015/10/27 Javascript
Bootstrap模态框水平垂直居中与增加拖拽功能
2016/11/09 Javascript
jQueryUI 拖放排序遇到滚动条时有可能无法执行排序的小bug及解决方案
2016/12/19 Javascript
JavaScript表单验证的两种实现方法
2017/02/11 Javascript
JS实现复选框的全选和批量删除功能
2017/04/05 Javascript
微信小程序开发之数据存储 参数传递 数据缓存
2017/04/13 Javascript
浅谈vue路径优化之resolve
2017/10/13 Javascript
浅谈Koa服务限流方法实践
2017/10/23 Javascript
vue移动端轻量级的轮播组件实现代码
2018/07/12 Javascript
了解JavaScript中的选择器
2019/05/24 Javascript
小程序登录之支付宝授权的实现示例
2019/12/13 Javascript
vue中js判断长时间不操作界面自动退出登录(推荐)
2020/01/22 Javascript
JavaScript鼠标悬停事件用法解析
2020/05/15 Javascript
[01:19:33]DOTA2-DPC中国联赛 正赛 iG vs VG BO3 第一场 2月2日
2021/03/11 DOTA
Python Sleep休眠函数使用简单实例
2015/02/02 Python
详解在Python和IPython中使用Docker
2015/04/28 Python
python函数局部变量用法实例分析
2015/08/04 Python
分析python切片原理和方法
2017/12/19 Python
Windows下anaconda安装第三方包的方法小结(tensorflow、gensim为例)
2018/04/05 Python
使用python itchat包爬取微信好友头像形成矩形头像集的方法
2019/02/21 Python
python导包的几种方法(自定义包的生成以及导入详解)
2019/07/15 Python
Python持续监听文件变化代码实例
2020/07/22 Python
CSS的background属性及CSS3的背景图片设置总结
2016/06/13 HTML / CSS
马来西亚航空官方网站:Malaysia Airlines
2017/07/28 全球购物
节约电力资源的建议书
2014/03/12 职场文书
浅谈MySQL之select优化方案
2021/08/07 MySQL
python绘制云雨图raincloud plot
2022/08/05 Python