Skip to content

Commit

Permalink
Using the new core type system (#903)
Browse files Browse the repository at this point in the history
* Catch up the type code refactoring

* Massive refactoring again for new type codes

* Use core types for points and argvals

* Use an aligned struct for argmin/argmax

* Use Legate field in the data interface implementation

* Use a temporary commit hash for testing

* Catch up the Legate core changes

* Address comments from @ jjwilke

* Update the legate core commit hash
  • Loading branch information
magnatelee authored May 4, 2023
1 parent a1f9fd6 commit 9b9d5e0
Show file tree
Hide file tree
Showing 169 changed files with 901 additions and 1,112 deletions.
2 changes: 1 addition & 1 deletion cmake/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"git_url" : "https://github.com/nv-legate/legate.core.git",
"git_shallow": false,
"always_download": false,
"git_tag" : "b3e280d6212aa2ec0af619b6fee3ac24d069897e"
"git_tag" : "149fa50bce56350e84f3fad4d453b5f5b77b935d"
}
}
}
48 changes: 15 additions & 33 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
cast,
)

import legate.core.types as ty
import numpy as np
import pyarrow # type: ignore [import]
from legate.core import Array
from legate.core import Array, Field
from numpy.core.multiarray import ( # type: ignore [attr-defined]
normalize_axis_index,
)
Expand All @@ -56,7 +54,7 @@
from .coverage import FALLBACK_WARNING, clone_class
from .runtime import runtime
from .types import NdShape
from .utils import deep_apply, dot_modes
from .utils import deep_apply, dot_modes, to_core_dtype

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -174,21 +172,6 @@ def maybe_convert_to_np_ndarray(obj: Any) -> Any:
return obj


# FIXME: we can't give an accurate return type as mypy thinks
# the pyarrow import can be ignored, and can't override the check
# either, because no-any-unimported needs Python >= 3.10. We can
# fix it once we bump up the Python version
def convert_numpy_dtype_to_pyarrow(dtype: np.dtype[Any]) -> Any:
if dtype.kind != "c":
return pyarrow.from_numpy_dtype(dtype)
elif dtype == np.complex64:
return ty.complex64
elif dtype == np.complex128:
return ty.complex128
else:
raise ValueError(f"Unsupported NumPy dtype: {dtype}")


NDARRAY_INTERNAL = {
"__array_finalize__",
"__array_function__",
Expand Down Expand Up @@ -242,9 +225,15 @@ def __init__(
for inp in inputs
if isinstance(inp, ndarray)
]
self._thunk = runtime.create_empty_thunk(
sanitized_shape, dtype, inputs
)
core_dtype = to_core_dtype(dtype)
if core_dtype is not None:
self._thunk = runtime.create_empty_thunk(
sanitized_shape, core_dtype, inputs
)
else:
self._thunk = runtime.create_eager_thunk(
sanitized_shape, dtype
)
else:
self._thunk = thunk
self._legate_data: Union[dict[str, Any], None] = None
Expand Down Expand Up @@ -280,24 +269,17 @@ def _sanitize_shape(
@property
def __legate_data_interface__(self) -> dict[str, Any]:
if self._legate_data is None:
# All of our thunks implement the Legate Store interface
# so we just need to convert our type and stick it in
# a Legate Array
arrow_type = convert_numpy_dtype_to_pyarrow(self.dtype)
# If the thunk is an eager array, we need to convert it to a
# deferred array so we can extract a legate store
deferred_thunk = runtime.to_deferred_array(self._thunk)
# We don't have nullable data for the moment
# until we support masked arrays
array = Array(arrow_type, [None, deferred_thunk.base])
dtype = deferred_thunk.base.type
array = Array(dtype, [None, deferred_thunk.base])
self._legate_data = dict()
self._legate_data["version"] = 1
data = dict()
field = pyarrow.field(
"cuNumeric Array", arrow_type, nullable=False
)
data[field] = array
self._legate_data["data"] = data
field = Field("cuNumeric Array", dtype)
self._legate_data["data"] = {field: array}
return self._legate_data

# Properties for ndarray
Expand Down
54 changes: 7 additions & 47 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
from abc import abstractmethod
from enum import IntEnum, unique
from typing import TYPE_CHECKING, Any, List, Union, cast
from typing import TYPE_CHECKING, Union, cast

import numpy as np
from legate.core import Library, get_legate_runtime
Expand Down Expand Up @@ -197,15 +197,6 @@ class _CunumericSharedLib:
CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME: int
CUNUMERIC_TUNABLE_NUM_GPUS: int
CUNUMERIC_TUNABLE_NUM_PROCS: int
CUNUMERIC_TYPE_POINT1: int
CUNUMERIC_TYPE_POINT2: int
CUNUMERIC_TYPE_POINT3: int
CUNUMERIC_TYPE_POINT4: int
CUNUMERIC_TYPE_POINT5: int
CUNUMERIC_TYPE_POINT6: int
CUNUMERIC_TYPE_POINT7: int
CUNUMERIC_TYPE_POINT8: int
CUNUMERIC_TYPE_POINT9: int
CUNUMERIC_UNARY_OP: int
CUNUMERIC_UNARY_RED: int
CUNUMERIC_UNIQUE: int
Expand Down Expand Up @@ -274,6 +265,12 @@ class _CunumericSharedLib:
def cunumeric_has_curand(self) -> int:
...

@abstractmethod
def cunumeric_register_reduction_op(
self, type_uid: int, elem_type_code: int
) -> None:
...


# Load the cuNumeric library first so we have a shard object that
# we can use to initialize all these configuration enumerations
Expand Down Expand Up @@ -500,13 +497,6 @@ class RandGenCode(IntEnum):
INTEGER = 3


# Match these to CuNumericRedopID in cunumeric_c.h
@unique
class CuNumericRedopCode(IntEnum):
ARGMAX = 1
ARGMIN = 2


# Match these to CuNumericTunable in cunumeric_c.h
@unique
class CuNumericTunable(IntEnum):
Expand Down Expand Up @@ -774,33 +764,3 @@ def reverse(in_string: Union[str, None]) -> str:
return "forward"
else:
return in_string


# Match these to CuNumericTypeCodes in cunumeric_c.h
# we start from POINT2 type since POINT1 is int8 type
_CUNUMERIC_DTYPES: List[tuple[np.dtype[Any], int, int]] = [
(np.dtype("i8, i8"), 16, _cunumeric.CUNUMERIC_TYPE_POINT2),
(np.dtype("i8, i8, i8"), 24, _cunumeric.CUNUMERIC_TYPE_POINT3),
(np.dtype("i8, i8, i8, i8"), 32, _cunumeric.CUNUMERIC_TYPE_POINT4),
(np.dtype("i8, i8, i8, i8, i8"), 40, _cunumeric.CUNUMERIC_TYPE_POINT5),
(
np.dtype("i8, i8, i8, i8, i8, i8"),
48,
_cunumeric.CUNUMERIC_TYPE_POINT6,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8"),
56,
_cunumeric.CUNUMERIC_TYPE_POINT7,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8, i8"),
64,
_cunumeric.CUNUMERIC_TYPE_POINT8,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8, i8, i8"),
72,
_cunumeric.CUNUMERIC_TYPE_POINT9,
),
]
Loading

0 comments on commit 9b9d5e0

Please sign in to comment.