首页 > 编程 > Python > 正文

TensorFlow实现自定义Op方式

2020-02-15 21:25:46
字体:
来源:转载
供稿:网友

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h" REGISTER_OP("Custom")    .Input("custom_input: int32")  .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class CustomOp : public OpKernel{  public:  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}  void Compute(OpKernelContext* context) override {  // 获取输入 tensor.  const Tensor& input_tensor = context->input(0);  auto input = input_tensor.flat<int32>();  // 创建一个输出 tensor.  Tensor* output_tensor = NULL;  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),                           &output_tensor));  auto output = output_tensor->template flat<int32>();  //进行具体的运算,操作input和output  //…… }};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>#include <vector>#include <cmath> #include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/framework/op.h"#include "tensorflow/core/framework/op_kernel.h"#include "tensorflow/core/framework/shape_inference.h"#include "tensorflow/core/kernels/bounds_check.h" namespace tf = tensorflow;using tf::shape_inference::DimensionHandle;using tf::shape_inference::InferenceContext;using tf::shape_inference::ShapeHandle; using namespace tensorflow; REGISTER_OP("CTCBeamSearchDecoderWithParam")  .Input("inputs: float")  .Input("sequence_length: int32")  .Attr("beam_width: int >= 1")  .Attr("top_paths: int >= 1")  .Attr("merge_repeated: bool = true")  //新添加了两个参数  .Attr("label_selection_size: int >= 0 = 0")   .Attr("label_selection_margin: float")   .Output("decoded_indices: top_paths * int64")  .Output("decoded_values: top_paths * int64")  .Output("decoded_shape: top_paths * int64")  .Output("log_probability: float")  .SetShapeFn([](InferenceContext* c) {   ShapeHandle inputs;   ShapeHandle sequence_length;    TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));    // Get batch size from inputs and sequence_length.   DimensionHandle batch_size;   TF_RETURN_IF_ERROR(     c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));    int32 top_paths;   TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));    // Outputs.   int out_idx = 0;   for (int i = 0; i < top_paths; ++i) { // decoded_indices    c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));   }   for (int i = 0; i < top_paths; ++i) { // decoded_values    c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));   }   ShapeHandle shape_v = c->Vector(2);   for (int i = 0; i < top_paths; ++i) { // decoded_shape    c->set_output(out_idx++, shape_v);   }   c->set_output(out_idx++, c->Matrix(batch_size, top_paths));   return Status::OK();  }); typedef Eigen::ThreadPoolDevice CPUDevice; inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,          int* c) { *c = 0; CHECK_LT(0, m.dimension(1)); float p = m(r, 0); for (int i = 1; i < m.dimension(1); ++i) {  if (m(r, i) > p) {   p = m(r, i);   *c = i;  } } return p;} class CTCDecodeHelper { public: CTCDecodeHelper() : top_paths_(1) {}  inline int GetTopPaths() const { return top_paths_; } void SetTopPaths(int tp) { top_paths_ = tp; }  Status ValidateInputsGenerateOutputs(   OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,   Tensor** log_prob, OpOutputList* decoded_indices,   OpOutputList* decoded_values, OpOutputList* decoded_shape) const {  Status status = ctx->input("inputs", inputs);  if (!status.ok()) return status;  status = ctx->input("sequence_length", seq_len);  if (!status.ok()) return status;   const TensorShape& inputs_shape = (*inputs)->shape();   if (inputs_shape.dims() != 3) {   return errors::InvalidArgument("inputs is not a 3-Tensor");  }   const int64 max_time = inputs_shape.dim_size(0);  const int64 batch_size = inputs_shape.dim_size(1);   if (max_time == 0) {   return errors::InvalidArgument("max_time is 0");  }  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {   return errors::InvalidArgument("sequence_length is not a vector");  }   if (!(batch_size == (*seq_len)->dim_size(0))) {   return errors::FailedPrecondition(     "len(sequence_length) != batch_size. ", "len(sequence_length): ",     (*seq_len)->dim_size(0), " batch_size: ", batch_size);  }   auto seq_len_t = (*seq_len)->vec<int32>();   for (int b = 0; b < batch_size; ++b) {   if (!(seq_len_t(b) <= max_time)) {    return errors::FailedPrecondition("sequence_length(", b, ") <= ",                     max_time);   }  }   Status s = ctx->allocate_output(    "log_probability", TensorShape({batch_size, top_paths_}), log_prob);  if (!s.ok()) return s;   s = ctx->output_list("decoded_indices", decoded_indices);  if (!s.ok()) return s;  s = ctx->output_list("decoded_values", decoded_values);  if (!s.ok()) return s;  s = ctx->output_list("decoded_shape", decoded_shape);  if (!s.ok()) return s;   return Status::OK(); }  // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". Status StoreAllDecodedSequences(   const std::vector<std::vector<std::vector<int> > >& sequences,   OpOutputList* decoded_indices, OpOutputList* decoded_values,   OpOutputList* decoded_shape) const {  // Calculate the total number of entries for each path  const int64 batch_size = sequences.size();  std::vector<int64> num_entries(top_paths_, 0);   // Calculate num_entries per path  for (const auto& batch_s : sequences) {   CHECK_EQ(batch_s.size(), top_paths_);   for (int p = 0; p < top_paths_; ++p) {    num_entries[p] += batch_s[p].size();   }  }   for (int p = 0; p < top_paths_; ++p) {   Tensor* p_indices = nullptr;   Tensor* p_values = nullptr;   Tensor* p_shape = nullptr;    const int64 p_num = num_entries[p];    Status s =     decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);   if (!s.ok()) return s;   s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);   if (!s.ok()) return s;   s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);   if (!s.ok()) return s;    auto indices_t = p_indices->matrix<int64>();   auto values_t = p_values->vec<int64>();   auto shape_t = p_shape->vec<int64>();    int64 max_decoded = 0;   int64 offset = 0;    for (int64 b = 0; b < batch_size; ++b) {    auto& p_batch = sequences[b][p];    int64 num_decoded = p_batch.size();    max_decoded = std::max(max_decoded, num_decoded);    std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));    for (int64 t = 0; t < num_decoded; ++t, ++offset) {     indices_t(offset, 0) = b;     indices_t(offset, 1) = t;    }   }    shape_t(0) = batch_size;   shape_t(1) = max_decoded;  }  return Status::OK(); }  private: int top_paths_; TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);}; // CTC beam searchclass CTCBeamSearchDecoderWithParamOp : public OpKernel { public: explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {  OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));  OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));  //从参数列表中读取新添的两个参数  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));  int top_paths;  OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));  decode_helper_.SetTopPaths(top_paths); }  void Compute(OpKernelContext* ctx) override {  const Tensor* inputs;  const Tensor* seq_len;  Tensor* log_prob = nullptr;  OpOutputList decoded_indices;  OpOutputList decoded_values;  OpOutputList decoded_shape;  OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(              ctx, &inputs, &seq_len, &log_prob, &decoded_indices,              &decoded_values, &decoded_shape));   auto inputs_t = inputs->tensor<float, 3>();  auto seq_len_t = seq_len->vec<int32>();  auto log_prob_t = log_prob->matrix<float>();   const TensorShape& inputs_shape = inputs->shape();   const int64 max_time = inputs_shape.dim_size(0);  const int64 batch_size = inputs_shape.dim_size(1);  const int64 num_classes_raw = inputs_shape.dim_size(2);  OP_REQUIRES(    ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),    errors::InvalidArgument("num_classes cannot exceed max int"));  const int num_classes = static_cast<const int>(num_classes_raw);   log_prob_t.setZero();   std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;   for (std::size_t t = 0; t < max_time; ++t) {   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,                batch_size, num_classes);  }   ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,                      &beam_scorer_, 1 /* batch_size */,                      merge_repeated_);  //使用传入的两个参数进行Set  beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);  Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));  auto input_chip_t = input_chip.flat<float>();   std::vector<std::vector<std::vector<int> > > best_paths(batch_size);  std::vector<float> log_probs;   // Assumption: the blank index is num_classes - 1  for (int b = 0; b < batch_size; ++b) {   auto& best_paths_b = best_paths[b];   best_paths_b.resize(decode_helper_.GetTopPaths());   for (int t = 0; t < seq_len_t(b); ++t) {    input_chip_t = input_list_t[t].chip(b, 0);    auto input_bi =      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);    beam_search.Step(input_bi);   }   OP_REQUIRES_OK(     ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,                  &log_probs, merge_repeated_));    beam_search.Reset();    for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {    log_prob_t(b, bp) = log_probs[bp];   }  }   OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(              best_paths, &decoded_indices, &decoded_values,              &decoded_shape)); }  private: CTCDecodeHelper decode_helper_; ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_; bool merge_repeated_; int beam_width_; //新添两个数据成员,用于存储新加的参数 int label_selection_size; float label_selection_margin; TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);}; REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),            CTCBeamSearchDecoderWithParamOp);            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表