预备知识:
如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作:
向 TensorFlow 系统注册来定义 Op 的接口. 在注册时, 指定 Op 的名称, 它的输入(类型和名称) 和输出(类型和名称), 和所需要任何 属性的文档说明.
为了让你有直观的认识, 创建一个简单的 Op 作为例子. 该 Op 接受一个 int32
类型 tensor 作为 输入, 输出这个 tensor 的一个副本, 副本与原 tensor 唯一的区别在于第一个元素被置为 0. 创建 文件tensorflow/core/user_ops/zero_out.cc
, 并调用 REGISTER_OP
宏来定义 Op 的接口.
#include "tensorflow/core/framework/op.h" REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32");
ZeroOut
Op 接受 32 位整型的 tensor to_zero
作为输入, 输出 32 位整型的 tensor zeroed
.
在定义接口之后, 提供一个或多个 Op 的实现. 为这些 kernel 的每一个创建一个对应的类, 继承 OpKernel
, 覆盖Compute
方法. Compute
方法提供一个类型为 OpKernelContext*
的参数 context
, 用于访问一些有用的信息, 例如输入和输出的 tensor.
将 kernel 添加到刚才创建的文件中, kernel 看起来和下面的代码类似:
#include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 获取输入 tensor. const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<int32>(); // 创建一个输出 tensor. Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat<int32>(); // 设置 tensor 除第一个之外的元素均设为 0. const int N = input.size(); for (int i = 1; i < N; i++) { output(i) = 0; } // 尽可能地保留第一个元素的值. if (N > 0) output(0) = input(0); } };实现 kernel 后, 将其注册到 TensorFlow 系统中. 注册时, 可以指定该 kernel 运行时的多个约束 条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行.
将下列代码加入到 zero_out.cc
中, 注册 ZeroOut
op:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
一旦创建和重新安装了 TensorFlow , Tensorflow 系统可以在需要时引用和使用该 Op.
当编译 TensorFlow 时, 所有放在 tensorflow/core/user_ops
目录下 的 Op 会自动在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py
文件 中生成 Python Op 包装器. 通过以下声明, 把那些 Op 引入到tensorflow/python/user_ops/user_ops.py
中:
from tensorflow.python.ops.gen_user_ops import *你可以选择性将部分函数替换为自己的实现. 为此, 首先要隐藏自动生成的代码, 在
tensorflow/python/BUILD
文件中, 将其名字添加到 "user_ops"
的 hidden
列表.
tf_gen_op_wrapper_py( name = "user_ops", hidden = [ "Fact", ], require_shape_functions = False, )紧接着
"Fact"
列出自己的 Op. 然后, 在 tensorflow/python/user_ops/user_ops.py
中添加你的替代实现函数. 通常, 替代实现函数也会调用自动生成函数来真正把 Op 添加 到图中. 被隐藏的自动生成函数位于 gen_user_ops
包中, 名称多了一个下划线前缀 ("_
"). 例如:
def my_fact(): """覆盖一个 Op 自动生成代码的示例.""" return gen_user_ops._fact()C++ Op 包装器
当编译 TensorFlow 时, 所有 tensorflow/core/user_ops
文件夹 下的 Op 会自动创建 C++ Op 包装器. 例如,tensorflow/core/user_ops/zero_out.cc
中的 Op 会自动在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc}
中生成包装器.
tensorflow/cc/ops/standard_ops.h
通过下述申明, 导入用户自定义 Op 自动生成的包装器.
#include "tensorflow/cc/ops/user_ops.h"
验证已经成功实现 Op 的方式是编写测试程序. 创建文件 tensorflow/python/kernel_tests/zero_out_op_test.py
, 包含以下内容:
import tensorflow as tf class ZeroOutTest(tf.test.TestCase): def testZeroOut(self): with self.test_session(): result = tf.user_ops.zero_out([5, 4, 3, 2, 1]) self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])然后运行测试:
$ bazel test tensorflow/python:zero_out_op_test验证条件
上述示例假定 Op 能够应用在任何 shape 的 tensor 上. 如果只想应用到 vector 上 呢? 这意味需要在上述 OpKernel 实现中添加相关的检查.
void Compute(OpKernelContext* context) override { // 获取输入 tensor const Tensor& input_tensor = context->input(0); OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), errors::InvalidArgument("ZeroOut expects a 1-D vector.")); // ... }
OP_REQUIRES 断言的输入是一个 vector, 如果不是 vector, 将设置 InvalidArgument
状态并返回. OP_REQUIRES
宏有三个参数:
context
: 可以是一个 OpKernelContext
或 OpKernelConstruction
指针 (参见tensorflow/core/framework/op_kernel.h
), 其 SetStatus()
方法将被使用到.
tensorflow/core/public/tensor_shape.h
中有一些验证 tensor shape 的函数.
Status
对象表示, 参见 tensorflow/core/public/status.h
. Status
包含一个类型 (通常是 InvalidArgument
, 但也可以是任何类型) 和一个消息. 构造 一个错误的函数位于tensorflow/core/lib/core/errors.h
中.
如果想要测试一个函数返回的 Status
对象是否是一个错误, 可以使用 OP_REQUIRES_OK
. 这些宏如果检测到错误, 会直接跳出函数, 终止函数执行.
Op 可以有属性, 属性的值在 Op 添加到图中时被设置. 属性值用于配置 Op, 在 kernel 实现中, Op 注册的输入和输出类型中, 均可访问这些属性值. 尽可能地使用输入代替属性, 因为输入的灵活性更高, 例如可以在执行步骤中 中被更改, 可以使用 feed 等等. 属性可用于实现一些输入无法做到的事情, 例如影响 Op 签名 (即输入输出的数量和类型) 的配置或只读配置可以通过属性实现.
注册 Op 时可以用 Attr
方法指定属性的名称和类型, 以此来定义一个属性, 形式如下:
<name>: <attr-type-expr>
<name>
必须以字母开头, 可以由数字, 字母, 下划线组成. <attr-type-expr>
是一个类型表达式, 形式如下:
例如, 如果想要 ZeroOut
Op 保存一个用户索引, 指示该 Op 不仅仅只有一个元素, 你可以注册 Op 如下:
REGISTER_OP("ZeroOut") .Attr("preserve_index: int") .Input("to_zero: int32") .Output("zeroed: int32");你的 kernel 可以在构造函数里, 通过
context
参数访问这个属性:
class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) { // 获取欲保存的索引值 OP_REQUIRES_OK(context, context->GetAttr("preserve_index", &preserve_index_)); // 检查 preserve_index 是否为正 OP_REQUIRES(context, preserve_index_ >= 0, errors::InvalidArgument("Need preserve_index >= 0, got ", preserve_index_)); } void Compute(OpKernelContext* context) override { // ... } private: int preserve_index_; };该值可以在
Compute
方法中被使用:
void Compute(OpKernelContext* context) override { // ... // 检查 preserve_index 范围是否合法 OP_REQUIRES(context, preserve_index_ < input.dimension(0), errors::InvalidArgument("preserve_index out of range")); // 设置输出 tensor 所有的元素值为 0 const int N = input.size(); for (int i = 0; i < N; i++) { output_flat(i) = 0; } // 保存请求的输入值 output_flat(preserve_index_) = input(preserve_index_); }为了维持向后兼容性, 将一个属性添加到一个已有的 Op 时, 必须指定一个默认值:
REGISTER_OP("ZeroOut") .Attr("preserve_index: int = 0") .Input("to_zero: int32") .Output("zeroed: int32");