维度标记

维度标记是一项实验性尝试,旨在为张量轴提供语义描述,进而提供类型并根据这些类型执行验证步骤。

动机

这种机制的动机可以通过一个简单的例子来说明。在下面的线性神经网络规范中,我们假设一个 NCHW 模型输入

input_in_NCHW -> Transpose(input, perm=[0, 2, 1, 3]) -> AveragePool(input, ...)

在这个神经网络中,用户错误地构建了一个神经网络,它将 NCHW 输入转置为奇怪的 NHCW 格式,并通过假设 NCHW 输入格式的空间池化。尽管这显然是一个错误,但现有基础设施不会向用户报告错误。这应该让严重依赖类型检查作为程序正确性保证组成部分的程序员深感不安。本提案旨在解决当前神经网络规范范式中固有的类型检查真空。

本提案由三个关键组件组成:标记定义、标记传播和标记验证,每个组件都将详细讨论。

标记定义

首先,我们为张量类型定义了一组类型。这些类型是根据以下原则定义的

  1. 足够精细,以消除潜在的陷阱。例如,动机部分中说明的上述示例要求我们区分通道维度和空间特征维度,以确保 AveragePool 运算执行的正确性。

  2. 足够粗粒度,以减轻用户的心理负担。例如,在上述示例中,区分宽度维度和高度维度的需求明显较少,因为池化和卷积等操作通常不对各种空间维度进行区分。因此,我们将所有空间维度概括为特征维度。

  3. 作为第 2 条的重要推论,与模型无关。例如,循环神经网络 (RNN) 中特征维度的语义与卷积神经网络 (CNN) 中空间维度的语义几乎无法区分,因此我们允许用户和开发人员将两者描述为特征维度。

具体来说,在我们的第一个提案中,我们定义了以下一组标准标记

  1. DATA_BATCH 描述训练数据的批次维度。这对应于更常用张量格式表示法 NCHW 中的 N 维度。

  2. DATA_CHANNEL 描述训练数据的通道维度。这对应于 C 维度。

  3. DATA_TIME 描述时间维度。

  4. DATA_FEATURE 描述特征维度。这对应于 HW 维度或 RNN 中的特征维度。

  5. FILTER_IN_CHANNEL 描述滤波器输入通道维度。此维度与输入图像特征图的通道维度相同(在大小上)。

  6. FILTER_OUT_CHANNEL 描述滤波器输出通道维度。此维度与输出图像特征图的通道维度相同(在大小上)。

  7. FILTER_SPATIAL 描述滤波器空间维度。

标记传播

当操作相对于其输入张量置换、销毁或创建维度时,会发生标记传播。在这种情况下,我们将实现定制的、特定于操作的函数,以根据输入张量维度标记推断输出张量维度标记。发生标记传播的一个示例操作是转置操作,其中输出维度标记推断的伪代码可以表述为输入维度标记的函数

for i, j in enumerate(perm):
    out_dim_denotaion[i] = in_dim_denotation[j]

标记验证

当操作期望其输入以特定格式到达时,会发生标记验证。发生标记验证的一个示例操作是 AveragePool 操作,其中输入(如果用维度标记注释)在 2D 情况下的标记应为 [DATA_BATCH, DATA_CHANNEL, DATA_FEATURE, DATA_FEATURE]。如果预期维度标记与实际维度标记之间存在不匹配,则应报告错误。

类型标记

有关如何描述图像和其他类型的更多详细信息,请参阅类型标记文档