属性可以使用下面的类型:
string
: 任何二进制字节流 (UTF8 不是必须的).
int
: 一个有型整数.
float
: 一个浮点数.
bool
: 真或假.
type
: DataType
非引用类型之一.
shape
: 一个 TensorShapeProto
.
tensor
: 一个 TensorProto
.
list(<type>)
: <type>
列表, 其中 <type>
是上述类型之一. 注意 list(list(<type>))
是无效的.
权威的列表以 op_def_builder.cc:FinalizeAttr
为准.
属性可能有默认值, 一些类型的属性可以有约束条件. 为了定义一个有约束条件的属性, 你可以使用下列的 <attr-type-expr>
形式:
{'<string1>', '<string2>'}
: 属性值必须是一个字符串, 取值可以为 <string1>
或 <string2>
. 值的语法已经暗示了值的类型为 string
, 已经暗示了. 下述语句模拟了一个枚举值:
REGISTER_OP("EnumExample") |
.Attr("e: {'apple', 'orange'}"); |
{<type1>, <type2>}
: 值是 type
类型, 且必须为 <type1>
或 <type2>
之一, 当然 <type1>
和 <type2>
必须都是有效的 tensor 类型. 你无须指定属性的类型为 type
, 而是通过 {...}
语句给出一个类型列表. 例如, 在下面的例子里, 属性 t
的类型必须为 int32
, float
, 或 bool
:
REGISTER_OP("RestrictedTypeExample") |
.Attr("t: {int32, float, bool}"); |
这里有一些常见类型约束条件的快捷方式:
numbertype
: 限制类型为数字类型, 即非 string 非 bool 的类型.
realnumbertype
: 与 numbertype
区别是不支持复杂类型.
quantizedtype
: 与 numbertype
区别是只支持量化数值 (quantized number type).
这些类型的列表在 tensorflow/core/framework/types.h
文件中通过函数定义 (如 NumberTypes()
). 本例中属性 t
必须为某种数字类型:
REGISTER_OP("NumberType") |
.Attr("t: numbertype"); |
tf.number_type(t=tf.int32) # 有效 |
tf.number_type(t=tf.bool) # 无效 |
int >= <n>
: 值必须是一个整数, 且取值大于等于 <n>
, <n>
是一个自然数.
例如, 下列 Op 注册操作指定了属性 a
的取值至少为 2
.
REGISTER_OP("MinIntExample") |
.Attr("a: int >= 2"); |
list(<type>) >= <n>
: 一个 <type>
类型列表, 列表长度必须大于等于 <n>
.
例如, 下面的 Op 注册操作指定属性 a
是一个列表, 列表中的元素类型是 int32
或 float
列表长度至少为3.
REGISTER_OP("TypeListExample") |
.Attr("a: list({int32, float}) >= 3"); |
= <default>
到约束条件末尾, 给一个属性设置默认值 (使其在自动生成的代码里 变成可选属性), 如下:
REGISTER_OP("AttrDefaultExample") |
.Attr("i: int = 0"); |
默认值支持的语法将在最终 GraphDef 定义的 protobuf 表示中被使用.
下面是给所有类型赋予默认值的例子:
REGISTER_OP("AttrDefaultExampleForAllTypes") |
.Attr("s: string = 'foo'") |
.Attr("i: int = 0") |
.Attr("f: float = 1.0") |
.Attr("b: bool = true") |
.Attr("ty: type = DT_INT32") |
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }") |
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }") |
.Attr("l_empty: list(int) = []") |
.Attr("l_int: list(int) = [2, 3, 5, 7]"); |
DT_*
名称.
对于那些可以使用不同类型输入或产生不同类型输出的 Op, 可以注册 Op 时为输入/输出类型里指定一个属性. 一般紧接着, 会为每一个支持的类型注册一个 OpKernel
.
例如, 除了 int32
外, 想要 ZeroOut
Op 支持 float
, 注册代码如下:
REGISTER_OP("ZeroOut") |
.Attr("T: {float, int32}") |
.Input("to_zero: <b>T</b>") |
.Output("zeroed: <b>T</b>"); |
float
或 int32
, 而且 既然输入和输出制定了同样的类型 T
, 输出也同样如此.
def zero_out(to_zero, name=None): |
"""... |
参数: |
to_zero: 一个 `Tensor`. 必须为下列类型之一: |
`float32`, `int32`. |
name: 操作的名字 (可选). |
返回值: |
一个 `Tensor`, 类型和 `to_zero` 一样. |
""" |
to_zero
是一个 int32
的tensor, 然后 T
将被自动 设置为 int32
(实际上是 DT_INT32
). 那些推导出的属性的名称字母全大写 或采用驼峰命名法.
下面是一个输出类型自动推断的例子, 读者可以对比一下:
REGISTER_OP("StringToNumber") |
.Input("string_tensor: string") |
.Output("output: out_type") |
.Attr("out_type: {float, int32}"); |
.Doc(R"doc( |
Converts each string in the input Tensor to the specified numeric type. |
)doc"); |
def string_to_number(string_tensor, out_type=None, name=None): |
"""将输入 Tensor 中的每一个字符串转化成指定的数字类型 |
参数: |
string_tensor: 一个 `string` 类型的 `Tensor`. |
out_type: 一个可选的 `tf.DType`, 取值为 `tf.float32, tf.int32`. |
默认值是 `tf.float32`. |
name: 操作的名称 (可选). |
返回值: |
一个 `out_type` 类型的 `Tensor`. |
""" |
#include "tensorflow/core/framework/op_kernel.h" |
class ZeroOutInt32Op : public OpKernel { |
// 和之前一样 |
}; |
class ZeroOutFloatOp : public OpKernel { |
public: |
explicit ZeroOutFloatOp(OpKernelConstruction * context) |
: OpKernel(context) {} |
void Compute(OpKernelContext * context) override { |
// 获取输入 tensor |
const Tensor& input_tensor = context->input(0); |
auto input = input_tensor.flat<float>(); |
// 创建一个输出 tensor |
Tensor * output = NULL; |
OP_REQUIRES_OK(context, |
context->allocate_output(0, input_tensor.shape(), &output)); |
auto output_flat = output->template flat<float>(); |
// 设置输出 tensor 的所有元素为 0 |
const int N = input.size(); |
for (int i = 0; i < N; i++) { |
output_flat(i) = 0; |
}<br/> |
// 保留第一个输入值 |
if (N > 0) output_flat(0) = input(0); |
} |
}; |
// 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中 |
// 定义的) 必须是 "int32", 才能实例化. |
REGISTER_KERNEL_BUILDER( |
Name("ZeroOut") |
.Device(DEVICE_CPU) |
.TypeConstraint<int32>("T"), |
ZeroOutOpInt32); |
REGISTER_KERNEL_BUILDER( |
Name("ZeroOut") |
.Device(DEVICE_CPU) |
.TypeConstraint<float>("T"), |
ZeroOutFloatOp); |
REGISTER_OP("ZeroOut") |
.Attr("T: {float, int32} = DT_INT32") |
.Input("to_zero: T") |
.Output("zeroed: T") |
double
:
REGISTER_OP("ZeroOut") |
.Attr("T: {float, double, int32}") |
.Input("to_zero: T") |
.Output("zeroed: T"); |
OpKernel
代码, 通常可以写一个 C++ 模板作为替代. 当然, 仍然需要为每一个重载版本定义一个 keneral 注册 (REGISTER\_KERNEL\_BUILDER
调用).
template <typename T>; |
class ZeroOutOp : public OpKernel { |
public: |
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} |
void Compute(OpKernelContext* context) override { |
// 获取输入 tensor |
const Tensor& input_tensor = context->input(0); |
auto input = input_tensor.flat<T>(); |
// 创建一个输出 tensor |
Tensor* output = NULL; |
OP_REQUIRES_OK(context, |
context->allocate_output(0, input_tensor.shape(), &output)); |
auto output_flat = output->template flat<T>(); |
// 设置输出 tensor 的所有元素为 0 |
const int N = input.size(); |
for (int i = 0; i < N; i++) { |
output_flat(i) = 0; |
} |
// Preserve the first input value |
if (N > 0) output_flat(0) = input(0); |
} |
}; |
};<br/> |
// 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中 |
// 定义的) 必须是 "int32", 才能实例化. </b> |
REGISTER_KERNEL_BUILDER( |
Name("ZeroOut") |
.Device(DEVICE_CPU) |
.TypeConstraint<int32>("T"), |
ZeroOutOp<int32>); |
REGISTER_KERNEL_BUILDER( |
Name("ZeroOut") |
.Device(DEVICE_CPU) |
.TypeConstraint<float>("T"), |
ZeroOutOp<float>); |
REGISTER_KERNEL_BUILDER( |
Name("ZeroOut") |
.Device(DEVICE_CPU) |
.TypeConstraint<double>("T"), |
ZeroOutOp<double>); |
#include "tensorflow/core/framework/op_kernel.h" |
#define REGISTER_KERNEL(type) \ |
REGISTER_KERNEL_BUILDER( \ |
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
ZeroOutOp<type>) |
REGISTER_KERNEL(int32); |
REGISTER_KERNEL(float); |
REGISTER_KERNEL(double); |
#undef REGISTER_KERNEL |
tensorflow/core/framework/register_types.h
提供的宏:
#include "tensorflow/core/framework/op_kernel.h" |
#include "tensorflow/core/framework/register_types.h" |
REGISTER_OP("ZeroOut") |
.Attr("T: realnumbertype") |
.Input("to_zero: T") |
.Output("zeroed: T"); |
template <typename T> |
class ZeroOutOp : public OpKernel { ... }; |
#define REGISTER_KERNEL(type) \ |
REGISTER_KERNEL_BUILDER( \ |
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
ZeroOutOp<type>) |
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); |
#undef REGISTER_KERNEL |
除了能够使用不同类型的 tensor 作为输入或输出, Op 还支持使用多个 tensor 作为输入或输出.
在接下来的例子里, 属性 T
存储了一个类型列表, 并同时作为输入 in
和输出 out
的类型. 输入和输出均为指定类型的 tensor 列表. 既然输入和输出的类型均为 T
, 它们的 tensor 数量和类型 是一致的.
REGISTER_OP("PolymorphicListExample") |
.Attr("T: list(type)") |
.Input("in: T") |
.Output("out: T"); |
float
和 double
类型的 tensor 列表. 例如, 这个 Op 可接受的 输入类型为 (float, double, float)
的数据, 且在此情况下, 输出类型同样 为 (float, double, float)
.
REGISTER_OP("ListTypeRestrictionExample") |
.Attr("T: list({float, double})") |
.Input("in: T") |
.Output("out: T"); |
REGISTER_OP("IntListInputExample") |
.Attr("N: int") |
.Input("in: N * int32") |
.Output("out: int32"); |
int32
tensor 列表, 并用一个 int
属性 N
来指定列表的长度.
这也可用于类型推断. 在下一个例子中, 输入是一个 tensor 列表, 长度为 "N"
, 类型为 "T"
, 输出是单个 "T"
的 tensor:
REGISTER_OP("SameListInputExample") |
.Attr("N: int") |
.Attr("T: type") |
.Input("in: N * T") |
.Output("out: T"); |
">="
约束来变更:
REGISTER_OP("MinLengthIntListExample") |
.Attr("N: int >= 2") |
.Input("in: N * int32") |
.Output("out: int32"); |
"list(type)"
属性:
REGISTER_OP("MinimumLengthPolymorphicListExample") |
.Attr("T: list(type) >= 3") |
.Input("in: T") |
.Output("out: T"); |