预备知识:
如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作:
向 TensorFlow 系统注册来定义 Op 的接口. 在注册时, 指定 Op 的名称, 它的输入(类型和名称) 和输出(类型和名称), 和所需要任何 属性的文档说明.
为了让你有直观的认识, 创建一个简单的 Op 作为例子. 该 Op 接受一个 int32
类型 tensor 作为 输入, 输出这个 tensor 的一个副本, 副本与原 tensor 唯一的区别在于第一个元素被置为 0. 创建 文件tensorflow/core/user_ops/zero_out.cc
, 并调用 REGISTER_OP
宏来定义 Op 的接口.
#include "tensorflow/core/framework/op.h" |
REGISTER_OP("ZeroOut") |
.Input("to_zero: int32") |
.Output("zeroed: int32"); |
ZeroOut
Op 接受 32 位整型的 tensor to_zero
作为输入, 输出 32 位整型的 tensor zeroed
.
在定义接口之后, 提供一个或多个 Op 的实现. 为这些 kernel 的每一个创建一个对应的类, 继承 OpKernel
, 覆盖Compute
方法. Compute
方法提供一个类型为 OpKernelContext*
的参数 context
, 用于访问一些有用的信息, 例如输入和输出的 tensor.
将 kernel 添加到刚才创建的文件中, kernel 看起来和下面的代码类似:
#include "tensorflow/core/framework/op_kernel.h" |
using namespace tensorflow; |
class ZeroOutOp : public OpKernel { |
public: |
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} |
void Compute(OpKernelContext* context) override { |
// 获取输入 tensor. |
const Tensor& input_tensor = context->input(0); |
auto input = input_tensor.flat<int32>(); |
// 创建一个输出 tensor. |
Tensor* output_tensor = NULL; |
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), |
&output_tensor)); |
auto output = output_tensor->template flat<int32>(); |
// 设置 tensor 除第一个之外的元素均设为 0. |
const int N = input.size(); |
for (int i = 1; i < N; i++) { |
output(i) = 0; |
} |
// 尽可能地保留第一个元素的值. |
if (N > 0) output(0) = input(0); |
} |
}; |
将下列代码加入到 zero_out.cc
中, 注册 ZeroOut
op:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp); |
一旦创建和重新安装了 TensorFlow , Tensorflow 系统可以在需要时引用和使用该 Op.
当编译 TensorFlow 时, 所有放在 tensorflow/core/user_ops
目录下 的 Op 会自动在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py
文件 中生成 Python Op 包装器. 通过以下声明, 把那些 Op 引入到tensorflow/python/user_ops/user_ops.py
中:
from tensorflow.python.ops.gen_user_ops import * |
tensorflow/python/BUILD
文件中, 将其名字添加到 "user_ops"
的 hidden
列表.
tf_gen_op_wrapper_py( |
name = "user_ops", |
hidden = [ |
"Fact", |
], |
require_shape_functions = False, |
) |
"Fact"
列出自己的 Op. 然后, 在 tensorflow/python/user_ops/user_ops.py
中添加你的替代实现函数. 通常, 替代实现函数也会调用自动生成函数来真正把 Op 添加 到图中. 被隐藏的自动生成函数位于 gen_user_ops
包中, 名称多了一个下划线前缀 ("_
"). 例如:
def my_fact(): |
"""覆盖一个 Op 自动生成代码的示例.""" |
return gen_user_ops._fact() |
当编译 TensorFlow 时, 所有 tensorflow/core/user_ops
文件夹 下的 Op 会自动创建 C++ Op 包装器. 例如,tensorflow/core/user_ops/zero_out.cc
中的 Op 会自动在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc}
中生成包装器.
tensorflow/cc/ops/standard_ops.h
通过下述申明, 导入用户自定义 Op 自动生成的包装器.
#include "tensorflow/cc/ops/user_ops.h" |
验证已经成功实现 Op 的方式是编写测试程序. 创建文件 tensorflow/python/kernel_tests/zero_out_op_test.py
, 包含以下内容:
import tensorflow as tf |
class ZeroOutTest(tf.test.TestCase): |
def testZeroOut(self): |
with self.test_session(): |
result = tf.user_ops.zero_out([5, 4, 3, 2, 1]) |
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) |
$ bazel test tensorflow/python:zero_out_op_test |
上述示例假定 Op 能够应用在任何 shape 的 tensor 上. 如果只想应用到 vector 上 呢? 这意味需要在上述 OpKernel 实现中添加相关的检查.
void Compute(OpKernelContext* context) override { |
// 获取输入 tensor |
const Tensor& input_tensor = context->input(0); |
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), |
errors::InvalidArgument("ZeroOut expects a 1-D vector.")); |
// ... |
} |
OP_REQUIRES 断言的输入是一个 vector, 如果不是 vector, 将设置 InvalidArgument
状态并返回. OP_REQUIRES
宏有三个参数:
context
: 可以是一个 OpKernelContext
或 OpKernelConstruction
指针 (参见tensorflow/core/framework/op_kernel.h
), 其 SetStatus()
方法将被使用到.
tensorflow/core/public/tensor_shape.h
中有一些验证 tensor shape 的函数.
Status
对象表示, 参见 tensorflow/core/public/status.h
. Status
包含一个类型 (通常是 InvalidArgument
, 但也可以是任何类型) 和一个消息. 构造 一个错误的函数位于tensorflow/core/lib/core/errors.h
中.
如果想要测试一个函数返回的 Status
对象是否是一个错误, 可以使用 OP_REQUIRES_OK
. 这些宏如果检测到错误, 会直接跳出函数, 终止函数执行.
Op 可以有属性, 属性的值在 Op 添加到图中时被设置. 属性值用于配置 Op, 在 kernel 实现中, Op 注册的输入和输出类型中, 均可访问这些属性值. 尽可能地使用输入代替属性, 因为输入的灵活性更高, 例如可以在执行步骤中 中被更改, 可以使用 feed 等等. 属性可用于实现一些输入无法做到的事情, 例如影响 Op 签名 (即输入输出的数量和类型) 的配置或只读配置可以通过属性实现.
注册 Op 时可以用 Attr
方法指定属性的名称和类型, 以此来定义一个属性, 形式如下:
<name>: <attr-type-expr> |
<name>
必须以字母开头, 可以由数字, 字母, 下划线组成. <attr-type-expr>
是一个类型表达式, 形式如下:
例如, 如果想要 ZeroOut
Op 保存一个用户索引, 指示该 Op 不仅仅只有一个元素, 你可以注册 Op 如下:
REGISTER_OP("ZeroOut") |
.Attr("preserve_index: int") |
.Input("to_zero: int32") |
.Output("zeroed: int32"); |
context
参数访问这个属性:
class ZeroOutOp : public OpKernel { |
public: |
explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) { |
// 获取欲保存的索引值 |
OP_REQUIRES_OK(context, |
context->GetAttr("preserve_index", &preserve_index_)); |
// 检查 preserve_index 是否为正 |
OP_REQUIRES(context, preserve_index_ >= 0, |
errors::InvalidArgument("Need preserve_index >= 0, got ", |
preserve_index_)); |
} |
void Compute(OpKernelContext* context) override { |
// ... |
} |
private: |
int preserve_index_; |
}; |
Compute
方法中被使用:
void Compute(OpKernelContext* context) override { |
// ... |
// 检查 preserve_index 范围是否合法 |
OP_REQUIRES(context, preserve_index_ < input.dimension(0), |
errors::InvalidArgument("preserve_index out of range")); |
// 设置输出 tensor 所有的元素值为 0 |
const int N = input.size(); |
for (int i = 0; i < N; i++) { |
output_flat(i) = 0; |
} |
// 保存请求的输入值 |
output_flat(preserve_index_) = input(preserve_index_); |
} |
REGISTER_OP("ZeroOut") |
.Attr("preserve_index: int = 0") |
.Input("to_zero: int32") |
.Output("zeroed: int32"); |