增加一个新 Op(二)

属性类型

属性可以使用下面的类型:

  • string: 任何二进制字节流 (UTF8 不是必须的).
  • int: 一个有型整数.
  • float: 一个浮点数.
  • bool: 真或假.
  • typeDataType 非引用类型之一.
  • shape: 一个 TensorShapeProto.
  • tensor: 一个 TensorProto.
  • list(<type>)<type> 列表, 其中 <type> 是上述类型之一. 注意 list(list(<type>)) 是无效的.

权威的列表以 op_def_builder.cc:FinalizeAttr 为准.

默认值和约束条件

属性可能有默认值, 一些类型的属性可以有约束条件. 为了定义一个有约束条件的属性, 你可以使用下列的 <attr-type-expr> 形式:


  • {'<string1>', '<string2>'}: 属性值必须是一个字符串, 取值可以为 <string1> 或 <string2>. 值的语法已经暗示了值的类型为 string, 已经暗示了. 下述语句模拟了一个枚举值:
  • REGISTER_OP("EnumExample")
          .Attr("e: {'apple', 'orange'}");
    • {<type1>, <type2>}: 值是 type 类型, 且必须为 <type1> 或 <type2> 之一, 当然 <type1> 和 <type2> 必须都是有效的 tensor 类型. 你无须指定属性的类型为 type, 而是通过 {...} 语句给出一个类型列表. 例如, 在下面的例子里, 属性 t 的类型必须为 int32float, 或 bool:
    REGISTER_OP("RestrictedTypeExample")
          .Attr("t: {int32, float, bool}");
    • 这里有一些常见类型约束条件的快捷方式:

      • numbertype: 限制类型为数字类型, 即非 string 非 bool 的类型.
      • realnumbertype: 与 numbertype 区别是不支持复杂类型.
      • quantizedtype: 与 numbertype 区别是只支持量化数值 (quantized number type).

    这些类型的列表在 tensorflow/core/framework/types.h 文件中通过函数定义 (如 NumberTypes()). 本例中属性 t 必须为某种数字类型:

    REGISTER_OP("NumberType")
            .Attr("t: numbertype");
    对于这个 Op:
  • tf.number_type(t=tf.int32)  # 有效
    tf.number_type(t=tf.bool)   # 无效
    int >= <n>: 值必须是一个整数, 且取值大于等于 <n><n> 是一个自然数.

    例如, 下列 Op 注册操作指定了属性 a 的取值至少为 2.

    REGISTER_OP("MinIntExample")
          .Attr("a: int >= 2");
    • list(<type>) >= <n>: 一个 <type> 类型列表, 列表长度必须大于等于 <n>.

    例如, 下面的 Op 注册操作指定属性 a 是一个列表, 列表中的元素类型是 int32 或 float列表长度至少为3.

    REGISTER_OP("TypeListExample")
          .Attr("a: list({int32, float}) >= 3");
    通过添加 = <default> 到约束条件末尾, 给一个属性设置默认值 (使其在自动生成的代码里 变成可选属性), 如下:
  • REGISTER_OP("AttrDefaultExample")
        .Attr("i: int = 0");

    默认值支持的语法将在最终 GraphDef 定义的 protobuf 表示中被使用.

    下面是给所有类型赋予默认值的例子:

    REGISTER_OP("AttrDefaultExampleForAllTypes")
       .Attr("s: string = 'foo'")
       .Attr("i: int = 0")
       .Attr("f: float = 1.0")
       .Attr("b: bool = true")
       .Attr("ty: type = DT_INT32")
       .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
       .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
       .Attr("l_empty: list(int) = []")
       .Attr("l_int: list(int) = [2, 3, 5, 7]");
    请特别注意那些类型值里面包含的 DT_* 名称.
  • 多态

    Type Polymorphism

    对于那些可以使用不同类型输入或产生不同类型输出的 Op, 可以注册 Op 时为输入/输出类型里指定一个属性. 一般紧接着, 会为每一个支持的类型注册一个 OpKernel.

    例如, 除了 int32 外, 想要 ZeroOut Op 支持 float, 注册代码如下:

    REGISTER_OP("ZeroOut")
        .Attr("T: {float, int32}")
        .Input("to_zero: <b>T</b>")
        .Output("zeroed: <b>T</b>");
    这段 Op 注册代码现在指定了输入的类型必须为 float 或 int32, 而且 既然输入和输出制定了同样的类型 T, 输出也同样如此.
  • 一个命名建议:{#naming} 输入, 输出, 和属性通常使用 snake_case 命名法. 唯一的例外是属性被用作输入类型或是输入类型的一部分. 当添加到图中时, 这些属性 可以被推断出来, 因此不会出现在 Op 的函数里. 例如, 最后一个 ZeroOut 定义 生成的 Python 函数如下:
  • def zero_out(to_zero, name=None):
       """...
       参数:
         to_zero: 一个 `Tensor`. 必须为下列类型之一:
             `float32`, `int32`.
         name: 操作的名字 (可选).
    
       返回值:
         一个 `Tensor`, 类型和 `to_zero` 一样.
       """
    如果输入的 to_zero 是一个 int32 的tensor, 然后 T 将被自动 设置为 int32 (实际上是 DT_INT32). 那些推导出的属性的名称字母全大写 或采用驼峰命名法.

    下面是一个输出类型自动推断的例子, 读者可以对比一下:

  • REGISTER_OP("StringToNumber")
         .Input("string_tensor: string")
         .Output("output: out_type")
         .Attr("out_type: {float, int32}");
         .Doc(R"doc(
     Converts each string in the input Tensor to the specified numeric type.
     )doc");
    在这种情况下, 用户需要在生成的 Python 代码中指定输出类型.

  • def string_to_number(string_tensor, out_type=None, name=None):
       """将输入 Tensor 中的每一个字符串转化成指定的数字类型
    
       参数:
         string_tensor: 一个 `string` 类型的 `Tensor`.
         out_type: 一个可选的 `tf.DType`, 取值为 `tf.float32, tf.int32`.
           默认值是 `tf.float32`.
         name: 操作的名称 (可选).
    
       返回值:
         一个 `out_type` 类型的 `Tensor`.
       """
    #include "tensorflow/core/framework/op_kernel.h"
    class ZeroOutInt32Op : public OpKernel {
      // 和之前一样
    };
    class ZeroOutFloatOp : public OpKernel {
     public:
      explicit ZeroOutFloatOp(OpKernelConstruction * context)
          : OpKernel(context) {}
      void Compute(OpKernelContext * context) override {
        // 获取输入 tensor
        const Tensor& input_tensor = context->input(0);
        auto input = input_tensor.flat<float>();
        // 创建一个输出 tensor
        Tensor * output = NULL;
        OP_REQUIRES_OK(context,
                        context->allocate_output(0, input_tensor.shape(), &output));
        auto output_flat = output->template flat<float>();
        // 设置输出 tensor 的所有元素为 0
        const int N = input.size();
        for (int i = 0; i &lt; N; i++) {
          output_flat(i) = 0;
        }<br/>
        // 保留第一个输入值
        if (N &gt; 0) output_flat(0) = input(0);
      }
    };
    // 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中
    // 定义的) 必须是 "int32", 才能实例化. 
    REGISTER_KERNEL_BUILDER(
        Name("ZeroOut")
        .Device(DEVICE_CPU)
        .TypeConstraint&lt;int32&gt;("T"),
        ZeroOutOpInt32);
    REGISTER_KERNEL_BUILDER(
        Name("ZeroOut")
        .Device(DEVICE_CPU)
        .TypeConstraint<float>("T"),
        ZeroOutFloatOp);
    为了保持向后兼容性, 你在为一个 已有的 op 添加属性时, 必须指定一个默认值:

  • REGISTER_OP("ZeroOut")
      .Attr("T: {float, int32} = DT_INT32")
      .Input("to_zero: T")
      .Output("zeroed: T")
    如果需要添加更多类型, 例如 double:

  • REGISTER_OP("ZeroOut")
        .Attr("T: {float, double, int32}")
        .Input("to_zero: T")
        .Output("zeroed: T");
    为了避免为新增的类型写冗余的 OpKernel 代码, 通常可以写一个 C++ 模板作为替代. 当然, 仍然需要为每一个重载版本定义一个 keneral 注册 (REGISTER\_KERNEL\_BUILDER 调用).

  • template <typename T>;
    class ZeroOutOp : public OpKernel {
     public:
        explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
      void Compute(OpKernelContext* context) override {
        // 获取输入 tensor
         const Tensor& input_tensor = context->input(0);
        auto input = input_tensor.flat<T>();
        // 创建一个输出 tensor
          Tensor* output = NULL;
        OP_REQUIRES_OK(context,
                       context->allocate_output(0, input_tensor.shape(), &output));
        auto output_flat = output->template flat<T>();
        // 设置输出 tensor 的所有元素为 0
       const int N = input.size();
        for (int i = 0; i < N; i++) {
          output_flat(i) = 0;
        }
        // Preserve the first input value
        if (N > 0) output_flat(0) = input(0);
      }
    };
    };<br/>
    // 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中
    // 定义的) 必须是 "int32", 才能实例化. </b>
    REGISTER_KERNEL_BUILDER(
        Name("ZeroOut")
        .Device(DEVICE_CPU)
        .TypeConstraint<int32>("T"),
        ZeroOutOp<int32>);
    REGISTER_KERNEL_BUILDER(
        Name("ZeroOut")
        .Device(DEVICE_CPU)
        .TypeConstraint<float>("T"),
        ZeroOutOp<float>);
    REGISTER_KERNEL_BUILDER(
        Name("ZeroOut")
        .Device(DEVICE_CPU)
        .TypeConstraint<double>("T"),
        ZeroOutOp<double>);
    如果有很多重载版本, 可以将注册操作通过一个宏来实现.

  • #include "tensorflow/core/framework/op_kernel.h"
     #define REGISTER_KERNEL(type)                                       \
      REGISTER_KERNEL_BUILDER(                                          \
          Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
          ZeroOutOp<type>)
    REGISTER_KERNEL(int32);
    REGISTER_KERNEL(float);
    REGISTER_KERNEL(double);
     #undef REGISTER_KERNEL
    取决于注册 kernel 使用哪些类型, 你可能可以使用tensorflow/core/framework/register_types.h 提供的宏:

  • #include "tensorflow/core/framework/op_kernel.h"
     #include "tensorflow/core/framework/register_types.h"
    REGISTER_OP("ZeroOut")
        .Attr("T: realnumbertype")
        .Input("to_zero: T")
        .Output("zeroed: T");
    template <typename T>
    class ZeroOutOp : public OpKernel { ... };
     #define REGISTER_KERNEL(type)                                       \
      REGISTER_KERNEL_BUILDER(                                          \
          Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
          ZeroOutOp<type>)
    TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
     #undef REGISTER_KERNEL
    列表输入和输出

    除了能够使用不同类型的 tensor 作为输入或输出, Op 还支持使用多个 tensor 作为输入或输出.

    在接下来的例子里, 属性 T 存储了一个类型列表, 并同时作为输入 in 和输出 out 的类型. 输入和输出均为指定类型的 tensor 列表. 既然输入和输出的类型均为 T, 它们的 tensor 数量和类型 是一致的.

    REGISTER_OP("PolymorphicListExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    可以为列表中可存放的类型设置约束条件. 在下一个例子中, 输入是 float 和 double 类型的 tensor 列表. 例如, 这个 Op 可接受的 输入类型为 (float, double, float) 的数据, 且在此情况下, 输出类型同样 为 (float, double, float).

  • REGISTER_OP("ListTypeRestrictionExample")
        .Attr("T: list({float, double})")
        .Input("in: T")
        .Output("out: T");
    如果想要一个列表中的所有 tensor 是同一类型, 你需要写下列代码:

  • REGISTER_OP("IntListInputExample")
        .Attr("N: int")
        .Input("in: N * int32")
        .Output("out: int32");
    这段代码接受 int32 tensor 列表, 并用一个 int 属性 N 来指定列表的长度.

    这也可用于类型推断. 在下一个例子中, 输入是一个 tensor 列表, 长度为 "N", 类型为 "T", 输出是单个 "T" 的 tensor:

    REGISTER_OP("SameListInputExample")
        .Attr("N: int")
        .Attr("T: type")
        .Input("in: N * T")
        .Output("out: T");
    默认情况下, tensor 列表的最小长度为1. 这个约束条件可以通过 为指定的属性增加一个 ">=" 约束来变更:

  • REGISTER_OP("MinLengthIntListExample")
        .Attr("N: int >= 2")
        .Input("in: N * int32")
        .Output("out: int32");
    同样的语法也适用于 "list(type)" 属性:

  • REGISTER_OP("MinimumLengthPolymorphicListExample")
        .Attr("T: list(type) >= 3")
        .Input("in: T")
        .Output("out: T");


联系我们

邮箱 626512443@qq.com
电话 18611320371(微信)
QQ群 235681453

Copyright © 2015-2024

备案号:京ICP备15003423号-3