IR 中的张量表示¶
ONNX IR 提供了 onnx_ir.TensorProtocol 接口,用于使用不同的数据结构作为张量的底层数据。除了传统的 onnx.TensorProto,您还可以使用 np.ndarray、torch.Tensor、jax.Array 以及几乎任何其他方式来表示图中的张量。这使得它们可以通过相同的 TensorProtocol 接口进行访问和序列化,而无需在初始化期间产生额外的复制。
The TensorProtocol¶
onnx_ir.TensorProtocol 定义了一个用于表示张量的只读接口。实现此接口的张量类具有 name、shape、dtype、size、nbytes 和 metadata_props 等属性,以描述张量的基本属性。此外,它还应该实现两个方法 numpy 和 __array__,它们将从底层数据生成等效的 NumPy 数组。
注意
在与初始化器、常量值和张量属性交互时,最好假定使用 TensorProtocol,并且只有在需要时才使用 isinstance 检查具体的类。
张量类¶
ir.TensorProtoTensor¶
我们使用 onnx_ir.TensorProtoTensor 作为 proto 的包装器,以实现 onnx_ir.TensorProtocol 接口。您可以像往常一样访问 shape、dtype 等。仅在调用 numpy() 时才会产生复制。
注意
可以直接初始化 onnx_ir.TensorProtoTensor,如下所示。但是,通常建议使用 onnx_ir.serde.deserialize_tensor,因为它处理所有类型的 TensorProto(例如,onnx_ir.TensorProtoTensor 不处理外部张量)。请参阅 从 TensorProto 转换并返回 以获取示例。
import onnx
import onnx_ir as ir
tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
tensor = ir.TensorProtoTensor(tensor_proto)
print("tensor: ", tensor) # TensorProtoTensor<INT16,[3]>(name='tensor')
print("shape: ", tensor.shape) # ir.Shape([3])
print("dtype: ", tensor.dtype) # ir.DataType.INT16
print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization
print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00'
print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16)
tensor: TensorProtoTensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name='tensor')
shape: [3]
dtype: INT16
True
tobytes: b'\x01\x00\x02\x00\x03\x00'
numpy: [1 2 3]
ir.ExternalTensor¶
外部存储在磁盘上的张量数据通常很大,加载时会占用内存。onnx_ir.ExternalTensor 类使用内存映射来避免将张量加载到内存中。您可以使用该张量作为正常的 NumPy 数组,且内存使用量最小。
请参阅 onnx_ir.serde.deserialize_tensor,以查找将 onnx.TensorProto 转换为 onnx_ir.ExternalTensor 的示例。
ir.Tensor¶
onnx_ir.Tensor 是 NumPy 数组兼容数组对象(如 np.ndarray 和 torch.Tensor)的包装器。它最适合创建内存中的张量,而无需将其转换为 TensorProto,以减少转换开销。
提示
如果数组对象定义了 __array__ 方法,则它是兼容的。
要从数组创建张量,只需使用 NumPy 数组初始化它
tensor = ir.Tensor(np.random.rand(1, 2))
初始化器将从数组中获取 dtype 和 shape 信息。
要从 NumPy 数组以外的对象创建张量,您需要指定 dtype
import torch
import onnx_ir as ir
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
print(tensor.numpy()) # array([1., 2., 3.], dtype=float16)
[1. 2. 3.]
字符串张量¶
使用 onnx_ir.StringTensor 创建字符串张量。
稀疏张量¶
稀疏张量尚未支持,但已列入我们的路线图。
从 TensorProto 转换并返回¶
在以下场景中,我们展示了如何从 TensorProto 转换为 onnx_ir.Tensor,执行一些计算,然后将其转换回 onnx_ir.Tensor,最后转换为 TensorProto。
import onnx_ir as ir
import onnx
import numpy as np
# 1. Create the TensorProto
proto = onnx.helper.make_tensor(
"tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]
)
# 2. Create an IR Tensor from the Protobuf message
tensor = ir.serde.deserialize_tensor(proto)
# Note that we get a TensorProtoTensor that implements the TensorProtocol
print("tensor:", tensor) # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.]
# [4. 5. 6.]]
print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F'
# 3. Do computation using numpy
mean = tensor.numpy().mean(axis=0)
print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16)
# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor
tensor_mean = ir.Tensor(mean)
print("tensor_mean:", tensor_mean) # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='')
# 5. Obtain the TensorProto from ir.Tensor
mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)
print("mean_tensor_proto:", mean_tensor_proto)
print(
"onnx.numpy_helper.to_array(mean_tensor_proto):",
onnx.numpy_helper.to_array(mean_tensor_proto)
# array([2.5, 3.5, 4.5], dtype=float16)
)
# You can obtain the bytes data as well
print("tensor_mean.tobytes():", tensor_mean.tobytes())
print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes())
# Explore other methods defined by TensorProtocol:
print("\n# Explore other methods defined by TensorProtocol:")
print("tensor_mean.shape:", tensor_mean.shape)
print("tensor_mean.dtype:", tensor_mean.dtype)
print("tensor_mean.name:", tensor_mean.name)
print("tensor_mean.doc_string:", tensor_mean.doc_string)
print("tensor_mean.raw:", tensor_mean.raw)
print("tensor_mean.metadata_props:", tensor_mean.metadata_props)
print("tensor_mean.size:", tensor_mean.size)
print("tensor_mean.nbytes:", tensor_mean.nbytes)
print("tensor_mean.raw:", tensor_mean.raw)
tensor: TensorProtoTensor<FLOAT16,[2,3]>(array([[1., 2., 3.], [4., 5., 6.]], dtype=float16), name='tensor')
tensor.numpy(): [[1. 2. 3.]
[4. 5. 6.]]
tensor.tobytes(): b'\x00<\x00@\x00B\x00D\x00E\x00F'
mean: [2.5 3.5 4.5]
tensor_mean: Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name=None)
mean_tensor_proto: dims: 3
data_type: 10
raw_data: "\000A\000C\200D"
onnx.numpy_helper.to_array(mean_tensor_proto): [2.5 3.5 4.5]
tensor_mean.tobytes(): b'\x00A\x00C\x80D'
Bytes same as proto: True
# Explore other methods defined by TensorProtocol:
tensor_mean.shape: [3]
tensor_mean.dtype: FLOAT16
tensor_mean.name: None
tensor_mean.doc_string: None
tensor_mean.raw: [2.5 3.5 4.5]
tensor_mean.metadata_props: {}
tensor_mean.size: 3
tensor_mean.nbytes: 6
tensor_mean.raw: [2.5 3.5 4.5]
使用非原生 NumPy dtype:bfloat16、float8、int4¶
onnx_ir.Tensor.numpy() 生成张量值的 NumPy 数组表示。当张量的 dtype 为 NumPy 不支持的 BFLOAT16、FLOAT8[...] 或 [U]INT4 时,我们使用 ml_dtypes 包中的 dtype。
uint4/int4 总是解包的;tobyte() 生成打包表示,符合预期。
onnx_ir.Tensor 的初始化要求 NumPy 数组遵循以下类型约束,或者具有 ml_dtypes dtype。
对于(解包的)int4,使用
int8,符号位扩展到 8 位。对于(解包的)uint4,使用
uint8。对于 float8 等 8 位数据类型,使用
uint8。对于 bfloat16,使用
uint16。
以下示例展示了如何创建 FLOAT8E4M3FN 张量,转换其值,并创建新张量以存储转换后的值。
import onnx_ir as ir
import numpy as np
array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938]
# Compute
times_100 = tensor.numpy() * np.array(100, dtype=tensor.numpy().dtype)
print("times_100:", times_100)
# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True])
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1875 0.5625]
new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
new_tensor == times_100 [ True True]
高级用法¶
子类化 onnx_ir.Tensor 以实现更高效的访问和更广泛的 dtype 支持¶
onnx_ir.Tensor 内部将任何与数组兼容的对象转换为 NumPy 数组,以在 tobytes() 中生成字节表示。由于额外的转换,这可能会效率低下。它也限制了对 NumPy 不支持的 dtype(如 bfloat16)的支持,因为 __array__ 方法将失败。
为了完全支持来自其他框架的数组,通常最好创建专门的类来处理它们。下面的 TorchTensor 类演示了如何子类化 onnx_ir.Tensor 来处理 PyTorch 张量。
import ctypes
import numpy.typing as npt
import torch
import onnx_ir as ir
class TorchTensor(ir.Tensor):
def __init__(
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
):
# Pass the tensor as the raw data to ir.Tensor's constructor
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
torch.uint16: ir.DataType.UINT16,
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
super().__init__(
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
)
def numpy(self) -> npt.NDArray:
self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
return self.raw.numpy(force=True)
def __array__(self, dtype = None, copy: bool | None = None) -> npt.NDArray:
del copy # Unused, but needed for the signature
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)
def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids copying to a NumPy array
import torch._subclasses.fake_tensor
with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
# Disable any fake mode so calling detach() etc. will return a real tensor
tensor = self.raw.detach().cpu().contiguous()
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
raise TypeError(
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False."
)
return bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)
# Test the implementation
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
tensor = TorchTensor(torch_tensor)
print("tensor: ", tensor)
print("numpy: ", tensor.numpy())
print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@'
print("nbytes: ", tensor.nbytes) # 6
tensor: TorchTensor<BFLOAT16,[3]>(tensor([1., 2., 3.], dtype=torch.bfloat16), name=None)
numpy: [1 2 3]
tobytes: b'\x80?\x00@@@'
nbytes: 6
上面的 TorchTensor 类实现了 tobytes(),以便在张量序列化为 ONNX 文件/TensorProto 时生成正确的字节表示。该类还实现了 __array__() 方法,以返回 NumPy 不支持的类型的位表示。这样,分析通道仍然可以对这些值执行计算。
与不同框架的计算¶
由于 onnx_ir.Tensor 实现了 __array__ 方法和 __dlpack__ 方法,其内容可以在不复制的情况下与计算框架共享。例如:
import onnx_ir as ir
# We can call numpy methods directly on ir.Tensor
import numpy as np
print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.])
# We can transfer arrays to different frameworks
import jax.numpy as jnp
import jax
import torch
# Create ir.Tensor
jax_array = jnp.array([10., 20.])
ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)
torch_tensor = torch.tensor([30., 40.])
ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
# Use numpy for computation
print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32)
# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same
jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)
print(jax_array_from_ir + jax_array) # [40. 60.]
# Use PyTorch for computation
torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)
print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.])
# They can all be serialized into TensorProto
proto = ir.serde.serialize_tensor(ir_tensor_jax)
print(type(proto)) # <class 'onnx.onnx_ml_pb2.TensorProto'>
print(proto)
# The value is exactly the same as jax_array
print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.]
[42. 84.]
[300. 800.]
[40. 60.]
tensor([-20., -20.])
<class 'onnx.onnx_ml_pb2.TensorProto'>
dims: 2
data_type: 1
raw_data: "\000\000 A\000\000\240A"
[10. 20.]
如果您正在图上创建需要对具体值进行计算的通道,这尤其有用。您可以自由使用您喜欢的框架来创建通道。即使下游通道利用了其他计算框架,包含新创建的 onnx_ir.Tensor 的转换图也将兼容。