IR 中的张量表示

ONNX IR 提供了 onnx_ir.TensorProtocol 接口,用于使用不同的数据结构作为张量的底层数据。除了传统的 onnx.TensorProto,您还可以使用 np.ndarraytorch.Tensorjax.Array 以及几乎任何其他方式来表示图中的张量。这使得它们可以通过相同的 TensorProtocol 接口进行访问和序列化,而无需在初始化期间产生额外的复制。

The TensorProtocol

onnx_ir.TensorProtocol 定义了一个用于表示张量的只读接口。实现此接口的张量类具有 nameshapedtypesizenbytesmetadata_props 等属性,以描述张量的基本属性。此外,它还应该实现两个方法 numpy__array__,它们将从底层数据生成等效的 NumPy 数组。

注意

在与初始化器、常量值和张量属性交互时,最好假定使用 TensorProtocol,并且只有在需要时才使用 isinstance 检查具体的类。

张量类

ir.TensorProtoTensor

我们使用 onnx_ir.TensorProtoTensor 作为 proto 的包装器,以实现 onnx_ir.TensorProtocol 接口。您可以像往常一样访问 shapedtype 等。仅在调用 numpy() 时才会产生复制。

注意

可以直接初始化 onnx_ir.TensorProtoTensor,如下所示。但是,通常建议使用 onnx_ir.serde.deserialize_tensor,因为它处理所有类型的 TensorProto(例如,onnx_ir.TensorProtoTensor 不处理外部张量)。请参阅 TensorProto 转换并返回 以获取示例。

import onnx
import onnx_ir as ir

tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
tensor = ir.TensorProtoTensor(tensor_proto)
print("tensor: ", tensor)  # TensorProtoTensor<INT16,[3]>(name='tensor')
print("shape: ", tensor.shape)  # ir.Shape([3])
print("dtype: ", tensor.dtype)  # ir.DataType.INT16
print(tensor.raw == tensor_proto)  # The raw field is the exact tensor_proto provided at initialization
print("tobytes: ", tensor.tobytes())  # b'\x01\x00\x02\x00\x03\x00'
print("numpy: ", tensor.numpy())  # array([1, 2, 3], dtype=int16)
tensor:  TensorProtoTensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name='tensor')
shape:  [3]
dtype:  INT16
True
tobytes:  b'\x01\x00\x02\x00\x03\x00'
numpy:  [1 2 3]

ir.ExternalTensor

外部存储在磁盘上的张量数据通常很大,加载时会占用内存。onnx_ir.ExternalTensor 类使用内存映射来避免将张量加载到内存中。您可以使用该张量作为正常的 NumPy 数组,且内存使用量最小。

请参阅 onnx_ir.serde.deserialize_tensor,以查找将 onnx.TensorProto 转换为 onnx_ir.ExternalTensor 的示例。

ir.Tensor

onnx_ir.Tensor 是 NumPy 数组兼容数组对象(如 np.ndarraytorch.Tensor)的包装器。它最适合创建内存中的张量,而无需将其转换为 TensorProto,以减少转换开销。

提示

如果数组对象定义了 __array__ 方法,则它是兼容的。

要从数组创建张量,只需使用 NumPy 数组初始化它

tensor = ir.Tensor(np.random.rand(1, 2))

初始化器将从数组中获取 dtype 和 shape 信息。

要从 NumPy 数组以外的对象创建张量,您需要指定 dtype

import torch
import onnx_ir as ir

torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
print(tensor.numpy())  # array([1., 2., 3.], dtype=float16)
[1. 2. 3.]

字符串张量

使用 onnx_ir.StringTensor 创建字符串张量。

稀疏张量

稀疏张量尚未支持,但已列入我们的路线图。

TensorProto 转换并返回

在以下场景中,我们展示了如何从 TensorProto 转换为 onnx_ir.Tensor,执行一些计算,然后将其转换回 onnx_ir.Tensor,最后转换为 TensorProto

import onnx_ir as ir
import onnx
import numpy as np

# 1. Create the TensorProto
proto = onnx.helper.make_tensor(
    "tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]
)

# 2. Create an IR Tensor from the Protobuf message
tensor = ir.serde.deserialize_tensor(proto)
# Note that we get a TensorProtoTensor that implements the TensorProtocol
print("tensor:", tensor)  # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
print("tensor.numpy():", tensor.numpy())   # [[1. 2. 3.]
                                           #  [4. 5. 6.]]
print("tensor.tobytes():", tensor.tobytes())  # b'\x00<\x00@\x00B\x00D\x00E\x00F'

# 3. Do computation using numpy
mean = tensor.numpy().mean(axis=0)
print("mean:", mean)  # array([2.5, 3.5, 4.5], dtype=float16)

# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor
tensor_mean = ir.Tensor(mean)
print("tensor_mean:", tensor_mean)  # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='')

# 5. Obtain the TensorProto from ir.Tensor
mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)
print("mean_tensor_proto:", mean_tensor_proto)
print(
    "onnx.numpy_helper.to_array(mean_tensor_proto):",
    onnx.numpy_helper.to_array(mean_tensor_proto)
    # array([2.5, 3.5, 4.5], dtype=float16)
)

# You can obtain the bytes data as well
print("tensor_mean.tobytes():", tensor_mean.tobytes())
print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes())

# Explore other methods defined by TensorProtocol:
print("\n# Explore other methods defined by TensorProtocol:")
print("tensor_mean.shape:", tensor_mean.shape)
print("tensor_mean.dtype:", tensor_mean.dtype)
print("tensor_mean.name:", tensor_mean.name)
print("tensor_mean.doc_string:", tensor_mean.doc_string)
print("tensor_mean.raw:", tensor_mean.raw)
print("tensor_mean.metadata_props:", tensor_mean.metadata_props)
print("tensor_mean.size:", tensor_mean.size)
print("tensor_mean.nbytes:", tensor_mean.nbytes)
print("tensor_mean.raw:", tensor_mean.raw)
tensor: TensorProtoTensor<FLOAT16,[2,3]>(array([[1., 2., 3.], [4., 5., 6.]], dtype=float16), name='tensor')
tensor.numpy(): [[1. 2. 3.]
 [4. 5. 6.]]
tensor.tobytes(): b'\x00<\x00@\x00B\x00D\x00E\x00F'
mean: [2.5 3.5 4.5]
tensor_mean: Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name=None)
mean_tensor_proto: dims: 3
data_type: 10
raw_data: "\000A\000C\200D"

onnx.numpy_helper.to_array(mean_tensor_proto): [2.5 3.5 4.5]
tensor_mean.tobytes(): b'\x00A\x00C\x80D'
Bytes same as proto: True

# Explore other methods defined by TensorProtocol:
tensor_mean.shape: [3]
tensor_mean.dtype: FLOAT16
tensor_mean.name: None
tensor_mean.doc_string: None
tensor_mean.raw: [2.5 3.5 4.5]
tensor_mean.metadata_props: {}
tensor_mean.size: 3
tensor_mean.nbytes: 6
tensor_mean.raw: [2.5 3.5 4.5]

使用非原生 NumPy dtype:bfloat16、float8、int4

onnx_ir.Tensor.numpy() 生成张量值的 NumPy 数组表示。当张量的 dtype 为 NumPy 不支持的 BFLOAT16FLOAT8[...][U]INT4 时,我们使用 ml_dtypes 包中的 dtype。

uint4/int4 总是解包的;tobyte() 生成打包表示,符合预期。

onnx_ir.Tensor 的初始化要求 NumPy 数组遵循以下类型约束,或者具有 ml_dtypes dtype。

  • 对于(解包的)int4,使用 int8,符号位扩展到 8 位。

  • 对于(解包的)uint4,使用 uint8

  • 对于 float8 等 8 位数据类型,使用 uint8

  • 对于 bfloat16,使用 uint16

以下示例展示了如何创建 FLOAT8E4M3FN 张量,转换其值,并创建新张量以存储转换后的值。

import onnx_ir as ir
import numpy as np

array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy())  # [0.00195312 0.00585938]

# Compute
times_100 = tensor.numpy() * np.array(100, dtype=tensor.numpy().dtype)
print("times_100:", times_100)

# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100)  # array([ True,  True])
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1875 0.5625]
new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
new_tensor == times_100 [ True  True]

高级用法

子类化 onnx_ir.Tensor 以实现更高效的访问和更广泛的 dtype 支持

onnx_ir.Tensor 内部将任何与数组兼容的对象转换为 NumPy 数组,以在 tobytes() 中生成字节表示。由于额外的转换,这可能会效率低下。它也限制了对 NumPy 不支持的 dtype(如 bfloat16)的支持,因为 __array__ 方法将失败。

为了完全支持来自其他框架的数组,通常最好创建专门的类来处理它们。下面的 TorchTensor 类演示了如何子类化 onnx_ir.Tensor 来处理 PyTorch 张量。

import ctypes

import numpy.typing as npt
import torch

import onnx_ir as ir


