ONNX 形状推断¶
ONNX 提供了 ONNX 图上的形状推断的可选实现。此实现涵盖了每个核心运算符,并提供了可扩展性接口。因此,您可以选择在图上调用现有的形状推断功能,或者为自定义运算符定义形状推断实现(或两者兼而有之!)。形状推断函数作为 OpSchema 对象的成员存储。
在 ONNX 1.10 版本中,符号生成和传播以及形状数据传播被添加到 ONNX 图级别形状推断中。详细提案在此
背景¶
请参阅 IR.md 的 本节,了解静态张量形状的回顾。特别是,静态张量形状(由 TensorShapeProto 表示)与运行时张量形状不同。此功能通常用于静态(即编译时)不知道确切运行时张量形状的情况。
具有未定义
shape字段的Tensor用于表示未知秩的张量。具有定义
shape的Tensor表示已知秩的张量。TensorShapeProto的每个Dimension可以具有已知的整数值(由dim_value字段表示),或者可以具有由符号标识符(dim_param字段)表示的未知值,或者可以没有定义任何字段(在这种情况下,它表示一个匿名的未知值)。
调用形状推断¶
形状推断可以通过 C++ 或 Python 调用。Python API 及其示例在此描述。
C++ API 由单个函数组成
shape_inference::InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry);
第一个参数是要执行形状推断的 ModelProto,它会原地使用形状信息进行注释。第二个参数是可选的。
限制¶
形状推断不保证完整。特别是,一些动态行为会阻塞形状推断的流程,例如将 Reshape 转换为动态提供的形状。此外,并非所有运算符都要求具有形状推断实现。
形状推断仅适用于常量和简单变量。它不支持包含变量的算术表达式。例如,形状为 (5, 2) 和 (7, 2) 的张量上的 Concat 可以推断产生形状为 (12, 2) 的结果,但形状为 (5, 2) 和 (N, 2) 的张量上的 Concat 将只产生 (M, 2),而不是包含 N+5 的表示。请注意,不同的未知符号值将被传播,因此这里的 M 表示一个未知量,它与 M 的其他出现相同。
这些限制是当前实现的特性,而不是基本约束——如果您需要更高级的功能,请告诉我们!
为运算符实现形状推断¶
您可以为运算符的 Schema 添加形状推断函数
OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction 在 shape_inference.h 中定义,以及核心接口结构 InferenceContext 和各种辅助方法。InferenceContext 是提供给推断函数的核心结构。它允许访问有关运算符输入的信息,并且还允许写入推断出的信息。
要查看大量示例,请在代码库中搜索 TypeAndShapeInferenceFunction 的出现。其中一个相对复杂的实现在 onnx/defs/tensor/defs.cc 中是 Concat 的实现。
请注意以下几点,以便在为运算符实现形状推断方法时避免常见错误
在访问任何输入的
shape之前,代码必须检查形状是否可用。如果不可用,则应将其视为秩未知且已适当地处理的动态张量。通常,形状推断逻辑由对hasInputShape或hasNInputShapes的调用守护。在访问任何维度的
dim_value或dim_param之前,代码必须检查这些字段是否有值。特别是,代码必须处理维度可能没有静态已知值的可能性。
shape_inference.h 中有几个实用函数可以处理各种常见情况。
对于必须具有固定秩的输入,使用
checkInputRank。(请参见RoiAlign的推断作为示例。)当多个输入维度预期相同,并且当输入维度传播到特定输出维度时,可以使用
unifyInputDim和unifyDim和updateOutputShape。(请参见RoiAlign的推断作为示例。)当使用算术从输入维度计算输出维度时,可以在符号维度上使用重载运算符
*和/。(请参见SpaceToDepth的推断作为示例。)
这些实用程序可以安全地处理缺失的形状和维度。
示例:考虑一个简单的矩阵乘法运算符,它期望形状为 [M,K] 和 [K,N] 的输入,并返回形状为 [M,N] 的输出。这可以编码如下
// Check that input 0 has rank 2 (if its rank is known).
checkInputRank(ctx, 0, 2);
// Check that input 1 has rank 2 (if its rank is known).
checkInputRank(ctx, 1, 2);
Dim M, K, N;
// Check various dimensions, handling missing dimensions/shapes safely.
unifyInputDim(ctx, 0, 0, M);
unifyInputDim(ctx, 0, 1, K);
unifyInputDim(ctx, 1, 0, K);
unifyInputDim(ctx, 1, 1, N);
updateOutputShape(ctx, 0, {M. N});