Skip to content

Commit

Permalink
[runtime][python] Add IRPA entry conversion to/from numpy (#19492)
Browse files Browse the repository at this point in the history
Add iterop between numpy ndarray and parameter index. This is an
adaptation of the original in IREE Turbine
https://github.com/iree-org/iree-turbine/blob/142c8a5044a4fedb43a11229f462363b05743b23/iree/turbine/aot/params.py
The goal is to maintain compatibility with IRPA files that were already
generated with IREE Turbine.
At some point we can refactor the IREE Turbine side to use this
implementation.

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar authored Dec 18, 2024
1 parent 4e29bbb commit 078c3ec
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 5 deletions.
1 change: 0 additions & 1 deletion runtime/bindings/python/iree/runtime/array_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def asdevicearray(
(np.float16, HalElementType.FLOAT_16),
(np.float32, HalElementType.FLOAT_32),
(np.float64, HalElementType.FLOAT_64),
(np.float16, HalElementType.FLOAT_16),
(np.int32, HalElementType.SINT_32),
(np.int64, HalElementType.SINT_64),
(np.int16, HalElementType.SINT_16),
Expand Down
170 changes: 166 additions & 4 deletions runtime/bindings/python/iree/runtime/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

import array
from functools import reduce
import numpy
import json
import numpy as np
from os import PathLike
from pathlib import Path

from ._binding import ParameterIndex
from ._binding import ParameterIndex, ParameterIndexEntry

__all__ = [
"parameter_index_add_numpy_ndarray",
"parameter_index_entry_as_numpy_flat_ndarray",
"parameter_index_entry_as_numpy_ndarray",
"SplatValue",
"save_archive_file",
]
Expand All @@ -23,7 +26,7 @@
class SplatValue:
def __init__(
self,
pattern: Union[array.array, numpy.ndarray],
pattern: Union[array.array, np.ndarray],
count: Union[Sequence[int], int],
):
if hasattr(pattern, "shape"):
Expand Down Expand Up @@ -64,3 +67,162 @@ def save_archive_file(entries: dict[str, Union[Any, SplatValue]], file_path: Pat
else:
index.add_buffer(key, value)
index.create_archive_file(str(file_path))


def parameter_index_add_numpy_ndarray(
index: ParameterIndex, name: str, array: np.ndarray
):
"""Adds an ndarray to the index."""
metadata = _make_tensor_metadata(array)
# 0d arrays are special in both torch/numpy in different ways that makes
# it hard to reliably get a memory view of their contents. Since we
# know that 0d is always small, we just force a copy when in numpy
# land and that seems to get it on the happy path.
# See: https://github.com/iree-org/iree-turbine/issues/29
if len(array.shape) == 0:
flat_array = array.copy()
else:
flat_array = np.ascontiguousarray(array).view(np.uint8)
index.add_buffer(name, flat_array, metadata=metadata)


def parameter_index_entry_as_numpy_flat_ndarray(
index_entry: ParameterIndexEntry,
) -> np.ndarray:
"""Accesses the contents as a uint8 flat tensor.
If it is a splat, then the tensor will be a view of the splat pattern.
Raises a ValueError on unsupported entries.
"""
if index_entry.is_file:
wrapper = np.array(index_entry.file_view, copy=False)
elif index_entry.is_splat:
wrapper = np.array(index_entry.splat_pattern, copy=True)
else:
raise ValueError(f"Unsupported ParameterIndexEntry: {index_entry}")

return wrapper


def parameter_index_entry_as_numpy_ndarray(
index_entry: ParameterIndexEntry,
) -> np.ndarray:
"""Returns a tensor viewed with appropriate shape/dtype from metadata.
Raises a ValueError if unsupported.
"""

# Decode metadata.
versioned_metadata = index_entry.metadata.decode()
metadata_parts = versioned_metadata.split(_metadata_version_separator, maxsplit=1)
if len(metadata_parts) == 1:
raise ValueError(
(
f'Invalid metadata for parameter index entry "{index_entry.key}".'
f' Expected format version prefix not found in "{metadata_parts[0][:100]}".'
)
)
format_version = metadata_parts[0]
metadata = metadata_parts[1]
if (
format_version != _metadata_version
and format_version != _metadata_iree_turbine_version
):
raise ValueError(
(
f'Unsupported metadata format version "{format_version}" for parameter '
'index entry "{index_entry.key}": Cannot convert to tensor'
)
)
d = json.loads(metadata)
try:
type_name = d["type"]
if d["type"] != "Tensor":
raise ValueError(
f"Metadata for parameter entry {index_entry.key} is not a Tensor ('{type_name}')"
)
dtype_name = d["dtype"]
shape = d["shape"]
except KeyError as e:
raise ValueError(f"Bad metadata for parameter entry {index_entry.key}") from e

# Unpack/validate.
try:
dtype = _NAME_TO_DTYPE[dtype_name]
except KeyError:
raise ValueError(f"Unknown dtype name '{dtype_name}'")
try:
shape = [int(d) for d in shape]
except ValueError as e:
raise ValueError(f"Illegal shape for parameter entry {index_entry.key}") from e

t = parameter_index_entry_as_numpy_flat_ndarray(index_entry)
return t.view(dtype=dtype).reshape(shape)


_DTYPE_TO_NAME = (
(np.float16, "float16"),
(np.float32, "float32"),
(np.float64, "float64"),
(np.int32, "int32"),
(np.int64, "int64"),
(np.int16, "int16"),
(np.int8, "int8"),
(np.uint32, "uint32"),
(np.uint64, "uint64"),
(np.uint16, "uint16"),
(np.uint8, "uint8"),
(np.bool_, "bool"),
(np.complex64, "complex64"),
(np.complex128, "complex128"),
)

_NAME_TO_DTYPE: dict[str, np.dtype] = {
name: np_dtype for np_dtype, name in _DTYPE_TO_NAME
}


def _map_dtype_to_name(dtype) -> str:
for match_dtype, dtype_name in _DTYPE_TO_NAME:
if match_dtype == dtype:
return dtype_name

raise KeyError(f"Numpy dtype {dtype} not found.")


_metadata_version = "TENSORv0"
"""Magic number to identify the format version.
The current version that will be used when adding tensors to a parameter index."""

_metadata_iree_turbine_version = "PYTORCH"
"""There are files created with IREE Turbine that use this prefix.
This is here to maintain the ability to load such files."""

_metadata_version_separator = ":"
"""The separator between the format version and the actual metadata.
The metadata has the following format <format-version><separator><metadata>"""


def _make_tensor_metadata(t: np.ndarray) -> str:
"""Makes a tensor metadata blob that can be used to reconstruct the tensor."""
dtype = t.dtype
dtype_name = _map_dtype_to_name(dtype)
is_complex = np.issubdtype(dtype, np.complexfloating)
is_floating_point = np.issubdtype(dtype, np.floating)
is_signed = np.issubdtype(dtype, np.signedinteger)
dtype_desc = {
"class_name": type(dtype).__name__,
"is_complex": is_complex,
"is_floating_point": is_floating_point,
"is_signed": is_signed,
"itemsize": dtype.itemsize,
}
d = {
"type": "Tensor",
"dtype": dtype_name,
"shape": list(t.shape),
"dtype_desc": dtype_desc,
}
encoded = f"{_metadata_version}{_metadata_version_separator}{json.dumps(d)}"
return encoded
41 changes: 41 additions & 0 deletions runtime/bindings/python/tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,47 @@ def verify_archive(file_path: Path):
verify_archive(file_path)
gc.collect()

def testParameterIndexEntryFromToNumpy(self):
array = np.array([[1, 2], [3, 4]], dtype=np.int32)
index = rt.ParameterIndex()
key = "key"
rt.parameter_index_add_numpy_ndarray(index, key, array)
assert index.items()[0][0] == key
index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray(
index.items()[0][1]
)
np.testing.assert_equal(index_entry_as_array, array)

def testParameterIndexEntryFromToNumpyZeroDims(self):
array = np.array(1234, dtype=np.int32)
index = rt.ParameterIndex()
key = "key"
rt.parameter_index_add_numpy_ndarray(index, key, array)
assert index.items()[0][0] == key
index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray(
index.items()[0][1]
)
np.testing.assert_equal(index_entry_as_array, array)

def testParameterIndexEntryFromIreeTurbine(self):
"""Verify that we are able to load a tensor from IRPA generated with IREE
Turbine.
We want to maintain backward compatibility with existing IRPA files."""
index = rt.ParameterIndex()
irpa_path = str(
Path(__file__).resolve().parent
/ "testdata"
/ "tensor_saved_with_iree_turbine.irpa"
)
index.load(irpa_path)
items = index.items()
assert len(items) == 1
key, entry = items[0]
assert key == "the_torch_tensor"
index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray(entry)
expected_array = np.array([1, 2, 3, 4], dtype=np.uint8)
np.testing.assert_array_equal(index_entry_as_array, expected_array, strict=True)

def testFileHandleWrap(self):
fh = rt.FileHandle.wrap_memory(b"foobar")
view = fh.host_allocation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from iree.turbine.aot import ParameterArchiveBuilder
import torch

archive = ParameterArchiveBuilder()
tensor = torch.tensor([1, 2, 3, 4], dtype=torch.uint8)
archive.add_tensor("the_torch_tensor", tensor)
archive.save("tensor_saved_with_iree_turbine.irpa")
Binary file not shown.

0 comments on commit 078c3ec

Please sign in to comment.