Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
_enums.DataType.INT2,
_enums.DataType.UINT2,
)
)

Expand Down Expand Up @@ -300,6 +302,16 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
raise TypeError(
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.INT2:
if array.dtype not in (np.int8, np.uint8, ml_dtypes.int2):
raise TypeError(
f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int2 (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.UINT2:
if array.dtype not in (np.uint8, ml_dtypes.uint2):
raise TypeError(
f"The numpy array dtype must be uint8 or ml_dtypes.uint2 (not {array.dtype}) for IR data type {dtype}."
)
return

try:
Expand Down Expand Up @@ -347,6 +359,10 @@ def _maybe_view_np_array_with_ml_dtypes(
return array.view(ml_dtypes.uint4)
if dtype == _enums.DataType.FLOAT4E2M1:
return array.view(ml_dtypes.float4_e2m1fn)
if dtype == _enums.DataType.INT2:
return array.view(ml_dtypes.int2)
if dtype == _enums.DataType.UINT2:
return array.view(ml_dtypes.uint2)
return array


Expand All @@ -365,7 +381,7 @@ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
"""Create a numpy array for the byte representation of the tensor.

This function is used for serializing the tensor to bytes. It handles the
special cases for 4-bit data types and endianness.
special cases for 2-bit and 4-bit data types and endianness.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we don’t need to handle endianess because these are subtbyte types. So endianess is irrelevant

"""
array = tensor.numpy()
if tensor.dtype in {
Expand All @@ -375,6 +391,12 @@ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
}:
# Pack the array into int4
array = _type_casting.pack_4bitx2(array)
elif tensor.dtype in {
_enums.DataType.INT2,
_enums.DataType.UINT2,
}:
# Pack the array into int2
array = _type_casting.pack_2bitx4(array)
else:
assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
Expand Down Expand Up @@ -726,6 +748,8 @@ def _load(self):
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
_enums.DataType.INT2,
_enums.DataType.UINT2,
}:
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
dt = np.dtype(np.uint8).newbyteorder("<")
Expand Down Expand Up @@ -1051,7 +1075,7 @@ def tofile(self, file) -> None:


class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""A tensor that stores 4bit datatypes in packed format.
"""A tensor that stores 2bit and 4bit datatypes in packed format.

.. versionadded:: 0.1.2
"""
Expand All @@ -1077,7 +1101,7 @@ def __init__(
Args:
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
The value MUST be packed in an integer dtype.
dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
dtype: The data type of the tensor. Must be one of INT2, UINT2, INT4, UINT4, FLOAT4E2M1.
shape: The shape of the tensor.
name: The name of the tensor.
doc_string: The documentation string.
Expand All @@ -1092,9 +1116,9 @@ def __init__(
raise TypeError(f"Expected an array compatible object, got {type(value)}")
self._shape = Shape(shape)
self._shape.freeze()
if dtype.bitwidth != 4:
if dtype.bitwidth not in (2, 4):
raise TypeError(
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
f"PackedTensor only supports INT2, UINT2, INT4, UINT4, FLOAT4E2M1, but got {dtype}"
)
self._dtype = dtype
self._raw = value
Expand All @@ -1104,6 +1128,8 @@ def __init__(
value.dtype == ml_dtypes.float4_e2m1fn
or value.dtype == ml_dtypes.uint4
or value.dtype == ml_dtypes.int4
or value.dtype == ml_dtypes.uint2
or value.dtype == ml_dtypes.int2
):
raise TypeError(
f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
Expand Down
35 changes: 35 additions & 0 deletions src/onnx_ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def test_init_requires_type_when_value_is_not_np_array(self):
("float8e5m2", np.uint8, ir.DataType.FLOAT8E5M2),
("float8e5m2fnuz", np.uint8, ir.DataType.FLOAT8E5M2FNUZ),
("float8e8m0", np.uint8, ir.DataType.FLOAT8E8M0),
("int2", np.int8, ir.DataType.INT2),
("int2_uint8", np.uint8, ir.DataType.INT2),
("int4", np.int8, ir.DataType.INT4),
("int4_uint8", np.uint8, ir.DataType.INT4),
("uint2", np.uint8, ir.DataType.UINT2),
("uint4", np.uint8, ir.DataType.UINT4),
("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1),
]
Expand Down Expand Up @@ -146,6 +149,38 @@ def test_tobytes(self):
tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
self.assertEqual(tensor.tobytes(), array.tobytes())

def test_tobytes_returns_packed_data_for_int2(self):
array = np.array([-2, -1, 0, 1, 1, -2, 1], dtype=np.int8)
# Test array size not divisible by 4
assert len(array) % 4 != 0
tensor = _core.Tensor(array, dtype=ir.DataType.INT2)
# -2, -1, 0, 1 => [0b10, 0b11, 0b00, 0b01] => 0b01001110 = 0x4E
# 1, -2, 1, 0 (padding) => [0b01, 0b10, 0b01, 0b00] => 0b00011001 = 0x19
self.assertEqual(tensor.tobytes(), b"\x4e\x19")

def test_tobytes_returns_packed_data_for_int2_ml_dtypes(self):
array = np.array([-2, -1, 0, 1, 1, -2, 1], dtype=ml_dtypes.int2)
# Test array size not divisible by 4
assert len(array) % 4 != 0
tensor = _core.Tensor(array, dtype=ir.DataType.INT2)
self.assertEqual(tensor.tobytes(), b"\x4e\x19")

def test_tobytes_returns_packed_data_for_uint2(self):
array = np.array([0, 1, 2, 3, 3, 2, 1], dtype=np.uint8)
# Test array size not divisible by 4
assert len(array) % 4 != 0
tensor = _core.Tensor(array, dtype=ir.DataType.UINT2)
# 0, 1, 2, 3 => 0b11100100 = 0xE4
# 3, 2, 1, 0 (padding) => 0b00011011 = 0x1B
self.assertEqual(tensor.tobytes(), b"\xe4\x1b")

def test_tobytes_returns_packed_data_for_uint2_ml_dtypes(self):
array = np.array([0, 1, 2, 3, 3, 2, 1], dtype=ml_dtypes.uint2)
# Test array size not divisible by 4
assert len(array) % 4 != 0
tensor = _core.Tensor(array, dtype=ir.DataType.UINT2)
self.assertEqual(tensor.tobytes(), b"\xe4\x1b")

def test_tobytes_returns_packed_data_for_int4(self):
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
# Test odd sized array
Expand Down
15 changes: 15 additions & 0 deletions src/onnx_ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DataType(enum.IntEnum):
INT4 = 22
FLOAT4E2M1 = 23
FLOAT8E8M0 = 24
INT2 = 25
UINT2 = 26

@classmethod
def from_numpy(cls, dtype: np.dtype) -> DataType:
Expand Down Expand Up @@ -101,6 +103,10 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
return DataType.INT4
if dtype.names == ("float4e2m1",):
return DataType.FLOAT4E2M1
if dtype.names == ("int2",):
return DataType.INT2
if dtype.names == ("uint2",):
return DataType.UINT2
raise TypeError(f"Unsupported numpy data type: {dtype}")

@classmethod
Expand Down Expand Up @@ -329,6 +335,8 @@ def is_integer(self) -> bool:
DataType.UINT64,
DataType.UINT4,
DataType.INT4,
DataType.INT2,
DataType.UINT2,
}

def is_signed(self) -> bool:
Expand All @@ -354,6 +362,7 @@ def is_signed(self) -> bool:
DataType.INT4,
DataType.FLOAT4E2M1,
DataType.FLOAT8E8M0,
DataType.INT2,
}

def is_string(self) -> bool:
Expand Down Expand Up @@ -394,6 +403,8 @@ def __str__(self) -> str:
DataType.INT4: 4,
DataType.FLOAT4E2M1: 4,
DataType.FLOAT8E8M0: 8,
DataType.INT2: 2,
DataType.UINT2: 2,
}


Expand Down Expand Up @@ -423,6 +434,8 @@ def __str__(self) -> str:
np.dtype(ml_dtypes.int4): DataType.INT4,
np.dtype(ml_dtypes.uint4): DataType.UINT4,
np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
np.dtype(ml_dtypes.int2): DataType.INT2,
np.dtype(ml_dtypes.uint2): DataType.UINT2,
}

# ONNX DataType to Numpy dtype.
Expand All @@ -442,12 +455,14 @@ def __str__(self) -> str:
DataType.FLOAT4E2M1: "f4e2m1",
DataType.COMPLEX64: "c64",
DataType.COMPLEX128: "c128",
DataType.INT2: "i2",
DataType.INT4: "i4",
DataType.INT8: "i8",
DataType.INT16: "i16",
DataType.INT32: "i32",
DataType.INT64: "i64",
DataType.BOOL: "b8",
DataType.UINT2: "u2",
DataType.UINT4: "u4",
DataType.UINT8: "u8",
DataType.UINT16: "u16",
Expand Down
6 changes: 6 additions & 0 deletions src/onnx_ir/_enums_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_enums_are_the_same_as_spec(self):
self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
if hasattr(onnx.TensorProto, "FLOAT8E8M0"):
self.assertEqual(_enums.DataType.FLOAT8E8M0, onnx.TensorProto.FLOAT8E8M0)
if hasattr(onnx.TensorProto, "INT2"):
self.assertEqual(_enums.DataType.INT2, onnx.TensorProto.INT2)
if hasattr(onnx.TensorProto, "UINT2"):
self.assertEqual(_enums.DataType.UINT2, onnx.TensorProto.UINT2)
self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)

@parameterized.parameterized.expand(
Expand Down Expand Up @@ -75,6 +79,8 @@ def test_enums_are_the_same_as_spec(self):
("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4),
("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1),
("float8e8m0", np.dtype(ml_dtypes.float8_e8m0fnu), _enums.DataType.FLOAT8E8M0),
("int2", np.dtype(ml_dtypes.int2), _enums.DataType.INT2),
("uint2", np.dtype(ml_dtypes.uint2), _enums.DataType.UINT2),
]
)
def test_from_numpy_takes_np_dtype_and_returns_data_type(
Expand Down
39 changes: 39 additions & 0 deletions src/onnx_ir/_type_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,42 @@ def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArr
result = result[:-1]
result.resize(dims, refcheck=False)
return result


def pack_2bitx4(array: np.ndarray) -> npt.NDArray[np.uint8]:
"""Convert a numpy array to flatten, packed int2/uint2. Elements must be in the correct range."""
# Create a 1D copy
array_flat = array.ravel().view(np.uint8).copy()
size = array.size
padding = (4 - (size % 4)) % 4
if padding > 0:
array_flat.resize([size + padding], refcheck=False)
array_flat &= 0x03
array_flat[1::4] <<= 2
array_flat[2::4] <<= 4
array_flat[3::4] <<= 6
return array_flat[0::4] | array_flat[1::4] | array_flat[2::4] | array_flat[3::4] # type: ignore[return-type]


def unpack_2bitx4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
"""Convert a packed uint2 array to unpacked uint2 array represented as uint8.

Args:
data: A numpy array.
dims: The dimensions are used to reshape the unpacked buffer.

Returns:
A numpy array of int8/uint8 reshaped to dims.
"""
assert data.dtype == np.uint8, "Input data must be of type uint8"
result = np.empty([data.size * 4], dtype=data.dtype)
result[0::4] = data & np.uint8(0x03)
result[1::4] = (data & np.uint8(0x0C)) >> np.uint8(2)
result[2::4] = (data & np.uint8(0x30)) >> np.uint8(4)
result[3::4] = (data & np.uint8(0xC0)) >> np.uint8(6)
total_elements = int(np.prod(dims))
if result.size > total_elements:
# handle padding due to element count not being a multiple of 4
result = result[:total_elements]
result.resize(dims, refcheck=False)
return result
14 changes: 13 additions & 1 deletion src/onnx_ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ def numpy(self) -> np.ndarray:
return _type_casting.unpack_4bitx2(
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
).view(dtype.numpy())
if dtype.bitwidth == 2:
return _type_casting.unpack_2bitx4(
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
).view(dtype.numpy())
return np.frombuffer(
self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
).reshape(shape)
Expand All @@ -408,9 +412,11 @@ def numpy(self) -> np.ndarray:
_enums.DataType.FLOAT8E8M0,
_enums.DataType.INT16,
_enums.DataType.INT32,
_enums.DataType.INT2,
_enums.DataType.INT4,
_enums.DataType.INT8,
_enums.DataType.UINT16,
_enums.DataType.UINT2,
_enums.DataType.UINT4,
_enums.DataType.UINT8,
}, f"Unsupported dtype {dtype} for int32_data"
Expand All @@ -426,6 +432,10 @@ def numpy(self) -> np.ndarray:
return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
dtype.numpy()
)
if dtype.bitwidth == 2:
return _type_casting.unpack_2bitx4(array.astype(np.uint8), shape).view(
dtype.numpy()
)
raise ValueError(
f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
)
Expand Down Expand Up @@ -507,11 +517,13 @@ def tobytes(self) -> bytes:
_enums.DataType.FLOAT8E5M2,
_enums.DataType.FLOAT8E5M2FNUZ,
_enums.DataType.FLOAT8E8M0,
_enums.DataType.INT2,
_enums.DataType.INT4,
_enums.DataType.UINT2,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# uint4 and int4 values are already packed, even when stored as int32
# uint2, uint4, int2 and int4 values are already packed, even when stored as int32
# so we don't need to pack them again
return array.astype(_little_endian_dtype(np.uint8)).tobytes()
assert self.dtype == _enums.DataType.INT32
Expand Down
26 changes: 26 additions & 0 deletions src/onnx_ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,21 @@ def test_tensor_proto_tensor_float8(self, _: str, dtype: int, np_dtype):
("INT32", onnx.TensorProto.INT32),
("INT64", onnx.TensorProto.INT64),
("INT4", onnx.TensorProto.INT4),
("INT2", 25), # INT2 value
]
)
def test_tensor_proto_tensor_int(self, _: str, dtype: int):
# INT2 is not yet supported in ONNX numpy_helper, so we handle it specially
if dtype == 25: # INT2
# Create tensor proto manually since ONNX helper might not support this type yet
data_array = np.array([[-1, 0, 1]], dtype=ml_dtypes.int2)
# Create an IR tensor which will pack the data correctly
ir_tensor = ir.Tensor(data_array)
tensor_proto = serde.to_proto(ir_tensor)
tensor = serde.TensorProtoTensor(tensor_proto)
np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.int2), data_array)
return # Skip remaining tests for INT2 as ONNX doesn't support it yet

tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 4], [-1, 0, 1, 8])
tensor = serde.TensorProtoTensor(tensor_proto)
expected_array = onnx.numpy_helper.to_array(
Expand All @@ -311,9 +323,21 @@ def test_tensor_proto_tensor_int(self, _: str, dtype: int):
("UINT32", onnx.TensorProto.UINT32),
("UINT64", onnx.TensorProto.UINT64),
("UINT4", onnx.TensorProto.UINT4),
("UINT2", 26), # UINT2 value
]
)
def test_tensor_proto_tensor_uint(self, _: str, dtype: int):
# UINT2 is not yet supported in ONNX numpy_helper, so we handle it specially
if dtype == 26: # UINT2
# Create tensor proto manually since ONNX helper might not support this type yet
data_array = np.array([[0, 1, 2, 3]], dtype=ml_dtypes.uint2)
# Create an IR tensor which will pack the data correctly
ir_tensor = ir.Tensor(data_array)
tensor_proto = serde.to_proto(ir_tensor)
tensor = serde.TensorProtoTensor(tensor_proto)
np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.uint2), data_array)
return # Skip remaining tests for UINT2 as ONNX doesn't support it yet

tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 3], [0, 1, 8])
tensor = serde.TensorProtoTensor(tensor_proto)
expected_array = onnx.numpy_helper.to_array(tensor_proto)
Expand Down Expand Up @@ -396,7 +420,9 @@ def test_tensor_proto_tensor_empty_tensor(self):
("FLOAT8E5M2", ir.DataType.FLOAT8E5M2),
("FLOAT8E5M2FNUZ", ir.DataType.FLOAT8E5M2FNUZ),
("FLOAT8E8M0", ir.DataType.FLOAT8E8M0),
("UINT2", ir.DataType.UINT2),
("UINT4", ir.DataType.UINT4),
("INT2", ir.DataType.INT2),
("INT4", ir.DataType.INT4),
("FLOAT4E2M1", ir.DataType.FLOAT4E2M1),
],
Expand Down
Loading
Loading