Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the new core type system #903

Merged
merged 13 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"git_url" : "https://github.com/nv-legate/legate.core.git",
"git_shallow": false,
"always_download": false,
"git_tag" : "b3e280d6212aa2ec0af619b6fee3ac24d069897e"
"git_tag" : "149fa50bce56350e84f3fad4d453b5f5b77b935d"
}
}
}
48 changes: 15 additions & 33 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
cast,
)

import legate.core.types as ty
import numpy as np
import pyarrow # type: ignore [import]
from legate.core import Array
from legate.core import Array, Field
from numpy.core.multiarray import ( # type: ignore [attr-defined]
normalize_axis_index,
)
Expand All @@ -56,7 +54,7 @@
from .coverage import FALLBACK_WARNING, clone_class
from .runtime import runtime
from .types import NdShape
from .utils import deep_apply, dot_modes
from .utils import deep_apply, dot_modes, to_core_dtype

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -174,21 +172,6 @@ def maybe_convert_to_np_ndarray(obj: Any) -> Any:
return obj


# FIXME: we can't give an accurate return type as mypy thinks
# the pyarrow import can be ignored, and can't override the check
# either, because no-any-unimported needs Python >= 3.10. We can
# fix it once we bump up the Python version
def convert_numpy_dtype_to_pyarrow(dtype: np.dtype[Any]) -> Any:
if dtype.kind != "c":
return pyarrow.from_numpy_dtype(dtype)
elif dtype == np.complex64:
return ty.complex64
elif dtype == np.complex128:
return ty.complex128
else:
raise ValueError(f"Unsupported NumPy dtype: {dtype}")


NDARRAY_INTERNAL = {
"__array_finalize__",
"__array_function__",
Expand Down Expand Up @@ -242,9 +225,15 @@ def __init__(
for inp in inputs
if isinstance(inp, ndarray)
]
self._thunk = runtime.create_empty_thunk(
sanitized_shape, dtype, inputs
)
core_dtype = to_core_dtype(dtype)
if core_dtype is not None:
self._thunk = runtime.create_empty_thunk(
sanitized_shape, core_dtype, inputs
)
else:
self._thunk = runtime.create_eager_thunk(
sanitized_shape, dtype
)
else:
self._thunk = thunk
self._legate_data: Union[dict[str, Any], None] = None
Expand Down Expand Up @@ -280,24 +269,17 @@ def _sanitize_shape(
@property
def __legate_data_interface__(self) -> dict[str, Any]:
if self._legate_data is None:
# All of our thunks implement the Legate Store interface
# so we just need to convert our type and stick it in
# a Legate Array
arrow_type = convert_numpy_dtype_to_pyarrow(self.dtype)
# If the thunk is an eager array, we need to convert it to a
# deferred array so we can extract a legate store
deferred_thunk = runtime.to_deferred_array(self._thunk)
# We don't have nullable data for the moment
# until we support masked arrays
array = Array(arrow_type, [None, deferred_thunk.base])
dtype = deferred_thunk.base.type
array = Array(dtype, [None, deferred_thunk.base])
self._legate_data = dict()
self._legate_data["version"] = 1
data = dict()
field = pyarrow.field(
"cuNumeric Array", arrow_type, nullable=False
)
data[field] = array
self._legate_data["data"] = data
field = Field("cuNumeric Array", dtype)
self._legate_data["data"] = {field: array}
return self._legate_data

# Properties for ndarray
Expand Down
54 changes: 7 additions & 47 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
from abc import abstractmethod
from enum import IntEnum, unique
from typing import TYPE_CHECKING, Any, List, Union, cast
from typing import TYPE_CHECKING, Union, cast

import numpy as np
from legate.core import Library, get_legate_runtime
Expand Down Expand Up @@ -197,15 +197,6 @@ class _CunumericSharedLib:
CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME: int
CUNUMERIC_TUNABLE_NUM_GPUS: int
CUNUMERIC_TUNABLE_NUM_PROCS: int
CUNUMERIC_TYPE_POINT1: int
CUNUMERIC_TYPE_POINT2: int
CUNUMERIC_TYPE_POINT3: int
CUNUMERIC_TYPE_POINT4: int
CUNUMERIC_TYPE_POINT5: int
CUNUMERIC_TYPE_POINT6: int
CUNUMERIC_TYPE_POINT7: int
CUNUMERIC_TYPE_POINT8: int
CUNUMERIC_TYPE_POINT9: int
CUNUMERIC_UNARY_OP: int
CUNUMERIC_UNARY_RED: int
CUNUMERIC_UNIQUE: int
Expand Down Expand Up @@ -274,6 +265,12 @@ class _CunumericSharedLib:
def cunumeric_has_curand(self) -> int:
...

@abstractmethod
def cunumeric_register_reduction_op(
self, type_uid: int, elem_type_code: int
) -> None:
...


# Load the cuNumeric library first so we have a shard object that
# we can use to initialize all these configuration enumerations
Expand Down Expand Up @@ -500,13 +497,6 @@ class RandGenCode(IntEnum):
INTEGER = 3


# Match these to CuNumericRedopID in cunumeric_c.h
@unique
class CuNumericRedopCode(IntEnum):
ARGMAX = 1
ARGMIN = 2


# Match these to CuNumericTunable in cunumeric_c.h
@unique
class CuNumericTunable(IntEnum):
Expand Down Expand Up @@ -774,33 +764,3 @@ def reverse(in_string: Union[str, None]) -> str:
return "forward"
else:
return in_string


# Match these to CuNumericTypeCodes in cunumeric_c.h
# we start from POINT2 type since POINT1 is int8 type
_CUNUMERIC_DTYPES: List[tuple[np.dtype[Any], int, int]] = [
(np.dtype("i8, i8"), 16, _cunumeric.CUNUMERIC_TYPE_POINT2),
(np.dtype("i8, i8, i8"), 24, _cunumeric.CUNUMERIC_TYPE_POINT3),
(np.dtype("i8, i8, i8, i8"), 32, _cunumeric.CUNUMERIC_TYPE_POINT4),
(np.dtype("i8, i8, i8, i8, i8"), 40, _cunumeric.CUNUMERIC_TYPE_POINT5),
(
np.dtype("i8, i8, i8, i8, i8, i8"),
48,
_cunumeric.CUNUMERIC_TYPE_POINT6,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8"),
56,
_cunumeric.CUNUMERIC_TYPE_POINT7,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8, i8"),
64,
_cunumeric.CUNUMERIC_TYPE_POINT8,
),
(
np.dtype("i8, i8, i8, i8, i8, i8, i8, i8, i8"),
72,
_cunumeric.CUNUMERIC_TYPE_POINT9,
),
]
Loading