diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py index 863456228a91..b5859ab1b873 100644 --- a/runtime/bindings/python/iree/runtime/array_interop.py +++ b/runtime/bindings/python/iree/runtime/array_interop.py @@ -294,7 +294,6 @@ def asdevicearray( (np.float16, HalElementType.FLOAT_16), (np.float32, HalElementType.FLOAT_32), (np.float64, HalElementType.FLOAT_64), - (np.float16, HalElementType.FLOAT_16), (np.int32, HalElementType.SINT_32), (np.int64, HalElementType.SINT_64), (np.int16, HalElementType.SINT_16), diff --git a/runtime/bindings/python/iree/runtime/io.py b/runtime/bindings/python/iree/runtime/io.py index 246243a2ca8c..e9d54dae71c7 100644 --- a/runtime/bindings/python/iree/runtime/io.py +++ b/runtime/bindings/python/iree/runtime/io.py @@ -8,13 +8,16 @@ import array from functools import reduce -import numpy +import json +import numpy as np from os import PathLike -from pathlib import Path -from ._binding import ParameterIndex +from ._binding import ParameterIndex, ParameterIndexEntry __all__ = [ + "parameter_index_add_numpy_ndarray", + "parameter_index_entry_as_numpy_flat_ndarray", + "parameter_index_entry_as_numpy_ndarray", "SplatValue", "save_archive_file", ] @@ -23,7 +26,7 @@ class SplatValue: def __init__( self, - pattern: Union[array.array, numpy.ndarray], + pattern: Union[array.array, np.ndarray], count: Union[Sequence[int], int], ): if hasattr(pattern, "shape"): @@ -64,3 +67,162 @@ def save_archive_file(entries: dict[str, Union[Any, SplatValue]], file_path: Pat else: index.add_buffer(key, value) index.create_archive_file(str(file_path)) + + +def parameter_index_add_numpy_ndarray( + index: ParameterIndex, name: str, array: np.ndarray +): + """Adds an ndarray to the index.""" + metadata = _make_tensor_metadata(array) + # 0d arrays are special in both torch/numpy in different ways that makes + # it hard to reliably get a memory view of their contents. Since we + # know that 0d is always small, we just force a copy when in numpy + # land and that seems to get it on the happy path. + # See: https://github.com/iree-org/iree-turbine/issues/29 + if len(array.shape) == 0: + flat_array = array.copy() + else: + flat_array = np.ascontiguousarray(array).view(np.uint8) + index.add_buffer(name, flat_array, metadata=metadata) + + +def parameter_index_entry_as_numpy_flat_ndarray( + index_entry: ParameterIndexEntry, +) -> np.ndarray: + """Accesses the contents as a uint8 flat tensor. + + If it is a splat, then the tensor will be a view of the splat pattern. + + Raises a ValueError on unsupported entries. + """ + if index_entry.is_file: + wrapper = np.array(index_entry.file_view, copy=False) + elif index_entry.is_splat: + wrapper = np.array(index_entry.splat_pattern, copy=True) + else: + raise ValueError(f"Unsupported ParameterIndexEntry: {index_entry}") + + return wrapper + + +def parameter_index_entry_as_numpy_ndarray( + index_entry: ParameterIndexEntry, +) -> np.ndarray: + """Returns a tensor viewed with appropriate shape/dtype from metadata. + + Raises a ValueError if unsupported. + """ + + # Decode metadata. + versioned_metadata = index_entry.metadata.decode() + metadata_parts = versioned_metadata.split(_metadata_version_separator, maxsplit=1) + if len(metadata_parts) == 1: + raise ValueError( + ( + f'Invalid metadata for parameter index entry "{index_entry.key}".' + f' Expected format version prefix not found in "{metadata_parts[0][:100]}".' + ) + ) + format_version = metadata_parts[0] + metadata = metadata_parts[1] + if ( + format_version != _metadata_version + and format_version != _metadata_iree_turbine_version + ): + raise ValueError( + ( + f'Unsupported metadata format version "{format_version}" for parameter ' + 'index entry "{index_entry.key}": Cannot convert to tensor' + ) + ) + d = json.loads(metadata) + try: + type_name = d["type"] + if d["type"] != "Tensor": + raise ValueError( + f"Metadata for parameter entry {index_entry.key} is not a Tensor ('{type_name}')" + ) + dtype_name = d["dtype"] + shape = d["shape"] + except KeyError as e: + raise ValueError(f"Bad metadata for parameter entry {index_entry.key}") from e + + # Unpack/validate. + try: + dtype = _NAME_TO_DTYPE[dtype_name] + except KeyError: + raise ValueError(f"Unknown dtype name '{dtype_name}'") + try: + shape = [int(d) for d in shape] + except ValueError as e: + raise ValueError(f"Illegal shape for parameter entry {index_entry.key}") from e + + t = parameter_index_entry_as_numpy_flat_ndarray(index_entry) + return t.view(dtype=dtype).reshape(shape) + + +_DTYPE_TO_NAME = ( + (np.float16, "float16"), + (np.float32, "float32"), + (np.float64, "float64"), + (np.int32, "int32"), + (np.int64, "int64"), + (np.int16, "int16"), + (np.int8, "int8"), + (np.uint32, "uint32"), + (np.uint64, "uint64"), + (np.uint16, "uint16"), + (np.uint8, "uint8"), + (np.bool_, "bool"), + (np.complex64, "complex64"), + (np.complex128, "complex128"), +) + +_NAME_TO_DTYPE: dict[str, np.dtype] = { + name: np_dtype for np_dtype, name in _DTYPE_TO_NAME +} + + +def _map_dtype_to_name(dtype) -> str: + for match_dtype, dtype_name in _DTYPE_TO_NAME: + if match_dtype == dtype: + return dtype_name + + raise KeyError(f"Numpy dtype {dtype} not found.") + + +_metadata_version = "TENSORv0" +"""Magic number to identify the format version. +The current version that will be used when adding tensors to a parameter index.""" + +_metadata_iree_turbine_version = "PYTORCH" +"""There are files created with IREE Turbine that use this prefix. +This is here to maintain the ability to load such files.""" + +_metadata_version_separator = ":" +"""The separator between the format version and the actual metadata. +The metadata has the following format """ + + +def _make_tensor_metadata(t: np.ndarray) -> str: + """Makes a tensor metadata blob that can be used to reconstruct the tensor.""" + dtype = t.dtype + dtype_name = _map_dtype_to_name(dtype) + is_complex = np.issubdtype(dtype, np.complexfloating) + is_floating_point = np.issubdtype(dtype, np.floating) + is_signed = np.issubdtype(dtype, np.signedinteger) + dtype_desc = { + "class_name": type(dtype).__name__, + "is_complex": is_complex, + "is_floating_point": is_floating_point, + "is_signed": is_signed, + "itemsize": dtype.itemsize, + } + d = { + "type": "Tensor", + "dtype": dtype_name, + "shape": list(t.shape), + "dtype_desc": dtype_desc, + } + encoded = f"{_metadata_version}{_metadata_version_separator}{json.dumps(d)}" + return encoded diff --git a/runtime/bindings/python/tests/io_test.py b/runtime/bindings/python/tests/io_test.py index 026c9b665364..a854cb9d9a2e 100644 --- a/runtime/bindings/python/tests/io_test.py +++ b/runtime/bindings/python/tests/io_test.py @@ -96,6 +96,47 @@ def verify_archive(file_path: Path): verify_archive(file_path) gc.collect() + def testParameterIndexEntryFromToNumpy(self): + array = np.array([[1, 2], [3, 4]], dtype=np.int32) + index = rt.ParameterIndex() + key = "key" + rt.parameter_index_add_numpy_ndarray(index, key, array) + assert index.items()[0][0] == key + index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray( + index.items()[0][1] + ) + np.testing.assert_equal(index_entry_as_array, array) + + def testParameterIndexEntryFromToNumpyZeroDims(self): + array = np.array(1234, dtype=np.int32) + index = rt.ParameterIndex() + key = "key" + rt.parameter_index_add_numpy_ndarray(index, key, array) + assert index.items()[0][0] == key + index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray( + index.items()[0][1] + ) + np.testing.assert_equal(index_entry_as_array, array) + + def testParameterIndexEntryFromIreeTurbine(self): + """Verify that we are able to load a tensor from IRPA generated with IREE + Turbine. + We want to maintain backward compatibility with existing IRPA files.""" + index = rt.ParameterIndex() + irpa_path = str( + Path(__file__).resolve().parent + / "testdata" + / "tensor_saved_with_iree_turbine.irpa" + ) + index.load(irpa_path) + items = index.items() + assert len(items) == 1 + key, entry = items[0] + assert key == "the_torch_tensor" + index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray(entry) + expected_array = np.array([1, 2, 3, 4], dtype=np.uint8) + np.testing.assert_array_equal(index_entry_as_array, expected_array, strict=True) + def testFileHandleWrap(self): fh = rt.FileHandle.wrap_memory(b"foobar") view = fh.host_allocation diff --git a/runtime/bindings/python/tests/testdata/generate_tensor_saved_with_iree_turbine.py b/runtime/bindings/python/tests/testdata/generate_tensor_saved_with_iree_turbine.py new file mode 100644 index 000000000000..3a8167fc296c --- /dev/null +++ b/runtime/bindings/python/tests/testdata/generate_tensor_saved_with_iree_turbine.py @@ -0,0 +1,7 @@ +from iree.turbine.aot import ParameterArchiveBuilder +import torch + +archive = ParameterArchiveBuilder() +tensor = torch.tensor([1, 2, 3, 4], dtype=torch.uint8) +archive.add_tensor("the_torch_tensor", tensor) +archive.save("tensor_saved_with_iree_turbine.irpa") diff --git a/runtime/bindings/python/tests/testdata/tensor_saved_with_iree_turbine.irpa b/runtime/bindings/python/tests/testdata/tensor_saved_with_iree_turbine.irpa new file mode 100644 index 000000000000..c9a4cea922b8 Binary files /dev/null and b/runtime/bindings/python/tests/testdata/tensor_saved_with_iree_turbine.irpa differ