ONNX 模型在 MLIR 编译器基础设施中的表示和参考下推
此项目由 onnx 维护
托管于 GitHub Pages — 主题来自 orderedlist
本文档描述了 --constprop-onnx 传递,该传递用于 ONNX 语言中的操作进行常量传播。
源代码.
给定以下代码
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
%3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"std.return"(%4) : (tensor<1xf32>) -> ()
}
如果我们调用 onnx-mlir-op --constprop-onnx,我们将得到
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
"std.return"(%0) : (tensor<1xf32>) -> ()
}
ONNXConstantOp 使用 MLIR DenseElementsAttr 来存储常量值。需要注意的是,一旦创建了 DenseElementsAttr,它就会一直存在并消耗内存,直到编译结束。在示例中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,通过折叠两个 ONNXAddOp 生成的两个 ONNXConstantOp 中的两个中间 DenseElementsAttr 也存在。对于一个真实世界的模型,中间 DenseElementsAttr 的数量会迅速增加,这会导致编译期间的内存占用量很大。
为了避免在 --constprop-onnx 期间为中间 ONNXConstantOp 创建过多的 DenseElementsAttr,我们设计了一种机制,该机制为中间 ONNXConstantOp 动态分配和释放缓冲区,并且仅在常量传播和其他 ONNX 语言传递之后,在降低到 Krnl(或任何其他目标语言)之前创建 DenseElementsAttr。
这是通过自定义属性 DisposableElementsAttr 完成的,该属性在非复杂标量元素类型(布尔型、整数型和浮点型)的常见情况下,充当 DenseElementsAttr 的替代品。DisposableElementsAttr 实现了与 DenseElementsAttr 相同的 ElementsAttr 接口,在大多数情况下,它们在功能上是相同的,并且周围的代码不需要区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构造和访问 ElementsAttr 实例,以获得内存占用和性能优势。
DisposableElementsAttr 缓冲区的释放发生在编译器传递之间,由 DisposableGarbageCollector 完成,该工具作为“模块”传递(保证“停止世界”,没有其他传递并行执行)之间的“检测”由 PassManager 运行。
DisposableElementsAttr 还提供了其他内存和速度优势,这些优势在类源文件中的注释中有所概述,并在 2022 年 11 月的演示中进行了说明,该演示链接自会议维基页面。
我们使用 MLIR 声明性重写规则 (DRR) 来编写常量传播模式。用于定义模式的 DRR 定义如下所示
class Pattern<
dag sourcePattern,
list<dag> resultPatterns,
list<dag> additionalConstraints = [],
list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)
>;
有关 DRR 的更多信息可以在此处找到。
现在,我们来看一个为 ONNXAddOp 添加常量传播的简单示例。
我们首先将一个模式添加到 ConstProp.td 中。
// Constant Propagation for Add
def AddConstProp : Pat<
// source patten: From add(lhs, rhs).
(ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
// result pattern: To c = lhs + rhs
(CreateAddOfTwoConst $addOp, $lhs, $rhs),
// Additional constraints: if both lhs and rhs are dense constants.
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;
上述模式将通过在编译时添加输入来用一个新的常量替换输入为常量的 ONNXAddOp。要检查输入是否为常量,仅使用 ONNXConstantOp 是不够的,因为常量张量可以是稀疏的,而我们现在仅支持密集常量张量。我们还需要使用 IsFromDenseONNXConstantOp 来检查密集常量张量。
在结果模式中,为了生成 ONNXConstantOp,我们将在编译时添加 lhs 和 rhs,并发出一个 ONNXConstantOp。为了最小化内存占用,此 ONNXConstantOp 具有 DisposableElementsAttr 而不是传统的 DenseElementsAttr。
函数 CreateAddOfTwoConst 将在编译时执行加法并返回一个 ONNXConstantOp。
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
模式中的函数 CreateAddOfTwoConst 调用 ConstProp.cpp 中的 ConstPropElementwiseBinary,其内容如下。
template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value replacingValue, Value lhsValue, Value rhsValue) {
ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());
// Get lhs and rhs ElementsAttr from the values' defining constant ops.
ElementsAttr lhs = getConstValueElements(lhsValue);
ElementsAttr rhs = getConstValueElements(rhsValue);
Type operandsElemType = lhs.getElementType();
assert(operandsElemType == rhs.getElementType() &&
"all element-wise binary ops have matching operands element types");
OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));
// Construct and return a new ONNXConstantOp with the resultElements attribute.
return createReplacingConstantOp(rewriter, replacingValue, resultElements)
.getResult();
}
其中 OnnxElementsAttrBuilder.combine(...) 根据需要广播 lhs 和 rhs 元素,并构造一个新的(可处置的)ElementsAttr,其元素是二进制函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType) 的元素级应用的결果,该函数将 ElementwiseBinaryOp ONNX 操作映射到 C++ 运算符。
有关常量传播的更多信息,请参阅 ConstProp.td 和 ConstProp.cpp。