增加一个新 Op(一)

预备知识:

  • 对 C++ 有一定了解.

如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作:

  • 在一个 C++ 文件中注册新 Op. Op 的注册与实现是相互独立的. 在其注册时描述了 Op 该如何执行. 例如, 注册 Op 时定义了 Op 的名字, 并指定了它的输入和输出.
  • 使用 C++ 实现 Op. 每一个实现称之为一个 "kernel", 可以存在多个 kernel, 以适配不同的架构 (CPU, GPU 等)或不同的输入/输出类型.
  • 创建一个 Python 包装器(wrapper). 这个包装器是创建 Op 的公开 API. 当注册 Op 时, 会自动生成一个默认 默认的包装器. 既可以直接使用默认包装器, 也可以添加一个新的包装器.
  • (可选) 写一个函数计算 Op 的梯度.
  • (可选) 写一个函数, 描述 Op 的输入和输出 shape. 该函数能够允许从 Op 推断 shape.
  • 测试 Op, 通常使用 Pyhton。如果你定义了梯度,你可以使用Python的GradientChecker来测试它。
  • 定义 Op 的接口

    向 TensorFlow 系统注册来定义 Op 的接口. 在注册时, 指定 Op 的名称, 它的输入(类型和名称) 和输出(类型和名称), 和所需要任何 属性的文档说明.

    为了让你有直观的认识, 创建一个简单的 Op 作为例子. 该 Op 接受一个 int32 类型 tensor 作为 输入, 输出这个 tensor 的一个副本, 副本与原 tensor 唯一的区别在于第一个元素被置为 0. 创建 文件tensorflow/core/user_ops/zero_out.cc, 并调用 REGISTER_OP 宏来定义 Op 的接口.

    #include "tensorflow/core/framework/op.h"
    REGISTER_OP("ZeroOut")
        .Input("to_zero: int32")
        .Output("zeroed: int32");
    ZeroOut Op 接受 32 位整型的 tensor to_zero 作为输入, 输出 32 位整型的 tensor zeroed.
  • 为 Op 实现 kernel

    在定义接口之后, 提供一个或多个 Op 的实现. 为这些 kernel 的每一个创建一个对应的类, 继承 OpKernel, 覆盖Compute 方法. Compute 方法提供一个类型为 OpKernelContext* 的参数 context, 用于访问一些有用的信息, 例如输入和输出的 tensor.

    将 kernel 添加到刚才创建的文件中, kernel 看起来和下面的代码类似:

    #include "tensorflow/core/framework/op_kernel.h"
    using namespace tensorflow;
    class ZeroOutOp : public OpKernel {
     public:
      explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
      void Compute(OpKernelContext* context) override {
        // 获取输入 tensor.
        const Tensor& input_tensor = context->input(0);
        auto input = input_tensor.flat<int32>();
       // 创建一个输出 tensor.
        Tensor* output_tensor = NULL;
        OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                         &output_tensor));
        auto output = output_tensor->template flat<int32>();
        // 设置 tensor 除第一个之外的元素均设为 0.
        const int N = input.size();
        for (int i = 1; i < N; i++) {
          output(i) = 0;
        }
        // 尽可能地保留第一个元素的值.
        if (N > 0) output(0) = input(0);
      }
    };
    实现 kernel 后, 将其注册到 TensorFlow 系统中. 注册时, 可以指定该 kernel 运行时的多个约束 条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行.

    将下列代码加入到 zero_out.cc 中, 注册 ZeroOut op:

    REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

    一旦创建和重新安装了 TensorFlow , Tensorflow 系统可以在需要时引用和使用该 Op.

    生成客户端包装器

    Python Op 包装器

    当编译 TensorFlow 时, 所有放在 tensorflow/core/user_ops 目录下 的 Op 会自动在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py 文件 中生成 Python Op 包装器. 通过以下声明, 把那些 Op 引入到tensorflow/python/user_ops/user_ops.py 中:

    from tensorflow.python.ops.gen_user_ops import *
    你可以选择性将部分函数替换为自己的实现. 为此, 首先要隐藏自动生成的代码, 在 tensorflow/python/BUILD 文件中, 将其名字添加到 "user_ops" 的 hidden 列表.
  • tf_gen_op_wrapper_py(
        name = "user_ops",
        hidden = [
            "Fact",
        ],
        require_shape_functions = False,
    )
    紧接着 "Fact" 列出自己的 Op. 然后, 在 tensorflow/python/user_ops/user_ops.py 中添加你的替代实现函数. 通常, 替代实现函数也会调用自动生成函数来真正把 Op 添加 到图中. 被隐藏的自动生成函数位于 gen_user_ops 包中, 名称多了一个下划线前缀 ("_"). 例如:
  • def my_fact():
        """覆盖一个 Op 自动生成代码的示例."""
        return gen_user_ops._fact()
    C++ Op 包装器

    当编译 TensorFlow 时, 所有 tensorflow/core/user_ops 文件夹 下的 Op 会自动创建 C++ Op 包装器. 例如,tensorflow/core/user_ops/zero_out.cc 中的 Op 会自动在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc} 中生成包装器.

    tensorflow/cc/ops/standard_ops.h 通过下述申明, 导入用户自定义 Op 自动生成的包装器.

    #include "tensorflow/cc/ops/user_ops.h"

    检查 Op 能否正常工作

    验证已经成功实现 Op 的方式是编写测试程序. 创建文件 tensorflow/python/kernel_tests/zero_out_op_test.py, 包含以下内容:

    import tensorflow as tf
    class ZeroOutTest(tf.test.TestCase):
      def testZeroOut(self):
        with self.test_session():
          result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
          self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
    然后运行测试:
  • $ bazel test tensorflow/python:zero_out_op_test
    验证条件

    上述示例假定 Op 能够应用在任何 shape 的 tensor 上. 如果只想应用到 vector 上 呢? 这意味需要在上述 OpKernel 实现中添加相关的检查.

    void Compute(OpKernelContext* context) override {
       // 获取输入 tensor
        const Tensor& input_tensor = context->input(0);
        OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                    errors::InvalidArgument("ZeroOut expects a 1-D vector."));
        // ...
      }

    OP_REQUIRES 断言的输入是一个 vector, 如果不是 vector, 将设置 InvalidArgument 状态并返回. OP_REQUIRES 宏有三个参数:

    如果想要测试一个函数返回的 Status 对象是否是一个错误, 可以使用 OP_REQUIRES_OK. 这些宏如果检测到错误, 会直接跳出函数, 终止函数执行.

    Op 注册

    属性

    Op 可以有属性, 属性的值在 Op 添加到图中时被设置. 属性值用于配置 Op, 在 kernel 实现中, Op 注册的输入和输出类型中, 均可访问这些属性值. 尽可能地使用输入代替属性, 因为输入的灵活性更高, 例如可以在执行步骤中 中被更改, 可以使用 feed 等等. 属性可用于实现一些输入无法做到的事情, 例如影响 Op 签名 (即输入输出的数量和类型) 的配置或只读配置可以通过属性实现.

    注册 Op 时可以用 Attr 方法指定属性的名称和类型, 以此来定义一个属性, 形式如下:

    <name>: <attr-type-expr>
    <name> 必须以字母开头, 可以由数字, 字母, 下划线组成. <attr-type-expr> 是一个类型表达式, 形式如下:

    例如, 如果想要 ZeroOut Op 保存一个用户索引, 指示该 Op 不仅仅只有一个元素, 你可以注册 Op 如下:

    REGISTER_OP("ZeroOut")
        .Attr("preserve_index: int")
        .Input("to_zero: int32")
        .Output("zeroed: int32");
    你的 kernel 可以在构造函数里, 通过 context 参数访问这个属性:
  • class ZeroOutOp : public OpKernel {
     public:
      explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) {
       // 获取欲保存的索引值
        OP_REQUIRES_OK(context,
                       context->GetAttr("preserve_index", &preserve_index_));
        // 检查 preserve_index 是否为正
        OP_REQUIRES(context, preserve_index_ >= 0,
                    errors::InvalidArgument("Need preserve_index >= 0, got ",
                                            preserve_index_));
      }
      void Compute(OpKernelContext* context) override {
        // ...
    }
     private:
      int preserve_index_;
    };
    该值可以在 Compute 方法中被使用:
  • void Compute(OpKernelContext* context) override {
        // ...
       // 检查 preserve_index 范围是否合法
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                    errors::InvalidArgument("preserve_index out of range"));
        // 设置输出 tensor 所有的元素值为 0
       const int N = input.size();
        for (int i = 0; i < N; i++) {
          output_flat(i) = 0;
        }
        // 保存请求的输入值
       output_flat(preserve_index_) = input(preserve_index_);
      }
    为了维持向后兼容性, 将一个属性添加到一个已有的 Op 时, 必须指定一个默认值:
  • REGISTER_OP("ZeroOut")
         .Attr("preserve_index: int = 0")
         .Input("to_zero: int32")
         .Output("zeroed: int32");
联系我们

邮箱 626512443@qq.com
电话 18611320371(微信)
QQ群 235681453

Copyright © 2015-2022

备案号:京ICP备15003423号-3