class TorchTensor(ir.Tensor):
    def __init__(
        self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
    ):
        # Pass the tensor as the raw data to ir.Tensor's constructor

        _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
            torch.bfloat16: ir.DataType.BFLOAT16,
            torch.bool: ir.DataType.BOOL,
            torch.complex128: ir.DataType.COMPLEX128,
            torch.complex64: ir.DataType.COMPLEX64,
            torch.float16: ir.DataType.FLOAT16,
            torch.float32: ir.DataType.FLOAT,
            torch.float64: ir.DataType.DOUBLE,
            torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
            torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
            torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
            torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
            torch.int16: ir.DataType.INT16,
            torch.int32: ir.DataType.INT32,
            torch.int64: ir.DataType.INT64,
            torch.int8: ir.DataType.INT8,
            torch.uint8: ir.DataType.UINT8,
            torch.uint16: ir.DataType.UINT16,
            torch.uint32: ir.DataType.UINT32,
            torch.uint64: ir.DataType.UINT64,
        }
        super().__init__(
            tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
        )

    def numpy(self) -> npt.NDArray:
        self.raw: torch.Tensor
        if self.dtype == ir.DataType.BFLOAT16:
            return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
        if self.dtype in {
            ir.DataType.FLOAT8E4M3FN,
            ir.DataType.FLOAT8E4M3FNUZ,
            ir.DataType.FLOAT8E5M2,
            ir.DataType.FLOAT8E5M2FNUZ,
        }:
            return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())

        return self.raw.numpy(force=True)

    def __array__(self, dtype = None, copy: bool | None = None) -> npt.NDArray:
        del copy  # Unused, but needed for the signature
        if dtype is None:
            return self.numpy()
        return self.numpy().__array__(dtype)

    def tobytes(self) -> bytes:
        # Implement tobytes to support native PyTorch types so we can use types like bloat16
        # Reading from memory directly is also more efficient because
        # it avoids copying to a NumPy array
        import torch._subclasses.fake_tensor

        with torch._subclasses.fake_tensor.unset_fake_temporarily():  # pylint: disable=protected-access
            # Disable any fake mode so calling detach() etc. will return a real tensor
            tensor = self.raw.detach().cpu().contiguous()

        if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):  # pylint: disable=protected-access
            raise TypeError(
                f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
                "with a tensor backed by real data using ONNXProgram.apply_weights() "
                "or save the model without initializers by setting include_initializers=False."
            )

        return bytes(
            (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
                tensor.data_ptr()
            )
        )

# Test the implementation
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
tensor = TorchTensor(torch_tensor)
print("tensor: ", tensor)
print("numpy: ", tensor.numpy())
print("tobytes: ", tensor.tobytes())  # b'\x80?\x00@@@'
print("nbytes: ", tensor.nbytes)  # 6
tensor:  TorchTensor<BFLOAT16,[3]>(tensor([1., 2., 3.], dtype=torch.bfloat16), name=None)
numpy:  [1 2 3]
tobytes:  b'\x80?\x00@@@'
nbytes:  6

上面的 TorchTensor 类实现了 tobytes(),以便在张量序列化为 ONNX 文件/TensorProto 时生成正确的字节表示。该类还实现了 __array__() 方法,以返回 NumPy 不支持的类型的位表示。这样,分析通道仍然可以对这些值执行计算。

与不同框架的计算

由于 onnx_ir.Tensor 实现了 __array__ 方法和 __dlpack__ 方法,其内容可以在不复制的情况下与计算框架共享。例如:

import onnx_ir as ir

# We can call numpy methods directly on ir.Tensor
import numpy as np
print(np.multiply(ir.Tensor(np.array([1, 2])), 42))  # array([42., 84.])

# We can transfer arrays to different frameworks
import jax.numpy as jnp
import jax
import torch

# Create ir.Tensor
jax_array = jnp.array([10., 20.])
ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)
torch_tensor = torch.tensor([30., 40.])
ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)

# Use numpy for computation
print(np.multiply(ir_tensor_jax, ir_tensor_torch))  # array([300., 800.], dtype=float32)

# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same
jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)
print(jax_array_from_ir + jax_array)  # [40. 60.]

# Use PyTorch for computation
torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)
print(torch_tensor_from_ir - torch_tensor)  # tensor([-20., -20.])

# They can all be serialized into TensorProto
proto = ir.serde.serialize_tensor(ir_tensor_jax)
print(type(proto))  # <class 'onnx.onnx_ml_pb2.TensorProto'>
print(proto)

# The value is exactly the same as jax_array
print(ir.serde.deserialize_tensor(proto).numpy())  # [10. 20.]
[42. 84.]
[300. 800.]
[40. 60.]
tensor([-20., -20.])
<class 'onnx.onnx_ml_pb2.TensorProto'>
dims: 2
data_type: 1
raw_data: "\000\000 A\000\000\240A"

[10. 20.]

如果您正在图上创建需要对具体值进行计算的通道,这尤其有用。您可以自由使用您喜欢的框架来创建通道。即使下游通道利用了其他计算框架,包含新创建的 onnx_ir.Tensor 的转换图也将兼容。