diff --git a/cmake/versions.json b/cmake/versions.json index 49abf9ed9..08dd5b6a5 100644 --- a/cmake/versions.json +++ b/cmake/versions.json @@ -5,7 +5,7 @@ "git_url" : "https://github.com/nv-legate/legate.core.git", "git_shallow": false, "always_download": false, - "git_tag" : "b3e280d6212aa2ec0af619b6fee3ac24d069897e" + "git_tag" : "149fa50bce56350e84f3fad4d453b5f5b77b935d" } } } diff --git a/cunumeric/array.py b/cunumeric/array.py index c7822b6b4..53629b4b7 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -31,10 +31,8 @@ cast, ) -import legate.core.types as ty import numpy as np -import pyarrow # type: ignore [import] -from legate.core import Array +from legate.core import Array, Field from numpy.core.multiarray import ( # type: ignore [attr-defined] normalize_axis_index, ) @@ -56,7 +54,7 @@ from .coverage import FALLBACK_WARNING, clone_class from .runtime import runtime from .types import NdShape -from .utils import deep_apply, dot_modes +from .utils import deep_apply, dot_modes, to_core_dtype if TYPE_CHECKING: from pathlib import Path @@ -174,21 +172,6 @@ def maybe_convert_to_np_ndarray(obj: Any) -> Any: return obj -# FIXME: we can't give an accurate return type as mypy thinks -# the pyarrow import can be ignored, and can't override the check -# either, because no-any-unimported needs Python >= 3.10. We can -# fix it once we bump up the Python version -def convert_numpy_dtype_to_pyarrow(dtype: np.dtype[Any]) -> Any: - if dtype.kind != "c": - return pyarrow.from_numpy_dtype(dtype) - elif dtype == np.complex64: - return ty.complex64 - elif dtype == np.complex128: - return ty.complex128 - else: - raise ValueError(f"Unsupported NumPy dtype: {dtype}") - - NDARRAY_INTERNAL = { "__array_finalize__", "__array_function__", @@ -242,9 +225,15 @@ def __init__( for inp in inputs if isinstance(inp, ndarray) ] - self._thunk = runtime.create_empty_thunk( - sanitized_shape, dtype, inputs - ) + core_dtype = to_core_dtype(dtype) + if core_dtype is not None: + self._thunk = runtime.create_empty_thunk( + sanitized_shape, core_dtype, inputs + ) + else: + self._thunk = runtime.create_eager_thunk( + sanitized_shape, dtype + ) else: self._thunk = thunk self._legate_data: Union[dict[str, Any], None] = None @@ -280,24 +269,17 @@ def _sanitize_shape( @property def __legate_data_interface__(self) -> dict[str, Any]: if self._legate_data is None: - # All of our thunks implement the Legate Store interface - # so we just need to convert our type and stick it in - # a Legate Array - arrow_type = convert_numpy_dtype_to_pyarrow(self.dtype) # If the thunk is an eager array, we need to convert it to a # deferred array so we can extract a legate store deferred_thunk = runtime.to_deferred_array(self._thunk) # We don't have nullable data for the moment # until we support masked arrays - array = Array(arrow_type, [None, deferred_thunk.base]) + dtype = deferred_thunk.base.type + array = Array(dtype, [None, deferred_thunk.base]) self._legate_data = dict() self._legate_data["version"] = 1 - data = dict() - field = pyarrow.field( - "cuNumeric Array", arrow_type, nullable=False - ) - data[field] = array - self._legate_data["data"] = data + field = Field("cuNumeric Array", dtype) + self._legate_data["data"] = {field: array} return self._legate_data # Properties for ndarray diff --git a/cunumeric/config.py b/cunumeric/config.py index 98361b0bc..21a7a68e5 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -17,7 +17,7 @@ import os from abc import abstractmethod from enum import IntEnum, unique -from typing import TYPE_CHECKING, Any, List, Union, cast +from typing import TYPE_CHECKING, Union, cast import numpy as np from legate.core import Library, get_legate_runtime @@ -197,15 +197,6 @@ class _CunumericSharedLib: CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME: int CUNUMERIC_TUNABLE_NUM_GPUS: int CUNUMERIC_TUNABLE_NUM_PROCS: int - CUNUMERIC_TYPE_POINT1: int - CUNUMERIC_TYPE_POINT2: int - CUNUMERIC_TYPE_POINT3: int - CUNUMERIC_TYPE_POINT4: int - CUNUMERIC_TYPE_POINT5: int - CUNUMERIC_TYPE_POINT6: int - CUNUMERIC_TYPE_POINT7: int - CUNUMERIC_TYPE_POINT8: int - CUNUMERIC_TYPE_POINT9: int CUNUMERIC_UNARY_OP: int CUNUMERIC_UNARY_RED: int CUNUMERIC_UNIQUE: int @@ -274,6 +265,12 @@ class _CunumericSharedLib: def cunumeric_has_curand(self) -> int: ... + @abstractmethod + def cunumeric_register_reduction_op( + self, type_uid: int, elem_type_code: int + ) -> None: + ... + # Load the cuNumeric library first so we have a shard object that # we can use to initialize all these configuration enumerations @@ -500,13 +497,6 @@ class RandGenCode(IntEnum): INTEGER = 3 -# Match these to CuNumericRedopID in cunumeric_c.h -@unique -class CuNumericRedopCode(IntEnum): - ARGMAX = 1 - ARGMIN = 2 - - # Match these to CuNumericTunable in cunumeric_c.h @unique class CuNumericTunable(IntEnum): @@ -774,33 +764,3 @@ def reverse(in_string: Union[str, None]) -> str: return "forward" else: return in_string - - -# Match these to CuNumericTypeCodes in cunumeric_c.h -# we start from POINT2 type since POINT1 is int8 type -_CUNUMERIC_DTYPES: List[tuple[np.dtype[Any], int, int]] = [ - (np.dtype("i8, i8"), 16, _cunumeric.CUNUMERIC_TYPE_POINT2), - (np.dtype("i8, i8, i8"), 24, _cunumeric.CUNUMERIC_TYPE_POINT3), - (np.dtype("i8, i8, i8, i8"), 32, _cunumeric.CUNUMERIC_TYPE_POINT4), - (np.dtype("i8, i8, i8, i8, i8"), 40, _cunumeric.CUNUMERIC_TYPE_POINT5), - ( - np.dtype("i8, i8, i8, i8, i8, i8"), - 48, - _cunumeric.CUNUMERIC_TYPE_POINT6, - ), - ( - np.dtype("i8, i8, i8, i8, i8, i8, i8"), - 56, - _cunumeric.CUNUMERIC_TYPE_POINT7, - ), - ( - np.dtype("i8, i8, i8, i8, i8, i8, i8, i8"), - 64, - _cunumeric.CUNUMERIC_TYPE_POINT8, - ), - ( - np.dtype("i8, i8, i8, i8, i8, i8, i8, i8, i8"), - 72, - _cunumeric.CUNUMERIC_TYPE_POINT9, - ), -] diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index ecac9332a..c8527d4d3 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -48,7 +48,6 @@ Bitorder, ConvertCode, CuNumericOpCode, - CuNumericRedopCode, RandGenCode, UnaryOpCode, UnaryRedCode, @@ -77,15 +76,10 @@ ) -def _complex_field_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - if dtype == np.complex64: - return np.dtype(np.float32) - elif dtype == np.complex128: - return np.dtype(np.float64) - elif dtype == np.complex256: - return np.dtype(np.float128) - else: - assert False +_COMPLEX_FIELD_DTYPES = { + ty.complex64: ty.float32, + ty.complex128: ty.float64, +} def _prod(tpl: Sequence[int]) -> int: @@ -167,8 +161,8 @@ def __init__( UnaryRedCode.PROD: ReductionOp.MUL, UnaryRedCode.MAX: ReductionOp.MAX, UnaryRedCode.MIN: ReductionOp.MIN, - UnaryRedCode.ARGMAX: CuNumericRedopCode.ARGMAX, - UnaryRedCode.ARGMIN: CuNumericRedopCode.ARGMIN, + UnaryRedCode.ARGMAX: ReductionOp.MAX, + UnaryRedCode.ARGMIN: ReductionOp.MIN, UnaryRedCode.CONTAINS: ReductionOp.ADD, UnaryRedCode.COUNT_NONZERO: ReductionOp.ADD, UnaryRedCode.ALL: ReductionOp.MUL, @@ -238,11 +232,10 @@ class DeferredArray(NumPyThunk): def __init__( self, runtime: Runtime, - base: Any, - dtype: np.dtype[Any], + base: Store, numpy_array: Optional[npt.NDArray[Any]] = None, ) -> None: - super().__init__(runtime, dtype) + super().__init__(runtime, base.type.to_numpy_dtype()) assert base is not None assert isinstance(base, Store) self.base: Any = base # a Legate Store @@ -275,7 +268,8 @@ def _copy_if_overlapping(self, other: DeferredArray) -> DeferredArray: copy = cast( DeferredArray, self.runtime.create_empty_thunk( - self.shape, self.dtype, inputs=[self] + self.shape, + self.base.type, ), ) copy.copy(self, deep=True) @@ -320,7 +314,7 @@ def construct_ndarray( def imag(self) -> NumPyThunk: result = self.runtime.create_empty_thunk( self.shape, - dtype=_complex_field_dtype(self.dtype), + dtype=_COMPLEX_FIELD_DTYPES[self.base.type], inputs=[self], ) @@ -337,7 +331,7 @@ def imag(self) -> NumPyThunk: def real(self) -> NumPyThunk: result = self.runtime.create_empty_thunk( self.shape, - dtype=_complex_field_dtype(self.dtype), + dtype=_COMPLEX_FIELD_DTYPES[self.base.type], inputs=[self], ) @@ -353,7 +347,7 @@ def real(self) -> NumPyThunk: def conj(self) -> NumPyThunk: result = self.runtime.create_empty_thunk( self.shape, - dtype=self.dtype, + dtype=self.base.type, inputs=[self], ) @@ -491,11 +485,10 @@ def _copy_store(self, store: Any) -> DeferredArray: store_to_copy = DeferredArray( self.runtime, base=store, - dtype=self.dtype, ) store_copy = self.runtime.create_empty_thunk( store_to_copy.shape, - self.dtype, + self.base.type, inputs=[store_to_copy], ) store_copy.copy(store_to_copy, deep=True) @@ -606,7 +599,7 @@ def _has_single_boolean_array( "Unsupported entry type passed to advanced ", "indexing operation", ) - lhs = DeferredArray(self.runtime, store, self.dtype) + lhs = DeferredArray(self.runtime, store) return True, lhs, key[transpose_index] @@ -653,7 +646,7 @@ def _advanced_indexing_with_boolean_array( DeferredArray, self.runtime.create_empty_thunk( out_shape, - rhs.dtype, + rhs.base.type, inputs=[rhs], ), ) @@ -674,12 +667,11 @@ def _advanced_indexing_with_boolean_array( mask = DeferredArray( self.runtime, base=key_store, - dtype=self.dtype, ) rhs.putmask(mask, set_value) return False, rhs, rhs, self else: - out_dtype = rhs.dtype + out_dtype = rhs.base.type # in the case this operation is called for the set_item, we # return Point type field that is later used for # indirect copy operation @@ -699,7 +691,7 @@ def _advanced_indexing_with_boolean_array( task.add_output(out.base) task.add_input(rhs.base) task.add_input(key_store) - task.add_scalar_arg(is_set, bool) + task.add_scalar_arg(is_set, ty.bool_) task.add_scalar_arg(key_dims, ty.int64) task.add_alignment(rhs.base, key_store) task.add_broadcast( @@ -726,7 +718,7 @@ def _advanced_indexing_with_boolean_array( ), ) if not is_set: - out.fill(np.array(0, dtype=out_dtype)) + out.fill(np.array(0, dtype=out_dtype.to_numpy_dtype())) else: for dim in range(rhs.ndim - out_dim): out_tmp = out_tmp.project(rhs.ndim - dim - 1, 0) @@ -841,7 +833,7 @@ def _create_indexing_array( # to apply all the transformations done to `store` to `self` # as well before creating a copy if is_set: - self = DeferredArray(self.runtime, store, self.dtype) + self = DeferredArray(self.runtime, store) # after store is transformed we need to to return a copy of # the store since Copy operation can't be done on # the store with transformation @@ -893,7 +885,6 @@ def _get_view(self, key: Any) -> DeferredArray: return DeferredArray( self.runtime, base=store, - dtype=self.dtype, ) def _broadcast(self, shape: NdShape) -> Any: @@ -921,14 +912,13 @@ def _convert_future_to_regionfield( else: shape = self.shape store = self.context.create_store( - self.dtype, + self.base.type, shape=shape, optimize_scalar=False, ) thunk_copy = DeferredArray( self.runtime, base=store, - dtype=self.dtype, ) thunk_copy.copy(self, deep=True) return thunk_copy @@ -951,20 +941,19 @@ def get_item(self, key: Any) -> NumPyThunk: if index_array.base.kind == Future: index_array = index_array._convert_future_to_regionfield() result_store = self.context.create_store( - self.dtype, + self.base.type, shape=index_array.shape, optimize_scalar=False, ) result = DeferredArray( self.runtime, base=result_store, - dtype=self.dtype, ) else: result = self.runtime.create_empty_thunk( index_array.base.shape, - self.dtype, + self.base.type, inputs=[self], ) @@ -984,7 +973,7 @@ def get_item(self, key: Any) -> NumPyThunk: if result.shape == (): input = result result = self.runtime.create_empty_thunk( - (), self.dtype, inputs=[self] + (), self.base.type, inputs=[self] ) task = self.context.create_auto_task(CuNumericOpCode.READ) @@ -1027,7 +1016,6 @@ def set_item(self, key: Any, rhs: Any) -> None: rhs_tmp = DeferredArray( self.runtime, base=rhs_store, - dtype=rhs.dtype, ) rhs_tmp2 = rhs_tmp._convert_future_to_regionfield() rhs_store = rhs_tmp2.base @@ -1085,7 +1073,7 @@ def set_item(self, key: Any, rhs: Any) -> None: if view.base.overlaps(rhs.base): rhs_copy = self.runtime.create_empty_thunk( rhs.shape, - rhs.dtype, + rhs.base.type, inputs=[rhs], ) rhs_copy.copy(rhs, deep=False) @@ -1207,7 +1195,7 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: tmp_shape += tgt_g result = self.runtime.create_empty_thunk( - tmp_shape, dtype=self.dtype, inputs=[self] + tmp_shape, dtype=self.base.type, inputs=[self] ) src = self.base @@ -1237,8 +1225,8 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: assert src.shape == tgt.shape - src_array = DeferredArray(self.runtime, src, self.dtype) - tgt_array = DeferredArray(self.runtime, tgt, self.dtype) + src_array = DeferredArray(self.runtime, src) + tgt_array = DeferredArray(self.runtime, tgt) tgt_array.copy(src_array, deep=True) if needs_delinearization and needs_linearization: @@ -1250,9 +1238,9 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: src_dim += len(tgt_g) assert src.shape == newshape - src_array = DeferredArray(self.runtime, src, self.dtype) + src_array = DeferredArray(self.runtime, src) result = self.runtime.create_empty_thunk( - newshape, dtype=self.dtype, inputs=[self] + newshape, dtype=self.base.type, inputs=[self] ) result.copy(src_array, deep=True) @@ -1276,7 +1264,7 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: src_dim += diff - result = DeferredArray(self.runtime, src, self.dtype) + result = DeferredArray(self.runtime, src) return result @@ -1303,7 +1291,7 @@ def squeeze( ) if result is self.base: return self - return DeferredArray(self.runtime, result, self.dtype) + return DeferredArray(self.runtime, result) def swapaxes(self, axis1: int, axis2: int) -> DeferredArray: if self.size == 1 or axis1 == axis2: @@ -1315,7 +1303,7 @@ def swapaxes(self, axis1: int, axis2: int) -> DeferredArray: dims[axis1], dims[axis2] = dims[axis2], dims[axis1] result = self.base.transpose(dims) - result = DeferredArray(self.runtime, result, self.dtype) + result = DeferredArray(self.runtime, result) return result @@ -1422,7 +1410,7 @@ def fft( len(set(axes)) != len(axes) or len(axes) != input.ndim or tuple(axes) != tuple(sorted(axes)), - bool, + ty.bool_, ) for ax in axes: task.add_scalar_arg(ax, ty.int64) @@ -1452,7 +1440,7 @@ def _fill(self, value: Any) -> None: task = self.context.create_auto_task(CuNumericOpCode.FILL) task.add_output(self.base) task.add_input(value) - task.add_scalar_arg(argval, bool) + task.add_scalar_arg(argval, ty.bool_) task.execute() def fill(self, numpy_array: Any) -> None: @@ -1463,9 +1451,9 @@ def fill(self, numpy_array: Any) -> None: # Have to copy the numpy array because this launch is asynchronous # and we need to make sure the application doesn't mutate the value # so make a future result, this is immediate so no dependence - value = self.runtime.create_scalar(numpy_array.data, self.dtype) + value = self.runtime.create_scalar(numpy_array.data) store = self.context.create_store( - self.dtype, shape=(1,), storage=value, optimize_scalar=True + self.base.type, shape=(1,), storage=value, optimize_scalar=True ) self._fill(store) @@ -1558,7 +1546,7 @@ def contract( # below the tasks do this adjustment internally. if blas_op is not None and lhs_thunk.dtype == np.float16: lhs_thunk = self.runtime.create_empty_thunk( - lhs_thunk.shape, np.dtype(np.float32), inputs=[lhs_thunk] + lhs_thunk.shape, ty.float32, inputs=[lhs_thunk] ) # Clear output array @@ -1706,9 +1694,9 @@ def add_mode( task.add_reduction(lhs, ReductionOp.ADD) task.add_input(rhs1) task.add_input(rhs2) - task.add_scalar_arg(tuple(lhs_dim_mask), (bool,)) - task.add_scalar_arg(tuple(rhs1_dim_mask), (bool,)) - task.add_scalar_arg(tuple(rhs2_dim_mask), (bool,)) + task.add_scalar_arg(tuple(lhs_dim_mask), (ty.bool_,)) + task.add_scalar_arg(tuple(rhs1_dim_mask), (ty.bool_,)) + task.add_scalar_arg(tuple(rhs2_dim_mask), (ty.bool_,)) task.add_alignment(lhs, rhs1) task.add_alignment(lhs, rhs2) task.execute() @@ -1801,7 +1789,7 @@ def _diag_helper( task.add_alignment(diag, matrix) task.add_scalar_arg(naxes, ty.int32) - task.add_scalar_arg(extract, bool) + task.add_scalar_arg(extract, ty.bool_) task.execute() @@ -1840,8 +1828,8 @@ def put(self, indices: Any, values: Any, check_bounds: bool) -> None: task = self.context.create_auto_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) task.add_scalar_arg(shape, (ty.int64,)) - task.add_scalar_arg(True, bool) # has_input - task.add_scalar_arg(check_bounds, bool) + task.add_scalar_arg(True, ty.bool_) # has_input + task.add_scalar_arg(check_bounds, ty.bool_) task.add_input(indices.base) task.add_alignment(indices.base, indirect.base) task.throws_exception(IndexError) @@ -1903,7 +1891,7 @@ def arange(self, start: float, stop: float, step: float) -> None: # Handle the special case of a single value here assert self.shape[0] == 1 array = np.array(start, dtype=self.dtype) - future = self.runtime.create_scalar(array.data, array.dtype) + future = self.runtime.create_scalar(array.data) self.base.set_storage(future) return @@ -1948,7 +1936,7 @@ def transpose( self, axes: Union[None, tuple[int, ...], list[int]] ) -> DeferredArray: result = self.base.transpose(axes) - result = DeferredArray(self.runtime, result, self.dtype) + result = DeferredArray(self.runtime, result) return result @auto_convert("rhs") @@ -1960,7 +1948,7 @@ def trilu(self, rhs: Any, k: int, lower: bool) -> None: task.add_output(lhs) task.add_input(rhs) - task.add_scalar_arg(lower, bool) + task.add_scalar_arg(lower, ty.bool_) task.add_scalar_arg(k, ty.int32) task.add_alignment(lhs, rhs) @@ -1971,13 +1959,13 @@ def trilu(self, rhs: Any, k: int, lower: bool) -> None: def repeat( self, repeats: Any, axis: int, scalar_repeats: bool ) -> DeferredArray: - out = self.runtime.create_unbound_thunk(self.dtype, ndim=self.ndim) + out = self.runtime.create_unbound_thunk(self.base.type, ndim=self.ndim) task = self.context.create_auto_task(CuNumericOpCode.REPEAT) task.add_input(self.base) task.add_output(out.base) # We pass axis now but don't use for 1D case (will use for ND case task.add_scalar_arg(axis, ty.int32) - task.add_scalar_arg(scalar_repeats, bool) + task.add_scalar_arg(scalar_repeats, ty.bool_) if scalar_repeats: task.add_scalar_arg(repeats, ty.int64) else: @@ -2046,7 +2034,7 @@ def bincount(self, rhs: Any, weights: Optional[NumPyThunk] = None) -> None: def nonzero(self) -> tuple[NumPyThunk, ...]: results = tuple( - self.runtime.create_unbound_thunk(np.dtype(np.int64)) + self.runtime.create_unbound_thunk(ty.int64) for _ in range(self.ndim) ) @@ -3133,7 +3121,7 @@ def unary_reduction( argred = op in (UnaryRedCode.ARGMAX, UnaryRedCode.ARGMIN) if argred: - argred_dtype = self.runtime.get_arg_dtype(rhs_array.dtype) + argred_dtype = self.runtime.get_argred_type(rhs_array.base.type) lhs_array = self.runtime.create_empty_thunk( lhs_array.shape, dtype=argred_dtype, @@ -3320,7 +3308,7 @@ def where(self, src1: Any, src2: Any, src3: Any) -> None: task.execute() def argwhere(self) -> NumPyThunk: - result = self.runtime.create_unbound_thunk(np.dtype(np.int64), ndim=2) + result = self.runtime.create_unbound_thunk(ty.int64, ndim=2) task = self.context.create_auto_task(CuNumericOpCode.ARGWHERE) @@ -3378,7 +3366,7 @@ def scan( # local sum # storage for local sums accessible temp = self.runtime.create_unbound_thunk( - dtype=self.dtype, ndim=self.ndim + dtype=self.base.type, ndim=self.ndim ) if axis == rhs.ndim - 1: @@ -3388,7 +3376,7 @@ def scan( # swap axes, always performing scan along last axis swapped = rhs.swapaxes(axis, rhs.ndim - 1) input = self.runtime.create_empty_thunk( - swapped.shape, dtype=rhs.dtype, inputs=(rhs, swapped) + swapped.shape, dtype=rhs.base.type, inputs=(rhs, swapped) ) input.copy(swapped, deep=True) output = input @@ -3398,7 +3386,7 @@ def scan( task.add_input(input.base) task.add_output(temp.base) task.add_scalar_arg(op, ty.int32) - task.add_scalar_arg(nan_to_identity, bool) + task.add_scalar_arg(nan_to_identity, ty.bool_) task.add_alignment(input.base, output.base) @@ -3424,7 +3412,7 @@ def scan( self.copy(swapped, deep=True) def unique(self) -> NumPyThunk: - result = self.runtime.create_unbound_thunk(self.dtype) + result = self.runtime.create_unbound_thunk(self.base.type) task = self.context.create_auto_task(CuNumericOpCode.UNIQUE) @@ -3464,7 +3452,7 @@ def searchsorted(self, rhs: Any, v: Any, side: SortSide = "left") -> None: task.add_broadcast(self.base) task.add_alignment(self.base, v.base) - task.add_scalar_arg(is_left, bool) + task.add_scalar_arg(is_left, ty.bool_) task.add_scalar_arg(rhs.size, ty.int64) task.execute() @@ -3577,8 +3565,8 @@ def _wrap(self, src: Any, new_len: int) -> None: task = self.context.create_auto_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) task.add_scalar_arg(src.shape, (ty.int64,)) - task.add_scalar_arg(False, bool) # has_input - task.add_scalar_arg(False, bool) # check bounds + task.add_scalar_arg(False, ty.bool_) # has_input + task.add_scalar_arg(False, ty.bool_) # check bounds task.execute() copy = self.context.create_copy() diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 7e023d2cd..db5a275a4 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -178,10 +178,10 @@ def tril_single(context: Context, output: Store) -> None: task = context.create_auto_task(CuNumericOpCode.TRILU) task.add_output(output) task.add_input(output) - task.add_scalar_arg(True, bool) + task.add_scalar_arg(True, ty.bool_) task.add_scalar_arg(0, ty.int32) # Add a fake task argument to indicate that this is for Cholesky - task.add_scalar_arg(True, bool) + task.add_scalar_arg(True, ty.bool_) task.execute() @@ -194,10 +194,10 @@ def tril(context: Context, p_output: StorePartition, n: int) -> None: task.add_output(p_output) task.add_input(p_output) - task.add_scalar_arg(True, bool) + task.add_scalar_arg(True, ty.bool_) task.add_scalar_arg(0, ty.int32) # Add a fake task argument to indicate that this is for Cholesky - task.add_scalar_arg(True, bool) + task.add_scalar_arg(True, ty.bool_) task.execute() diff --git a/cunumeric/linalg/solve.py b/cunumeric/linalg/solve.py index 8eca91bc8..cec277c4e 100644 --- a/cunumeric/linalg/solve.py +++ b/cunumeric/linalg/solve.py @@ -50,7 +50,7 @@ def solve(output: DeferredArray, a: DeferredArray, b: DeferredArray) -> None: a_copy = cast( DeferredArray, - runtime.create_empty_thunk(a.shape, dtype=a.dtype, inputs=(a,)), + runtime.create_empty_thunk(a.shape, dtype=a.base.type, inputs=(a,)), ) transpose_copy_single(context, a.base, a_copy.base) diff --git a/cunumeric/runtime.py b/cunumeric/runtime.py index f0ad0398b..3fe1ff242 100644 --- a/cunumeric/runtime.py +++ b/cunumeric/runtime.py @@ -21,15 +21,13 @@ import legate.core.types as ty import numpy as np -from legate.core import LEGATE_MAX_DIM, Rect, get_legate_runtime, legion +from legate.core import LEGATE_MAX_DIM, Rect, get_legate_runtime from legate.core.context import Context as LegateContext from typing_extensions import TypeGuard from .config import ( - _CUNUMERIC_DTYPES, BitGeneratorOperation, CuNumericOpCode, - CuNumericRedopCode, CuNumericTunable, cunumeric_context, cunumeric_lib, @@ -39,12 +37,7 @@ from .settings import settings from .thunk import NumPyThunk from .types import NdShape -from .utils import ( - SUPPORTED_DTYPES, - calculate_volume, - find_last_user_stacklevel, - get_arg_dtype, -) +from .utils import calculate_volume, find_last_user_stacklevel, to_core_dtype if TYPE_CHECKING: import numpy.typing as npt @@ -54,6 +47,9 @@ from .array import ndarray +DIMENSION = int + + class Runtime(object): def __init__(self, legate_context: LegateContext) -> None: self.legate_context = legate_context @@ -87,28 +83,26 @@ def __init__(self, legate_context: LegateContext) -> None: # destroy us cunumeric_lib.set_runtime(self) assert cunumeric_lib.shared_object is not None + self.cunumeric_lib = cunumeric_lib.shared_object self.has_curand = cunumeric_lib.shared_object.cunumeric_has_curand() - self._register_dtypes() settings.warn = settings.warn() or settings.test() if self.num_gpus > 0 and settings.preload_cudalibs(): self._load_cudalibs() - def _register_dtypes(self) -> None: - type_system = self.legate_context.type_system - for numpy_type, core_type in SUPPORTED_DTYPES.items(): - type_system.make_alias(np.dtype(numpy_type), core_type) - - for dtype in _CUNUMERIC_DTYPES: - type_system.add_type(dtype[0], dtype[1], dtype[2]) + # Maps dimensions to point types + self._cached_point_types: dict[DIMENSION, ty.Dtype] = dict() + # Maps value types to struct types used in argmin/argmax + self._cached_argred_types: dict[ty.Dtype, ty.Dtype] = dict() - def get_point_type(self, n: int) -> np.dtype[Any]: - type_system = self.legate_context.type_system - point_type = np.dtype(",".join(("i8",) * n)) - if point_type not in type_system: - raise ValueError(f"there is no point type registered for {n}") - return point_type + def get_point_type(self, dim: DIMENSION) -> ty.Dtype: + cached = self._cached_point_types.get(dim) + if cached is not None: + return cached + point_dtype = ty.array_type(ty.int64, dim) if dim > 1 else ty.int64 + self._cached_point_types[dim] = point_dtype + return point_dtype def record_api_call( self, name: str, location: str, implemented: bool @@ -131,20 +125,16 @@ def _unload_cudalibs(self) -> None: ) task.execute() - def get_arg_dtype(self, value_dtype: np.dtype[Any]) -> np.dtype[Any]: - arg_dtype = get_arg_dtype(value_dtype) - type_system = self.legate_context.type_system - if arg_dtype not in type_system: - # We assign T's type code to Argval - code = type_system[value_dtype].code - dtype = type_system.add_type(arg_dtype, arg_dtype.itemsize, code) - - for redop in CuNumericRedopCode: - redop_id = self.legate_context.get_reduction_op_id( - redop.value * legion.MAX_TYPE_NUMBER + code - ) - dtype.register_reduction_op(redop, redop_id) - return arg_dtype + def get_argred_type(self, value_dtype: ty.Dtype) -> ty.Dtype: + cached = self._cached_argred_types.get(value_dtype) + if cached is not None: + return cached + argred_dtype = ty.struct_type([ty.int64, value_dtype], True) + self._cached_argred_types[value_dtype] = argred_dtype + self.cunumeric_lib.cunumeric_register_reduction_op( + argred_dtype.uid, value_dtype.code + ) + return argred_dtype def _report_coverage(self) -> None: total = len(self.api_calls) @@ -175,7 +165,6 @@ def destroy(self) -> None: def create_scalar( self, array: Union[memoryview, npt.NDArray[Any]], - dtype: np.dtype[Any], shape: Optional[NdShape] = None, ) -> Future: data = array.tobytes() @@ -188,15 +177,17 @@ def create_wrapped_scalar( dtype: np.dtype[Any], shape: NdShape, ) -> DeferredArray: - future = self.create_scalar(array, dtype, shape) + future = self.create_scalar(array, shape) assert all(extent == 1 for extent in shape) + core_dtype = to_core_dtype(dtype) + assert core_dtype is not None store = self.legate_context.create_store( - dtype, + core_dtype, shape=shape, storage=future, optimize_scalar=True, ) - return DeferredArray(self, store, dtype=dtype) + return DeferredArray(self, store) def bitgenerator_populate_task( self, @@ -272,21 +263,8 @@ def get_next_random_epoch(self) -> int: self.current_random_epoch += 1 return result - def is_point_type(self, dtype: Union[str, np.dtype[Any]]) -> bool: - if ( - isinstance(dtype, str) - and len(dtype) == 6 - and dtype[0:5] == "Point" - ): - return True - else: - return False - def is_supported_type(self, dtype: Union[str, np.dtype[Any]]) -> bool: - if self.is_point_type(dtype): - return dtype in self.legate_context.type_system - else: - return np.dtype(dtype) in self.legate_context.type_system + return to_core_dtype(dtype) is not None def get_numpy_thunk( self, @@ -312,9 +290,7 @@ def get_numpy_thunk( if stores[0] is not None: raise NotImplementedError("Need support for masked arrays") store = stores[1] - if dtype is None: - dtype = np.dtype(array.type.to_pandas_dtype()) - return DeferredArray(self, store, dtype=dtype) + return DeferredArray(self, store) # See if this is a normal numpy array # Make sure to convert numpy matrices to numpy arrays here # as the former doesn't behave quite like the latter @@ -442,7 +418,8 @@ def find_or_create_array_thunk( # Once it's a normal numpy array we can make it into one of our arrays # Check to see if it is a type that we support for doing deferred # execution and big enough to be worth off-loading onto Legion - if self.is_supported_type(array.dtype) and ( + dtype = to_core_dtype(array.dtype) + if dtype is not None and ( defer or not self.is_eager_shape(array.shape) or self.has_external_attachment(array) @@ -458,7 +435,7 @@ def find_or_create_array_thunk( # This is not a scalar so make a field store = self.legate_context.create_store( - array.dtype, + dtype, shape=array.shape, optimize_scalar=False, ) @@ -470,7 +447,6 @@ def find_or_create_array_thunk( return DeferredArray( self, store, - dtype=array.dtype, numpy_array=array if share else None, ) @@ -481,24 +457,29 @@ def find_or_create_array_thunk( def create_empty_thunk( self, shape: NdShape, - dtype: np.dtype[Any], + dtype: ty.Dtype, inputs: Optional[Sequence[NumPyThunk]] = None, ) -> NumPyThunk: - if self.is_supported_type(dtype) and not ( - self.is_eager_shape(shape) and self.are_all_eager_inputs(inputs) - ): - store = self.legate_context.create_store( - dtype, shape=shape, optimize_scalar=True - ) - return DeferredArray(self, store, dtype=dtype) - else: - return EagerArray(self, np.empty(shape, dtype=dtype)) + if self.is_eager_shape(shape) and self.are_all_eager_inputs(inputs): + return self.create_eager_thunk(shape, dtype.to_numpy_dtype()) + + store = self.legate_context.create_store( + dtype, shape=shape, optimize_scalar=True + ) + return DeferredArray(self, store) + + def create_eager_thunk( + self, + shape: NdShape, + dtype: np.dtype[Any], + ) -> NumPyThunk: + return EagerArray(self, np.empty(shape, dtype=dtype)) def create_unbound_thunk( - self, dtype: np.dtype[Any], ndim: int = 1 + self, dtype: ty.Dtype, ndim: int = 1 ) -> DeferredArray: store = self.legate_context.create_store(dtype, ndim=ndim) - return DeferredArray(self, store, dtype=dtype) + return DeferredArray(self, store) def is_eager_shape(self, shape: NdShape) -> bool: volume = calculate_volume(shape) diff --git a/cunumeric/sort.py b/cunumeric/sort.py index 93fa63abb..a0503bf92 100644 --- a/cunumeric/sort.py +++ b/cunumeric/sort.py @@ -36,7 +36,7 @@ def sort_flattened( sort_result = cast( "DeferredArray", output.runtime.create_empty_thunk( - flattened.shape, dtype=output.dtype, inputs=(flattened,) + flattened.shape, dtype=output.base.type, inputs=(flattened,) ), ) sort(sort_result, flattened, argsort, stable=stable) @@ -59,7 +59,7 @@ def sort_swapped( swapped_copy = cast( "DeferredArray", output.runtime.create_empty_thunk( - swapped.shape, dtype=input.dtype, inputs=(input, swapped) + swapped.shape, dtype=input.base.type, inputs=(input, swapped) ), ) swapped_copy.copy(swapped, deep=True) @@ -69,7 +69,9 @@ def sort_swapped( sort_result = cast( "DeferredArray", output.runtime.create_empty_thunk( - swapped_copy.shape, dtype=output.dtype, inputs=(swapped_copy,) + swapped_copy.shape, + dtype=output.base.type, + inputs=(swapped_copy,), ), ) sort(sort_result, swapped_copy, argsort, stable=stable) @@ -91,7 +93,7 @@ def sort_task( task.add_input(input.base) if uses_unbound_output: unbound = output.runtime.create_unbound_thunk( - dtype=output.dtype, ndim=1 + dtype=output.base.type, ndim=1 ) task.add_output(unbound.base) else: @@ -103,9 +105,9 @@ def sort_task( elif output.runtime.num_gpus == 0 and output.runtime.num_procs > 1: task.add_cpu_communicator() - task.add_scalar_arg(argsort, bool) # return indices flag + task.add_scalar_arg(argsort, ty.bool_) # return indices flag task.add_scalar_arg(input.base.shape, (ty.int64,)) - task.add_scalar_arg(stable, bool) + task.add_scalar_arg(stable, ty.bool_) task.execute() if uses_unbound_output: diff --git a/cunumeric/utils.py b/cunumeric/utils.py index 64f39a87e..dc40ea190 100644 --- a/cunumeric/utils.py +++ b/cunumeric/utils.py @@ -18,7 +18,7 @@ from functools import reduce from string import ascii_lowercase, ascii_uppercase from types import FrameType -from typing import Any, Callable, List, Sequence, Tuple, Union, cast +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import legate.core.types as ty import numpy as np @@ -26,26 +26,27 @@ from .types import NdShape SUPPORTED_DTYPES = { - bool: ty.bool_, - np.bool_: ty.bool_, - np.int8: ty.int8, - np.int16: ty.int16, - np.int32: ty.int32, - int: ty.int64, # np.int is int - np.int64: ty.int64, - np.uint8: ty.uint8, - np.uint16: ty.uint16, - np.uint32: ty.uint32, - np.uint64: ty.uint64, # np.uint is np.uint64 - np.float16: ty.float16, - np.float32: ty.float32, - float: ty.float64, - np.float64: ty.float64, - np.complex64: ty.complex64, - np.complex128: ty.complex128, + np.dtype(np.bool_): ty.bool_, + np.dtype(np.int8): ty.int8, + np.dtype(np.int16): ty.int16, + np.dtype(np.int32): ty.int32, + np.dtype(np.int64): ty.int64, + np.dtype(np.uint8): ty.uint8, + np.dtype(np.uint16): ty.uint16, + np.dtype(np.uint32): ty.uint32, + np.dtype(np.uint64): ty.uint64, + np.dtype(np.float16): ty.float16, + np.dtype(np.float32): ty.float32, + np.dtype(np.float64): ty.float64, + np.dtype(np.complex64): ty.complex64, + np.dtype(np.complex128): ty.complex128, } +def to_core_dtype(dtype: Union[str, np.dtype[Any]]) -> Optional[ty.Dtype]: + return SUPPORTED_DTYPES.get(np.dtype(dtype), None) + + def is_advanced_indexing(key: Any) -> bool: if key is Ellipsis or key is None: # np.newdim case return False @@ -93,30 +94,12 @@ def find_last_user_frames(top_only: bool = True) -> str: return "|".join(get_line_number_from_frame(f) for f in frames) -def is_supported_dtype(dtype: Any) -> bool: - if not isinstance(dtype, np.dtype): - raise TypeError("expected a NumPy dtype") - return dtype.type in SUPPORTED_DTYPES - - def calculate_volume(shape: NdShape) -> int: if len(shape) == 0: return 0 return reduce(lambda x, y: x * y, shape) -def get_arg_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - return np.dtype( - [("arg", np.int64), ("arg_value", dtype)], - align=True, - ) - - -def get_arg_value_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - dt = dtype.fields["arg_value"][0].type # type: ignore [index] - return cast(np.dtype[Any], dt) - - Modes = Tuple[List[str], List[str], List[str]] diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index 4592559d3..dd8a60f7e 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -156,7 +156,7 @@ list(APPEND cunumeric_SOURCES src/cunumeric/stat/bincount.cc src/cunumeric/convolution/convolve.cc src/cunumeric/transform/flip.cc - src/cunumeric/arg.cc + src/cunumeric/arg_redop_register.cc src/cunumeric/mapper.cc src/cunumeric/cephes/chbevl.cc src/cunumeric/cephes/i0.cc @@ -254,8 +254,8 @@ if(Legion_USE_CUDA) src/cunumeric/convolution/convolve.cu src/cunumeric/fft/fft.cu src/cunumeric/transform/flip.cu + src/cunumeric/arg_redop_register.cu src/cunumeric/cudalibs.cu - src/cunumeric/cunumeric.cu ) endif() diff --git a/src/cunumeric/arg.h b/src/cunumeric/arg.h index edfa1d1c3..6f1f55258 100644 --- a/src/cunumeric/arg.h +++ b/src/cunumeric/arg.h @@ -17,7 +17,6 @@ #pragma once #include "legate.h" -#include "cunumeric/cunumeric_c.h" namespace cunumeric { @@ -63,8 +62,6 @@ class ArgmaxReduction { using RHS = Argval; static const Argval identity; - static const int32_t REDOP_ID = - CUNUMERIC_ARGMAX_REDOP * MAX_TYPE_NUMBER + legate::legate_type_code_of; template __CUDA_HD__ inline static void apply(LHS& lhs, RHS rhs) @@ -85,8 +82,6 @@ class ArgminReduction { using RHS = Argval; static const Argval identity; - static const int32_t REDOP_ID = - CUNUMERIC_ARGMIN_REDOP * MAX_TYPE_NUMBER + legate::legate_type_code_of; template __CUDA_HD__ inline static void apply(LHS& lhs, RHS rhs) @@ -101,3 +96,5 @@ class ArgminReduction { }; } // namespace cunumeric + +#include "cunumeric/arg.inl" diff --git a/src/cunumeric/arg.inl b/src/cunumeric/arg.inl index b98314d7b..8839ec456 100644 --- a/src/cunumeric/arg.inl +++ b/src/cunumeric/arg.inl @@ -112,34 +112,4 @@ __CUDA_HD__ inline void Argval::apply(const Argval& rhs) } } -#define DECLARE_ARGMAX_IDENTITY(TYPE) \ - template <> \ - const Argval ArgmaxReduction::identity; - -#define DECLARE_ARGMIN_IDENTITY(TYPE) \ - template <> \ - const Argval ArgminReduction::identity; - -#define DECLARE_IDENTITIES(TYPE) \ - DECLARE_ARGMAX_IDENTITY(TYPE) \ - DECLARE_ARGMIN_IDENTITY(TYPE) - -DECLARE_IDENTITIES(__half) -DECLARE_IDENTITIES(float) -DECLARE_IDENTITIES(double) -DECLARE_IDENTITIES(bool) -DECLARE_IDENTITIES(int8_t) -DECLARE_IDENTITIES(int16_t) -DECLARE_IDENTITIES(int32_t) -DECLARE_IDENTITIES(int64_t) -DECLARE_IDENTITIES(uint8_t) -DECLARE_IDENTITIES(uint16_t) -DECLARE_IDENTITIES(uint32_t) -DECLARE_IDENTITIES(uint64_t) -DECLARE_IDENTITIES(complex) - -#undef DECLARE_IDENTITIES -#undef DECLARE_ARGMIN_IDENTITY -#undef DECLARE_ARGMAX_IDENTITY - } // namespace cunumeric diff --git a/src/cunumeric/arg.cc b/src/cunumeric/arg_redop_register.cc similarity index 60% rename from src/cunumeric/arg.cc rename to src/cunumeric/arg_redop_register.cc index 5c400a0d8..5068abaee 100644 --- a/src/cunumeric/arg.cc +++ b/src/cunumeric/arg_redop_register.cc @@ -14,8 +14,7 @@ * */ -#include "cunumeric/arg.h" -#include "cunumeric/arg.inl" +#include "cunumeric/arg_redop_register.h" namespace cunumeric { @@ -46,30 +45,27 @@ DEFINE_IDENTITIES(uint16_t) DEFINE_IDENTITIES(uint32_t) DEFINE_IDENTITIES(uint64_t) +#undef DEFINE_ARGMAX_IDENTITY +#undef DEFINE_ARGMIN_IDENTITY +#undef DEFINE_IDENTITIES + +/*static*/ int32_t register_reduction_op_fn::register_reduction_op_fn::next_reduction_operator_id() +{ + static int32_t next_redop_id = 0; + return next_redop_id++; +} + +} // namespace cunumeric + #ifndef LEGATE_USE_CUDA -#define REGISTER_REDOPS(OP) \ - { \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - } +extern "C" { -void register_reduction_operators(legate::LibraryContext* context) +void cunumeric_register_reduction_op(int32_t type_uid, int32_t _elem_type_code) { - REGISTER_REDOPS(ArgmaxReduction); - REGISTER_REDOPS(ArgminReduction); + auto elem_type_code = static_cast(_elem_type_code); + legate::type_dispatch(elem_type_code, cunumeric::register_reduction_op_fn{}, type_uid); +} } #endif - -} // namespace cunumeric diff --git a/src/cunumeric/arg_redop_register.cu b/src/cunumeric/arg_redop_register.cu new file mode 100644 index 000000000..5a14b0b71 --- /dev/null +++ b/src/cunumeric/arg_redop_register.cu @@ -0,0 +1,26 @@ +/* Copyright 2021-2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/arg_redop_register.h" + +extern "C" { + +void cunumeric_register_reduction_op(int32_t type_uid, int32_t _elem_type_code) +{ + auto elem_type_code = static_cast(_elem_type_code); + legate::type_dispatch(elem_type_code, cunumeric::register_reduction_op_fn{}, type_uid); +} +} diff --git a/src/cunumeric/arg_redop_register.h b/src/cunumeric/arg_redop_register.h new file mode 100644 index 000000000..02433da62 --- /dev/null +++ b/src/cunumeric/arg_redop_register.h @@ -0,0 +1,56 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "legate.h" +#include "cunumeric/cunumeric_c.h" +#include "cunumeric/arg.h" + +namespace cunumeric { + +struct register_reduction_op_fn { + template ::value>* = nullptr> + void operator()(int32_t type_uid) + { + using VAL = legate::legate_type_of; + + auto runtime = legate::Runtime::get_runtime(); + auto context = runtime->find_library("cunumeric"); + { + auto redop_id = + context->register_reduction_operator>(next_reduction_operator_id()); + auto op_kind = static_cast(legate::ReductionOpKind::MAX); + runtime->record_reduction_operator(type_uid, op_kind, redop_id); + } + { + auto redop_id = + context->register_reduction_operator>(next_reduction_operator_id()); + auto op_kind = static_cast(legate::ReductionOpKind::MIN); + runtime->record_reduction_operator(type_uid, op_kind, redop_id); + } + } + + template ::value>* = nullptr> + void operator()(int32_t type_uid) + { + LEGATE_ABORT; + } + + static int32_t next_reduction_operator_id(); +}; + +} // namespace cunumeric diff --git a/src/cunumeric/binary/binary_op.cc b/src/cunumeric/binary/binary_op.cc index c17e3ac49..a718443fa 100644 --- a/src/cunumeric/binary/binary_op.cc +++ b/src/cunumeric/binary/binary_op.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryOpImplBody { using OP = BinaryOp; using RHS1 = legate_type_of; diff --git a/src/cunumeric/binary/binary_op.cu b/src/cunumeric/binary/binary_op.cu index b1d7ce4df..76177b154 100644 --- a/src/cunumeric/binary/binary_op.cu +++ b/src/cunumeric/binary/binary_op.cu @@ -51,7 +51,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[point] = func(in1[point], in2[point]); } -template +template struct BinaryOpImplBody { using OP = BinaryOp; using RHS1 = legate_type_of; diff --git a/src/cunumeric/binary/binary_op_omp.cc b/src/cunumeric/binary/binary_op_omp.cc index 46452e9f1..53ec582a7 100644 --- a/src/cunumeric/binary/binary_op_omp.cc +++ b/src/cunumeric/binary/binary_op_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryOpImplBody { using OP = BinaryOp; using RHS1 = legate_type_of; diff --git a/src/cunumeric/binary/binary_op_template.inl b/src/cunumeric/binary/binary_op_template.inl index a51276d2a..d32484835 100644 --- a/src/cunumeric/binary/binary_op_template.inl +++ b/src/cunumeric/binary/binary_op_template.inl @@ -25,14 +25,12 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryOpImplBody; template struct BinaryOpImpl { - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(BinaryOpArgs& args) const { using OP = BinaryOp; @@ -64,9 +62,7 @@ struct BinaryOpImpl { BinaryOpImplBody()(func, out, in1, in2, pitches, rect, dense); } - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(BinaryOpArgs& args) const { assert(false); diff --git a/src/cunumeric/binary/binary_op_util.h b/src/cunumeric/binary/binary_op_util.h index 1cb54464b..a0c9540dc 100644 --- a/src/cunumeric/binary/binary_op_util.h +++ b/src/cunumeric/binary/binary_op_util.h @@ -160,18 +160,18 @@ constexpr decltype(auto) reduce_op_dispatch(BinaryOpCode op_code, Functor f, Fna return f.template operator()(std::forward(args)...); } -template +template struct BinaryOp { static constexpr bool valid = false; }; -template +template struct BinaryOp : std::plus> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -187,17 +187,17 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = std::is_integral::value; @@ -210,7 +210,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = std::is_integral::value; @@ -223,7 +223,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = std::is_integral::value; @@ -236,7 +236,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -251,18 +251,18 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { using T = __half; static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -281,17 +281,17 @@ struct BinaryOp { } }; -template +template struct BinaryOp : std::equal_to> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = - CODE == legate::LegateTypeCode::DOUBLE_LT or CODE == legate::LegateTypeCode::COMPLEX128_LT; + CODE == legate::Type::Code::FLOAT64 or CODE == legate::Type::Code::COMPLEX128; BinaryOp(const std::vector& args) {} constexpr T operator()(const T& a, const T& b) const @@ -302,7 +302,7 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { using T = complex; static constexpr bool valid = true; BinaryOp(const std::vector& args) {} @@ -314,11 +314,12 @@ struct BinaryOp } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = - not(CODE == legate::LegateTypeCode::BOOL_LT or legate::is_complex::value); + not(CODE == legate::Type::Code::BOOL or legate::is_complex::value); + __CUDA_HD__ BinaryOp() {} BinaryOp(const std::vector& args) {} @@ -337,12 +338,12 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; @@ -370,7 +371,7 @@ static __CUDA_HD__ T _gcd(T a, T b) return a; } -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = std::is_integral::value; @@ -388,7 +389,7 @@ static constexpr T floor_divide_signed(const T& a, const T& b) } using std::floor; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -416,31 +417,31 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = false; BinaryOp(const std::vector& args) {} }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = false; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp : std::greater> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp : std::greater_equal> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -456,18 +457,18 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp { using VAL = legate::legate_type_of; static constexpr bool valid = true; @@ -499,7 +500,7 @@ struct BinaryOp { double atol_{0}; }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = std::is_integral::value; @@ -524,7 +525,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -538,7 +539,7 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { using T = __half; static constexpr bool valid = true; BinaryOp(const std::vector& args) {} @@ -550,10 +551,10 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; - static constexpr bool valid = CODE != BOOL_LT && std::is_integral::value; + static constexpr bool valid = CODE != legate::Type::Code::BOOL && std::is_integral::value; BinaryOp(const std::vector& args) {} @@ -567,19 +568,19 @@ struct BinaryOp { } }; -template +template struct BinaryOp : std::less> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp : std::less_equal> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -602,18 +603,18 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -635,18 +636,18 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -665,7 +666,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -685,7 +686,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -704,7 +705,7 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -712,7 +713,7 @@ struct BinaryOp { constexpr T operator()(const T& a, const T& b) const { return std::max(a, b); } }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -732,7 +733,7 @@ constexpr T real_mod(const T& a, const T& b) return res; } -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -762,34 +763,34 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = false; BinaryOp(const std::vector& args) {} }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = false; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp : std::multiplies> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_floating_point::value; @@ -804,24 +805,24 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { using T = __half; static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { - return lift(a, b, BinaryOp{}); + return lift(a, b, BinaryOp{}); } }; -template +template struct BinaryOp : std::not_equal_to> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct BinaryOp { using VAL = legate::legate_type_of; static constexpr bool valid = true; @@ -833,14 +834,14 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ __half operator()(const __half& a, const __half& b) const { return pow(a, b); } }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ complex operator()(const complex& a, const complex& b) const @@ -850,7 +851,7 @@ struct BinaryOp { }; template <> -struct BinaryOp { +struct BinaryOp { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} __CUDA_HD__ complex operator()(const complex& a, const complex& b) const @@ -859,33 +860,33 @@ struct BinaryOp { } }; -template +template struct BinaryOp { using T = legate::legate_type_of; - static constexpr bool valid = CODE != BOOL_LT && std::is_integral::value; + static constexpr bool valid = CODE != legate::Type::Code::BOOL && std::is_integral::value; BinaryOp(const std::vector& args) {} constexpr decltype(auto) operator()(const T& a, const T& b) const { return a >> b; } }; -template +template struct BinaryOp : std::minus> { static constexpr bool valid = true; BinaryOp(const std::vector& args) {} }; -template +template struct RHS2OfBinaryOp { using type = legate::legate_type_of; }; -template +template struct RHS2OfBinaryOp { using type = int32_t; }; -template +template using rhs2_of_binary_op = typename RHS2OfBinaryOp::type; } // namespace cunumeric diff --git a/src/cunumeric/binary/binary_red.cc b/src/cunumeric/binary/binary_red.cc index 73b72bc8a..5340e5334 100644 --- a/src/cunumeric/binary/binary_red.cc +++ b/src/cunumeric/binary/binary_red.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryRedImplBody { using OP = BinaryOp; using ARG = legate_type_of; diff --git a/src/cunumeric/binary/binary_red.cu b/src/cunumeric/binary/binary_red.cu index b9cc48d10..98435abd8 100644 --- a/src/cunumeric/binary/binary_red.cu +++ b/src/cunumeric/binary/binary_red.cu @@ -46,7 +46,7 @@ static __global__ void __launch_bounds__(1, 1) copy_kernel(Buffer result, RedAcc out.reduce(0, result.read()); } -template +template struct BinaryRedImplBody { using OP = BinaryOp; using ARG = legate_type_of; diff --git a/src/cunumeric/binary/binary_red_omp.cc b/src/cunumeric/binary/binary_red_omp.cc index cc2713645..891aa7abd 100644 --- a/src/cunumeric/binary/binary_red_omp.cc +++ b/src/cunumeric/binary/binary_red_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryRedImplBody { using OP = BinaryOp; using ARG = legate_type_of; diff --git a/src/cunumeric/binary/binary_red_template.inl b/src/cunumeric/binary/binary_red_template.inl index d0f180091..4bff8e454 100644 --- a/src/cunumeric/binary/binary_red_template.inl +++ b/src/cunumeric/binary/binary_red_template.inl @@ -25,14 +25,12 @@ namespace cunumeric { using namespace legate; -template +template struct BinaryRedImplBody; template struct BinaryRedImpl { - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(BinaryRedArgs& args) const { using OP = BinaryOp; @@ -68,9 +66,7 @@ struct BinaryRedImpl { BinaryRedImplBody()(func, out, in1, in2, pitches, rect, dense); } - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(BinaryRedArgs& args) const { assert(false); diff --git a/src/cunumeric/bits/packbits.cc b/src/cunumeric/bits/packbits.cc index 027e090c7..99eac967c 100644 --- a/src/cunumeric/bits/packbits.cc +++ b/src/cunumeric/bits/packbits.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct PackbitsImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/bits/packbits.cu b/src/cunumeric/bits/packbits.cu index 03757144b..81edb82b4 100644 --- a/src/cunumeric/bits/packbits.cu +++ b/src/cunumeric/bits/packbits.cu @@ -39,7 +39,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[out_p] = pack(in, out_p, in_hi_axis, axis); } -template +template struct PackbitsImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/bits/packbits_omp.cc b/src/cunumeric/bits/packbits_omp.cc index b28199ab3..7e8e05c55 100644 --- a/src/cunumeric/bits/packbits_omp.cc +++ b/src/cunumeric/bits/packbits_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct PackbitsImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/bits/packbits_template.inl b/src/cunumeric/bits/packbits_template.inl index c1a820c73..7e1e68a89 100644 --- a/src/cunumeric/bits/packbits_template.inl +++ b/src/cunumeric/bits/packbits_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct PackbitsImplBody; template struct PackbitsImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& output, Array& input, uint32_t axis) const { using VAL = legate_type_of; @@ -74,9 +74,7 @@ struct PackbitsImpl { axis); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& output, Array& input, uint32_t axis) const { // Unreachable @@ -93,7 +91,7 @@ static void packbits_template(TaskContext& context) auto axis = scalars[0].value(); auto bitorder = scalars[1].value(); - auto code = input.code(); + auto code = input.code(); switch (bitorder) { case Bitorder::BIG: { double_dispatch(input.dim(), code, PackbitsImpl{}, output, input, axis); diff --git a/src/cunumeric/bits/unpackbits_template.inl b/src/cunumeric/bits/unpackbits_template.inl index f0316ac1e..2022a4135 100644 --- a/src/cunumeric/bits/unpackbits_template.inl +++ b/src/cunumeric/bits/unpackbits_template.inl @@ -47,9 +47,7 @@ struct UnpackbitsImpl { UnpackbitsImplBody{}(out, in, in_rect, in_pitches, in_volume, axis); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& output, Array& input, uint32_t axis) const { // Unreachable @@ -66,7 +64,7 @@ static void unpackbits_template(TaskContext& context) auto axis = scalars[0].value(); auto bitorder = scalars[1].value(); - auto code = input.code(); + auto code = input.code(); switch (bitorder) { case Bitorder::BIG: { dim_dispatch(input.dim(), UnpackbitsImpl{}, output, input, axis); diff --git a/src/cunumeric/convolution/convolve.cc b/src/cunumeric/convolution/convolve.cc index b2a2d817f..3827be718 100644 --- a/src/cunumeric/convolution/convolve.cc +++ b/src/cunumeric/convolution/convolve.cc @@ -24,7 +24,7 @@ namespace cunumeric { // algorithm, but it is commented out in favor of the faster one // that is blocked for caches #if 0 -template +template struct ConvolveImplBody { using VAL = legate_type_of; @@ -73,7 +73,7 @@ struct ConvolveImplBody { }; #endif -template +template struct ConvolveImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/convolution/convolve.cu b/src/cunumeric/convolution/convolve.cu index 28688d6c4..66886a96b 100644 --- a/src/cunumeric/convolution/convolve.cu +++ b/src/cunumeric/convolution/convolve.cu @@ -1409,7 +1409,7 @@ struct UseCUFFT { static constexpr bool value = 1 <= DIM && DIM <= 3 && std::is_floating_point::value; }; -template +template struct ConvolveImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/convolution/convolve_omp.cc b/src/cunumeric/convolution/convolve_omp.cc index f016c4118..283f4b7b9 100644 --- a/src/cunumeric/convolution/convolve_omp.cc +++ b/src/cunumeric/convolution/convolve_omp.cc @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct ConvolveImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/convolution/convolve_template.inl b/src/cunumeric/convolution/convolve_template.inl index 9d9077eca..f698ebfd4 100644 --- a/src/cunumeric/convolution/convolve_template.inl +++ b/src/cunumeric/convolution/convolve_template.inl @@ -26,12 +26,12 @@ namespace cunumeric { using namespace legate; -template +template struct ConvolveImplBody; template struct ConvolveImpl { - template * = nullptr> + template * = nullptr> void operator()(ConvolveArgs& args) const { using VAL = legate_type_of; @@ -55,7 +55,7 @@ struct ConvolveImpl { ConvolveImplBody()(out, filter, input, root_rect, subrect, filter_rect); } - template * = nullptr> + template * = nullptr> void operator()(ConvolveArgs& args) const { assert(false); diff --git a/src/cunumeric/cuda_help.h b/src/cunumeric/cuda_help.h index b7988f741..f0f0fee85 100644 --- a/src/cunumeric/cuda_help.h +++ b/src/cunumeric/cuda_help.h @@ -20,7 +20,6 @@ #include "core/cuda/cuda_help.h" #include "core/cuda/stream_pool.h" #include "cunumeric/arg.h" -#include "cunumeric/arg.inl" #include "cunumeric/device_scalar_reduction_buffer.h" #include #include diff --git a/src/cunumeric/cunumeric.cc b/src/cunumeric/cunumeric.cc index c631d5a8d..5377c5aeb 100644 --- a/src/cunumeric/cunumeric.cc +++ b/src/cunumeric/cunumeric.cc @@ -30,8 +30,6 @@ static const char* const cunumeric_library_name = "cunumeric"; return registrar; } -extern void register_reduction_operators(LibraryContext* context); - void registration_callback() { ResourceConfig config; @@ -42,9 +40,6 @@ void registration_callback() cunumeric_library_name, config, std::make_unique()); CuNumericRegistrar::get_registrar().register_all_tasks(context); - - // Register our special reduction functions - register_reduction_operators(context); } } // namespace cunumeric diff --git a/src/cunumeric/cunumeric.cu b/src/cunumeric/cunumeric.cu deleted file mode 100644 index 87cd3a85e..000000000 --- a/src/cunumeric/cunumeric.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021-2022 NVIDIA Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include "cunumeric.h" -#include "arg.h" -#include "arg.inl" - -namespace cunumeric { - -#define REGISTER_REDOPS(OP) \ - { \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - context->register_reduction_operator>(); \ - } - -void register_reduction_operators(legate::LibraryContext* context) -{ - REGISTER_REDOPS(ArgmaxReduction); - REGISTER_REDOPS(ArgminReduction); -} - -} // namespace cunumeric diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 8cfcb0ac8..c3145939e 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -319,25 +319,13 @@ enum CuNumericFFTDirection { CUNUMERIC_FFT_FORWARD = -1, CUNUMERIC_FFT_INVERSE = // Match these to Bitorder in config.py enum CuNumericBitorder { CUNUMERIC_BITORDER_BIG = 0, CUNUMERIC_BITORDER_LITTLE = 1 }; -// Match these to CuNumericTypeCodes in config.py -enum CuNumericTypeCodes { - CUNUMERIC_TYPE_POINT1 = MAX_TYPE_NUMBER + 1, - CUNUMERIC_TYPE_POINT2, - CUNUMERIC_TYPE_POINT3, - CUNUMERIC_TYPE_POINT4, - CUNUMERIC_TYPE_POINT5, - CUNUMERIC_TYPE_POINT6, - CUNUMERIC_TYPE_POINT7, - CUNUMERIC_TYPE_POINT8, - CUNUMERIC_TYPE_POINT9, -}; - #ifdef __cplusplus extern "C" { #endif void cunumeric_perform_registration(); bool cunumeric_has_curand(); +void cunumeric_register_reduction_op(int32_t type_uid, int32_t elem_type_code); #ifdef __cplusplus } diff --git a/src/cunumeric/fft/fft.cu b/src/cunumeric/fft/fft.cu index 45b550584..4fb5bfea0 100644 --- a/src/cunumeric/fft/fft.cu +++ b/src/cunumeric/fft/fft.cu @@ -363,7 +363,7 @@ __host__ static inline void cufft_over_axis_r2c_c2r(AccessorWO CHECK_CUFFT(cufftDestroy(plan)); } -template +template struct FFTImplBody { using INPUT_TYPE = legate_type_of; using OUTPUT_TYPE = legate_type_of; diff --git a/src/cunumeric/fft/fft_template.inl b/src/cunumeric/fft/fft_template.inl index de26bd017..1ad05f9af 100644 --- a/src/cunumeric/fft/fft_template.inl +++ b/src/cunumeric/fft/fft_template.inl @@ -27,14 +27,14 @@ using namespace legate; template struct FFTImplBody; template struct FFTImpl { - template ::valid)>* = nullptr> void operator()(FFTArgs& args) const @@ -54,7 +54,7 @@ struct FFTImpl { } // We only support up to 3D FFTs for now - template 3) || !FFT::valid)>* = nullptr> void operator()(FFTArgs& args) const diff --git a/src/cunumeric/fft/fft_util.h b/src/cunumeric/fft/fft_util.h index 04428a1ab..dea461ccc 100644 --- a/src/cunumeric/fft/fft_util.h +++ b/src/cunumeric/fft/fft_util.h @@ -44,45 +44,45 @@ constexpr decltype(auto) fft_dispatch(CuNumericFFTType type, Functor f, Fnargs&& return f.template operator()(std::forward(args)...); } -template +template struct FFT { static constexpr bool valid = false; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::COMPLEX64_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::COMPLEX64; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::FLOAT_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::FLOAT32; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::COMPLEX64_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::COMPLEX64; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::COMPLEX128_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::COMPLEX128; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::DOUBLE_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::FLOAT64; }; template <> -struct FFT { - static constexpr bool valid = true; - static constexpr LegateTypeCode CODE_OUT = LegateTypeCode::COMPLEX128_LT; +struct FFT { + static constexpr bool valid = true; + static constexpr Type::Code CODE_OUT = Type::Code::COMPLEX128; }; } // namespace cunumeric diff --git a/src/cunumeric/index/advanced_indexing.cc b/src/cunumeric/index/advanced_indexing.cc index b50a8c279..19c1101e2 100644 --- a/src/cunumeric/index/advanced_indexing.cc +++ b/src/cunumeric/index/advanced_indexing.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct AdvancedIndexingImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/advanced_indexing.cu b/src/cunumeric/index/advanced_indexing.cu index 2f5e90d25..5b808a563 100644 --- a/src/cunumeric/index/advanced_indexing.cu +++ b/src/cunumeric/index/advanced_indexing.cu @@ -75,7 +75,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct AdvancedIndexingImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/advanced_indexing_omp.cc b/src/cunumeric/index/advanced_indexing_omp.cc index f82cbdb2a..11cec094e 100644 --- a/src/cunumeric/index/advanced_indexing_omp.cc +++ b/src/cunumeric/index/advanced_indexing_omp.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct AdvancedIndexingImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/advanced_indexing_template.inl b/src/cunumeric/index/advanced_indexing_template.inl index 5178cb35e..fb160adff 100644 --- a/src/cunumeric/index/advanced_indexing_template.inl +++ b/src/cunumeric/index/advanced_indexing_template.inl @@ -24,14 +24,14 @@ namespace cunumeric { using namespace legate; -template +template struct AdvancedIndexingImplBody; template struct AdvancedIndexingImpl { // current implementaion of the ND-output regions requires all regions // to have the same DIM. - template + template void operator()(AdvancedIndexingArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/index/choose.cc b/src/cunumeric/index/choose.cc index ecd1f052f..ed4b0a0cf 100644 --- a/src/cunumeric/index/choose.cc +++ b/src/cunumeric/index/choose.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct ChooseImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/choose.cu b/src/cunumeric/index/choose.cu index 1a042f941..5deab68bb 100644 --- a/src/cunumeric/index/choose.cu +++ b/src/cunumeric/index/choose.cu @@ -45,7 +45,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) cho outptr[idx] = choices[indexptr[idx]][idx]; } -template +template struct ChooseImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/choose_omp.cc b/src/cunumeric/index/choose_omp.cc index 14006aa01..19bf12ee2 100644 --- a/src/cunumeric/index/choose_omp.cc +++ b/src/cunumeric/index/choose_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct ChooseImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/choose_template.inl b/src/cunumeric/index/choose_template.inl index 58affeee1..9399f736a 100644 --- a/src/cunumeric/index/choose_template.inl +++ b/src/cunumeric/index/choose_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct ChooseImplBody; template struct ChooseImpl { - template + template void operator()(ChooseArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/index/putmask_template.inl b/src/cunumeric/index/putmask_template.inl index 463f41473..3a85d2044 100644 --- a/src/cunumeric/index/putmask_template.inl +++ b/src/cunumeric/index/putmask_template.inl @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct Putmask { using T = legate_type_of; using IN = AccessorRW; @@ -92,7 +92,7 @@ using namespace legate; template struct PutmaskImpl { - template + template void operator()(PutmaskArgs& args) const { Putmask putmask(args); diff --git a/src/cunumeric/index/repeat.cc b/src/cunumeric/index/repeat.cc index d6c317173..9222d7c5f 100644 --- a/src/cunumeric/index/repeat.cc +++ b/src/cunumeric/index/repeat.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct RepeatImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/repeat.cu b/src/cunumeric/index/repeat.cu index cc78378a5..634050b9d 100644 --- a/src/cunumeric/index/repeat.cu +++ b/src/cunumeric/index/repeat.cu @@ -93,7 +93,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct RepeatImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/repeat_omp.cc b/src/cunumeric/index/repeat_omp.cc index 1e9018e15..9ff130634 100644 --- a/src/cunumeric/index/repeat_omp.cc +++ b/src/cunumeric/index/repeat_omp.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct RepeatImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/index/repeat_template.inl b/src/cunumeric/index/repeat_template.inl index 3b52141f5..d6173dde8 100644 --- a/src/cunumeric/index/repeat_template.inl +++ b/src/cunumeric/index/repeat_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct RepeatImplBody; template struct RepeatImpl { - template + template void operator()(RepeatArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/item/read_template.inl b/src/cunumeric/item/read_template.inl index 66478bdef..b13a00345 100644 --- a/src/cunumeric/item/read_template.inl +++ b/src/cunumeric/item/read_template.inl @@ -28,7 +28,7 @@ struct ReadImplBody; template struct ReadImpl { - template + template void operator()(const Array& out_arr, const Array& in_arr) const { using VAL = legate_type_of; diff --git a/src/cunumeric/item/write_template.inl b/src/cunumeric/item/write_template.inl index 6595f3edd..41b18e01c 100644 --- a/src/cunumeric/item/write_template.inl +++ b/src/cunumeric/item/write_template.inl @@ -28,7 +28,7 @@ struct WriteImplBody; template struct WriteImpl { - template + template void operator()(Array& out_arr, Array& in_arr) const { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/contract.cc b/src/cunumeric/matrix/contract.cc index 47fe3c1b0..eab84e70e 100644 --- a/src/cunumeric/matrix/contract.cc +++ b/src/cunumeric/matrix/contract.cc @@ -33,7 +33,7 @@ using namespace tblis; // to appease the type checker. template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(float* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -65,7 +65,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(double* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -97,7 +97,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(__half* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -133,29 +133,29 @@ struct ContractImplBody { float* rhs2_copy_data = allocate_buffer(rhs2_size); half_tensor_to_float(rhs2_copy_data, rhs2_data, rhs2_ndim, rhs2_shape, rhs2_strides); - ContractImplBody{}(lhs_copy_data, - lhs_ndim, - lhs_shape, - lhs_copy_strides.data(), - lhs_modes, - rhs1_copy_data, - rhs1_ndim, - rhs1_shape, - rhs1_copy_strides.data(), - rhs1_modes, - rhs2_copy_data, - rhs2_ndim, - rhs2_shape, - rhs2_copy_strides.data(), - rhs2_modes, - lhs_overwritable); + ContractImplBody{}(lhs_copy_data, + lhs_ndim, + lhs_shape, + lhs_copy_strides.data(), + lhs_modes, + rhs1_copy_data, + rhs1_ndim, + rhs1_shape, + rhs1_copy_strides.data(), + rhs1_modes, + rhs2_copy_data, + rhs2_ndim, + rhs2_shape, + rhs2_copy_strides.data(), + rhs2_modes, + lhs_overwritable); float_tensor_to_half(lhs_data, lhs_copy_data, lhs_ndim, lhs_shape, lhs_strides); } }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -198,7 +198,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, diff --git a/src/cunumeric/matrix/contract.cu b/src/cunumeric/matrix/contract.cu index 3748ac7ab..3d4155106 100644 --- a/src/cunumeric/matrix/contract.cu +++ b/src/cunumeric/matrix/contract.cu @@ -152,7 +152,7 @@ __host__ void contract(T* lhs_data, } template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(__half* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -190,7 +190,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(float* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -228,7 +228,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(double* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -266,7 +266,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -304,7 +304,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, diff --git a/src/cunumeric/matrix/contract_omp.cc b/src/cunumeric/matrix/contract_omp.cc index 539ac9a74..698690cfe 100644 --- a/src/cunumeric/matrix/contract_omp.cc +++ b/src/cunumeric/matrix/contract_omp.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace tblis; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(float* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -58,7 +58,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(double* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -90,7 +90,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(__half* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -126,29 +126,29 @@ struct ContractImplBody { float* rhs2_copy_data = allocate_buffer(rhs2_size); half_tensor_to_float(rhs2_copy_data, rhs2_data, rhs2_ndim, rhs2_shape, rhs2_strides); - ContractImplBody{}(lhs_copy_data, - lhs_ndim, - lhs_shape, - lhs_copy_strides.data(), - lhs_modes, - rhs1_copy_data, - rhs1_ndim, - rhs1_shape, - rhs1_copy_strides.data(), - rhs1_modes, - rhs2_copy_data, - rhs2_ndim, - rhs2_shape, - rhs2_copy_strides.data(), - rhs2_modes, - lhs_overwritable); + ContractImplBody{}(lhs_copy_data, + lhs_ndim, + lhs_shape, + lhs_copy_strides.data(), + lhs_modes, + rhs1_copy_data, + rhs1_ndim, + rhs1_shape, + rhs1_copy_strides.data(), + rhs1_modes, + rhs2_copy_data, + rhs2_ndim, + rhs2_shape, + rhs2_copy_strides.data(), + rhs2_modes, + lhs_overwritable); float_tensor_to_half(lhs_data, lhs_copy_data, lhs_ndim, lhs_shape, lhs_strides); } }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, @@ -191,7 +191,7 @@ struct ContractImplBody { }; template <> -struct ContractImplBody { +struct ContractImplBody { void operator()(complex* lhs_data, size_t lhs_ndim, int64_t* lhs_shape, diff --git a/src/cunumeric/matrix/contract_template.inl b/src/cunumeric/matrix/contract_template.inl index 1600bd7fe..a7fa69fa1 100644 --- a/src/cunumeric/matrix/contract_template.inl +++ b/src/cunumeric/matrix/contract_template.inl @@ -28,21 +28,21 @@ namespace cunumeric { using namespace legate; -template +template struct ContractImplBody; -template +template struct support_contract : std::false_type {}; template <> -struct support_contract : std::true_type {}; +struct support_contract : std::true_type {}; template <> -struct support_contract : std::true_type {}; +struct support_contract : std::true_type {}; template <> -struct support_contract : std::true_type {}; +struct support_contract : std::true_type {}; template <> -struct support_contract : std::true_type {}; +struct support_contract : std::true_type {}; template <> -struct support_contract : std::true_type {}; +struct support_contract : std::true_type {}; #if 0 // debugging output @@ -77,9 +77,7 @@ void print_ptr(const char* title, const T* vals, size_t len) template struct ContractImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(ContractArgs& args) const { using T = legate_type_of; @@ -195,9 +193,7 @@ struct ContractImpl { #endif } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(ContractArgs& args) const { assert(false); diff --git a/src/cunumeric/matrix/diag.cc b/src/cunumeric/matrix/diag.cc index 65d7d5cfb..84140b1af 100644 --- a/src/cunumeric/matrix/diag.cc +++ b/src/cunumeric/matrix/diag.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct DiagImplBody { using VAL = legate_type_of; @@ -52,7 +52,7 @@ struct DiagImplBody { }; // not extract (create a new 2D matrix with diagonal from vector) -template +template struct DiagImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/diag.cu b/src/cunumeric/matrix/diag.cu index 4797cfb56..17b6a2564 100644 --- a/src/cunumeric/matrix/diag.cu +++ b/src/cunumeric/matrix/diag.cu @@ -60,7 +60,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct DiagImplBody { using VAL = legate_type_of; @@ -92,7 +92,7 @@ struct DiagImplBody { }; // not extract (create a new 2D matrix with diagonal from vector) -template +template struct DiagImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/diag_omp.cc b/src/cunumeric/matrix/diag_omp.cc index 986561e25..5d2224d0c 100644 --- a/src/cunumeric/matrix/diag_omp.cc +++ b/src/cunumeric/matrix/diag_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct DiagImplBody { using VAL = legate_type_of; @@ -57,7 +57,7 @@ struct DiagImplBody { }; // not extract (create a new 2D matrix with diagonal from vector) -template +template struct DiagImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/diag_template.inl b/src/cunumeric/matrix/diag_template.inl index d28a9c7c0..f2bf33b3f 100644 --- a/src/cunumeric/matrix/diag_template.inl +++ b/src/cunumeric/matrix/diag_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct DiagImplBody; template struct DiagImpl { - template + template void operator()(DiagArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/dot.cc b/src/cunumeric/matrix/dot.cc index b2b9d03c4..637ab6a4e 100644 --- a/src/cunumeric/matrix/dot.cc +++ b/src/cunumeric/matrix/dot.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct DotImplBody { using VAL = legate_type_of; using ACC = acc_type_of; diff --git a/src/cunumeric/matrix/dot.cu b/src/cunumeric/matrix/dot.cu index c5a3ffd9b..f99047932 100644 --- a/src/cunumeric/matrix/dot.cu +++ b/src/cunumeric/matrix/dot.cu @@ -43,7 +43,7 @@ static __global__ void __launch_bounds__(1, 1) copy_kernel(Buffer result, RedAcc out.reduce(0, result.read()); } -template +template struct DotImplBody { using VAL = legate_type_of; using ACC = acc_type_of; diff --git a/src/cunumeric/matrix/dot_omp.cc b/src/cunumeric/matrix/dot_omp.cc index c71c2c7d9..857ab8f26 100644 --- a/src/cunumeric/matrix/dot_omp.cc +++ b/src/cunumeric/matrix/dot_omp.cc @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct DotImplBody { using VAL = legate_type_of; using ACC = acc_type_of; diff --git a/src/cunumeric/matrix/dot_template.inl b/src/cunumeric/matrix/dot_template.inl index fd16e2783..fae14df13 100644 --- a/src/cunumeric/matrix/dot_template.inl +++ b/src/cunumeric/matrix/dot_template.inl @@ -23,7 +23,7 @@ namespace cunumeric { using namespace legate; -template +template struct DotImplBody; template @@ -41,7 +41,7 @@ using acc_type_of = typename AccTypeOf::type; template struct DotImpl { - template + template void operator()(DotArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/gemm.cc b/src/cunumeric/matrix/gemm.cc index eee1205fe..4160bb03e 100644 --- a/src/cunumeric/matrix/gemm.cc +++ b/src/cunumeric/matrix/gemm.cc @@ -47,7 +47,7 @@ static inline void complex_gemm_template( } template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(float* lhs, const float* rhs1, const float* rhs2, int32_t m, int32_t n, int32_t k) { gemm_template(cblas_sgemm, lhs, rhs1, rhs2, m, n, k); @@ -55,7 +55,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()( double* lhs, const double* rhs1, const double* rhs2, int32_t m, int32_t n, int32_t k) { @@ -64,7 +64,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, @@ -81,7 +81,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, diff --git a/src/cunumeric/matrix/gemm.cu b/src/cunumeric/matrix/gemm.cu index 7b17c8477..8fff167ff 100644 --- a/src/cunumeric/matrix/gemm.cu +++ b/src/cunumeric/matrix/gemm.cu @@ -62,7 +62,7 @@ static inline void complex_gemm_template( } template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(float* lhs, const float* rhs1, const float* rhs2, int32_t m, int32_t n, int32_t k) { gemm_template(cublasSgemm, lhs, rhs1, rhs2, m, n, k); @@ -70,7 +70,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()( double* lhs, const double* rhs1, const double* rhs2, int32_t m, int32_t n, int32_t k) { @@ -79,7 +79,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, @@ -96,7 +96,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, diff --git a/src/cunumeric/matrix/gemm_omp.cc b/src/cunumeric/matrix/gemm_omp.cc index 0af9e6023..69b20c673 100644 --- a/src/cunumeric/matrix/gemm_omp.cc +++ b/src/cunumeric/matrix/gemm_omp.cc @@ -48,7 +48,7 @@ static inline void complex_gemm_template( } template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(float* lhs, const float* rhs1, const float* rhs2, int32_t m, int32_t n, int32_t k) { gemm_template(cblas_sgemm, lhs, rhs1, rhs2, m, n, k); @@ -56,7 +56,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()( double* lhs, const double* rhs1, const double* rhs2, int32_t m, int32_t n, int32_t k) { @@ -65,7 +65,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, @@ -82,7 +82,7 @@ struct GemmImplBody { }; template <> -struct GemmImplBody { +struct GemmImplBody { void operator()(complex* lhs_, const complex* rhs1_, const complex* rhs2_, diff --git a/src/cunumeric/matrix/gemm_template.inl b/src/cunumeric/matrix/gemm_template.inl index 92d500398..09178be1b 100644 --- a/src/cunumeric/matrix/gemm_template.inl +++ b/src/cunumeric/matrix/gemm_template.inl @@ -23,23 +23,23 @@ namespace cunumeric { using namespace legate; -template +template struct GemmImplBody; -template +template struct support_gemm : std::false_type {}; template <> -struct support_gemm : std::true_type {}; +struct support_gemm : std::true_type {}; template <> -struct support_gemm : std::true_type {}; +struct support_gemm : std::true_type {}; template <> -struct support_gemm : std::true_type {}; +struct support_gemm : std::true_type {}; template <> -struct support_gemm : std::true_type {}; +struct support_gemm : std::true_type {}; template struct GemmImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs1_array, Array& rhs2_array) const { using VAL = legate_type_of; @@ -67,7 +67,7 @@ struct GemmImpl { GemmImplBody()(lhs, rhs1, rhs2, m, n, k); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs1_array, Array& rhs2_array) const { assert(false); diff --git a/src/cunumeric/matrix/matmul.cu b/src/cunumeric/matrix/matmul.cu index 7ed65f832..934deb965 100644 --- a/src/cunumeric/matrix/matmul.cu +++ b/src/cunumeric/matrix/matmul.cu @@ -28,7 +28,7 @@ namespace cunumeric { // for this matrix shape and GPU. template <> -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -73,7 +73,7 @@ struct MatMulImplBody { }; template <> -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -114,7 +114,7 @@ struct MatMulImplBody { }; template <> -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -158,7 +158,7 @@ struct MatMulImplBody { }; template <> -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -206,7 +206,7 @@ struct MatMulImplBody { }; template <> -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, diff --git a/src/cunumeric/matrix/matmul_cpu.inl b/src/cunumeric/matrix/matmul_cpu.inl index e059ac384..16e286045 100644 --- a/src/cunumeric/matrix/matmul_cpu.inl +++ b/src/cunumeric/matrix/matmul_cpu.inl @@ -28,7 +28,7 @@ using namespace Legion; using namespace legate; template -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -61,7 +61,7 @@ struct MatMulImplBody { }; template -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -93,7 +93,7 @@ struct MatMulImplBody { }; template -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -138,7 +138,7 @@ struct MatMulImplBody { }; template -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, @@ -176,7 +176,7 @@ struct MatMulImplBody { }; template -struct MatMulImplBody { +struct MatMulImplBody { void operator()(size_t m, size_t n, size_t k, diff --git a/src/cunumeric/matrix/matmul_template.inl b/src/cunumeric/matrix/matmul_template.inl index b1ff0a4ba..967860f53 100644 --- a/src/cunumeric/matrix/matmul_template.inl +++ b/src/cunumeric/matrix/matmul_template.inl @@ -24,35 +24,35 @@ namespace cunumeric { using namespace legate; -template +template struct MatMulImplBody; -template +template struct support_matmul : std::false_type {}; template <> -struct support_matmul : std::true_type { +struct support_matmul : std::true_type { using ACC_TYPE = double; }; template <> -struct support_matmul : std::true_type { +struct support_matmul : std::true_type { using ACC_TYPE = float; }; template <> -struct support_matmul : std::true_type { +struct support_matmul : std::true_type { using ACC_TYPE = float; }; template <> -struct support_matmul : std::true_type { +struct support_matmul : std::true_type { using ACC_TYPE = complex; }; template <> -struct support_matmul : std::true_type { +struct support_matmul : std::true_type { using ACC_TYPE = complex; }; template struct MatMulImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(MatMulArgs& args) const { using VAL = legate_type_of; @@ -105,7 +105,7 @@ struct MatMulImpl { args.lhs.is_readable()); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(MatMulArgs& args) const { assert(false); diff --git a/src/cunumeric/matrix/matvecmul.cu b/src/cunumeric/matrix/matvecmul.cu index 53c5f7104..d54e28b48 100644 --- a/src/cunumeric/matrix/matvecmul.cu +++ b/src/cunumeric/matrix/matvecmul.cu @@ -22,7 +22,7 @@ namespace cunumeric { template <> -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, float* lhs, @@ -75,7 +75,7 @@ struct MatVecMulImplBody { }; template <> -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, double* lhs, @@ -122,7 +122,7 @@ struct MatVecMulImplBody { }; template <> -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, float* lhs, @@ -164,7 +164,7 @@ struct MatVecMulImplBody { }; template <> -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, complex* lhs_, @@ -218,7 +218,7 @@ struct MatVecMulImplBody { }; template <> -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, complex* lhs_, diff --git a/src/cunumeric/matrix/matvecmul_cpu.inl b/src/cunumeric/matrix/matvecmul_cpu.inl index 92e99a2c6..6797d701b 100644 --- a/src/cunumeric/matrix/matvecmul_cpu.inl +++ b/src/cunumeric/matrix/matvecmul_cpu.inl @@ -28,7 +28,7 @@ using namespace Legion; using namespace legate; template -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, float* lhs, @@ -46,7 +46,7 @@ struct MatVecMulImplBody { }; template -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, double* lhs, @@ -63,7 +63,7 @@ struct MatVecMulImplBody { }; template -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, float* lhs, @@ -81,13 +81,13 @@ struct MatVecMulImplBody { half_matrix_to_float(mat_copy, mat, m, n, mat_stride); half_vector_to_float(vec_copy, vec, vec_size); - MatVecMulImplBody{}( + MatVecMulImplBody{}( m, n, lhs, mat_copy, vec_copy, n, transpose_mat, lhs_overwritable); } }; template -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, complex* lhs_, @@ -109,7 +109,7 @@ struct MatVecMulImplBody { }; template -struct MatVecMulImplBody { +struct MatVecMulImplBody { void operator()(size_t m, size_t n, complex* lhs_, diff --git a/src/cunumeric/matrix/matvecmul_template.inl b/src/cunumeric/matrix/matvecmul_template.inl index 26c3ba876..547d376d1 100644 --- a/src/cunumeric/matrix/matvecmul_template.inl +++ b/src/cunumeric/matrix/matvecmul_template.inl @@ -24,35 +24,35 @@ namespace cunumeric { using namespace legate; -template +template struct MatVecMulImplBody; -template +template struct support_matvecmul : std::false_type {}; template <> -struct support_matvecmul : std::true_type { +struct support_matvecmul : std::true_type { using ACC_TYPE = double; }; template <> -struct support_matvecmul : std::true_type { +struct support_matvecmul : std::true_type { using ACC_TYPE = float; }; template <> -struct support_matvecmul : std::true_type { +struct support_matvecmul : std::true_type { using ACC_TYPE = float; }; template <> -struct support_matvecmul : std::true_type { +struct support_matvecmul : std::true_type { using ACC_TYPE = complex; }; template <> -struct support_matvecmul : std::true_type { +struct support_matvecmul : std::true_type { using ACC_TYPE = complex; }; template struct MatVecMulImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(MatVecMulArgs& args) const { using VAL = legate_type_of; @@ -86,7 +86,7 @@ struct MatVecMulImpl { m, n, lhs, mat, vec, mat_stride, transpose_mat, args.lhs.is_readable()); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(MatVecMulArgs& args) const { assert(false); diff --git a/src/cunumeric/matrix/potrf.cc b/src/cunumeric/matrix/potrf.cc index a49be25da..02ae06246 100644 --- a/src/cunumeric/matrix/potrf.cc +++ b/src/cunumeric/matrix/potrf.cc @@ -25,7 +25,7 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(float* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -36,7 +36,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(double* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -47,7 +47,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -58,7 +58,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { char uplo = 'L'; diff --git a/src/cunumeric/matrix/potrf.cu b/src/cunumeric/matrix/potrf.cu index 0a8bba066..68616525f 100644 --- a/src/cunumeric/matrix/potrf.cu +++ b/src/cunumeric/matrix/potrf.cu @@ -49,7 +49,7 @@ static inline void potrf_template( } template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(float* array, int32_t m, int32_t n) { potrf_template(cusolverDnSpotrf_bufferSize, cusolverDnSpotrf, array, m, n); @@ -57,7 +57,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(double* array, int32_t m, int32_t n) { potrf_template(cusolverDnDpotrf_bufferSize, cusolverDnDpotrf, array, m, n); @@ -65,7 +65,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { potrf_template( @@ -74,7 +74,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { potrf_template(cusolverDnZpotrf_bufferSize, diff --git a/src/cunumeric/matrix/potrf_omp.cc b/src/cunumeric/matrix/potrf_omp.cc index 51e729bc1..d26143a6f 100644 --- a/src/cunumeric/matrix/potrf_omp.cc +++ b/src/cunumeric/matrix/potrf_omp.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(float* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -37,7 +37,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(double* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -48,7 +48,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { char uplo = 'L'; @@ -59,7 +59,7 @@ struct PotrfImplBody { }; template <> -struct PotrfImplBody { +struct PotrfImplBody { void operator()(complex* array, int32_t m, int32_t n) { char uplo = 'L'; diff --git a/src/cunumeric/matrix/potrf_template.inl b/src/cunumeric/matrix/potrf_template.inl index afceecfe5..55c782ad0 100644 --- a/src/cunumeric/matrix/potrf_template.inl +++ b/src/cunumeric/matrix/potrf_template.inl @@ -23,23 +23,23 @@ namespace cunumeric { using namespace legate; -template +template struct PotrfImplBody; -template +template struct support_potrf : std::false_type {}; template <> -struct support_potrf : std::true_type {}; +struct support_potrf : std::true_type {}; template <> -struct support_potrf : std::true_type {}; +struct support_potrf : std::true_type {}; template <> -struct support_potrf : std::true_type {}; +struct support_potrf : std::true_type {}; template <> -struct support_potrf : std::true_type {}; +struct support_potrf : std::true_type {}; template struct PotrfImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& array) const { using VAL = legate_type_of; @@ -58,7 +58,7 @@ struct PotrfImpl { PotrfImplBody()(arr, m, n); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& array) const { assert(false); diff --git a/src/cunumeric/matrix/solve.cu b/src/cunumeric/matrix/solve.cu index 3e490bb78..3f3262b15 100644 --- a/src/cunumeric/matrix/solve.cu +++ b/src/cunumeric/matrix/solve.cu @@ -61,7 +61,7 @@ static inline void solve_template(GetrfBufferSize getrf_buffer_size, } template <> -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, float* a, float* b) { solve_template( @@ -70,7 +70,7 @@ struct SolveImplBody { }; template <> -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, double* a, double* b) { solve_template( @@ -79,7 +79,7 @@ struct SolveImplBody { }; template <> -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a, complex* b) { solve_template(cusolverDnCgetrf_bufferSize, @@ -94,7 +94,7 @@ struct SolveImplBody { }; template <> -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a, complex* b) { solve_template(cusolverDnZgetrf_bufferSize, diff --git a/src/cunumeric/matrix/solve_cpu.inl b/src/cunumeric/matrix/solve_cpu.inl index 1c036ee61..7275a2c0e 100644 --- a/src/cunumeric/matrix/solve_cpu.inl +++ b/src/cunumeric/matrix/solve_cpu.inl @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; template -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, float* a, float* b) { auto ipiv = create_buffer(std::min(m, n)); @@ -37,7 +37,7 @@ struct SolveImplBody { }; template -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, double* a, double* b) { auto ipiv = create_buffer(std::min(m, n)); @@ -50,7 +50,7 @@ struct SolveImplBody { }; template -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a_, complex* b_) { auto ipiv = create_buffer(std::min(m, n)); @@ -66,7 +66,7 @@ struct SolveImplBody { }; template -struct SolveImplBody { +struct SolveImplBody { void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a_, complex* b_) { auto ipiv = create_buffer(std::min(m, n)); diff --git a/src/cunumeric/matrix/solve_template.inl b/src/cunumeric/matrix/solve_template.inl index 3fa48b778..e338b8326 100644 --- a/src/cunumeric/matrix/solve_template.inl +++ b/src/cunumeric/matrix/solve_template.inl @@ -25,23 +25,23 @@ namespace cunumeric { using namespace legate; -template +template struct SolveImplBody; -template +template struct support_solve : std::false_type {}; template <> -struct support_solve : std::true_type {}; +struct support_solve : std::true_type {}; template <> -struct support_solve : std::true_type {}; +struct support_solve : std::true_type {}; template <> -struct support_solve : std::true_type {}; +struct support_solve : std::true_type {}; template <> -struct support_solve : std::true_type {}; +struct support_solve : std::true_type {}; template struct SolveImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& a_array, Array& b_array) const { using VAL = legate_type_of; @@ -95,7 +95,7 @@ struct SolveImpl { SolveImplBody()(m, n, nrhs, a, b); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& a_array, Array& b_array) const { assert(false); diff --git a/src/cunumeric/matrix/syrk.cc b/src/cunumeric/matrix/syrk.cc index 2149bda5c..2fa5bc64c 100644 --- a/src/cunumeric/matrix/syrk.cc +++ b/src/cunumeric/matrix/syrk.cc @@ -33,7 +33,7 @@ static inline void syrk_template(Syrk syrk, VAL* lhs, const VAL* rhs, int32_t m, } template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { syrk_template(cblas_ssyrk, lhs, rhs, m, n); @@ -41,7 +41,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { syrk_template(cblas_dsyrk, lhs, rhs, m, n); @@ -49,7 +49,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ float*>(lhs_); @@ -64,7 +64,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ double*>(lhs_); diff --git a/src/cunumeric/matrix/syrk.cu b/src/cunumeric/matrix/syrk.cu index d7f38bcdd..1fdbd2ca6 100644 --- a/src/cunumeric/matrix/syrk.cu +++ b/src/cunumeric/matrix/syrk.cu @@ -42,7 +42,7 @@ static inline void syrk_template( } template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { syrk_template(cublasSsyrk, lhs, rhs, m, n, static_cast(0)); @@ -50,7 +50,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { syrk_template(cublasDsyrk, lhs, rhs, m, n, static_cast(0)); @@ -58,7 +58,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast(lhs_); @@ -69,7 +69,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast(lhs_); diff --git a/src/cunumeric/matrix/syrk_omp.cc b/src/cunumeric/matrix/syrk_omp.cc index 849429aac..b276d71a2 100644 --- a/src/cunumeric/matrix/syrk_omp.cc +++ b/src/cunumeric/matrix/syrk_omp.cc @@ -34,7 +34,7 @@ static inline void syrk_template(Syrk syrk, VAL* lhs, const VAL* rhs, int32_t m, } template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { syrk_template(cblas_ssyrk, lhs, rhs, m, n); @@ -42,7 +42,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { syrk_template(cblas_dsyrk, lhs, rhs, m, n); @@ -50,7 +50,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ float*>(lhs_); @@ -65,7 +65,7 @@ struct SyrkImplBody { }; template <> -struct SyrkImplBody { +struct SyrkImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ double*>(lhs_); diff --git a/src/cunumeric/matrix/syrk_template.inl b/src/cunumeric/matrix/syrk_template.inl index 66490b34e..58ea4abae 100644 --- a/src/cunumeric/matrix/syrk_template.inl +++ b/src/cunumeric/matrix/syrk_template.inl @@ -23,23 +23,23 @@ namespace cunumeric { using namespace legate; -template +template struct SyrkImplBody; -template +template struct support_syrk : std::false_type {}; template <> -struct support_syrk : std::true_type {}; +struct support_syrk : std::true_type {}; template <> -struct support_syrk : std::true_type {}; +struct support_syrk : std::true_type {}; template <> -struct support_syrk : std::true_type {}; +struct support_syrk : std::true_type {}; template <> -struct support_syrk : std::true_type {}; +struct support_syrk : std::true_type {}; template struct SyrkImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs_array) const { using VAL = legate_type_of; @@ -62,7 +62,7 @@ struct SyrkImpl { SyrkImplBody()(lhs, rhs, m, n); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs_array) const { assert(false); diff --git a/src/cunumeric/matrix/tile_template.inl b/src/cunumeric/matrix/tile_template.inl index 89f6edd55..35f8dc967 100644 --- a/src/cunumeric/matrix/tile_template.inl +++ b/src/cunumeric/matrix/tile_template.inl @@ -67,7 +67,7 @@ struct TileImpl { template struct TileDispatch { - template + template void operator()(TileArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/transpose.cc b/src/cunumeric/matrix/transpose.cc index bc829f440..224a36ab2 100644 --- a/src/cunumeric/matrix/transpose.cc +++ b/src/cunumeric/matrix/transpose.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct TransposeImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/transpose.cu b/src/cunumeric/matrix/transpose.cu index 0d4c210e7..5ccd3ef7a 100644 --- a/src/cunumeric/matrix/transpose.cu +++ b/src/cunumeric/matrix/transpose.cu @@ -136,7 +136,7 @@ __global__ static void __launch_bounds__((TILE_DIM * BLOCK_ROWS), MIN_CTAS_PER_S } } -template +template struct TransposeImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/transpose_omp.cc b/src/cunumeric/matrix/transpose_omp.cc index c1750434f..729719242 100644 --- a/src/cunumeric/matrix/transpose_omp.cc +++ b/src/cunumeric/matrix/transpose_omp.cc @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct TransposeImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/transpose_template.inl b/src/cunumeric/matrix/transpose_template.inl index 3c8de1f6e..4d695c3cd 100644 --- a/src/cunumeric/matrix/transpose_template.inl +++ b/src/cunumeric/matrix/transpose_template.inl @@ -23,12 +23,12 @@ namespace cunumeric { using namespace legate; -template +template struct TransposeImplBody; template struct TransposeImpl { - template + template void operator()(TransposeArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/trilu.cc b/src/cunumeric/matrix/trilu.cc index 8b44c2517..7d0e55e4f 100644 --- a/src/cunumeric/matrix/trilu.cc +++ b/src/cunumeric/matrix/trilu.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct TriluImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/trilu.cu b/src/cunumeric/matrix/trilu.cu index 2158f9dbf..6a8c7a02b 100644 --- a/src/cunumeric/matrix/trilu.cu +++ b/src/cunumeric/matrix/trilu.cu @@ -50,7 +50,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct TriluImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/trilu_omp.cc b/src/cunumeric/matrix/trilu_omp.cc index 6e9e9598f..b4e2482da 100644 --- a/src/cunumeric/matrix/trilu_omp.cc +++ b/src/cunumeric/matrix/trilu_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct TriluImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/matrix/trilu_template.inl b/src/cunumeric/matrix/trilu_template.inl index 4c7d019a4..ca417f5af 100644 --- a/src/cunumeric/matrix/trilu_template.inl +++ b/src/cunumeric/matrix/trilu_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct TriluImplBody; template struct TriluImpl { - template = 2)>* = nullptr> + template = 2)>* = nullptr> void operator()(TriluArgs& args) const { using VAL = legate_type_of; @@ -59,7 +59,7 @@ struct TriluImpl { } } - template * = nullptr> + template * = nullptr> void operator()(TriluArgs& args) const { assert(false); diff --git a/src/cunumeric/matrix/trsm.cc b/src/cunumeric/matrix/trsm.cc index 32382465d..e61c86981 100644 --- a/src/cunumeric/matrix/trsm.cc +++ b/src/cunumeric/matrix/trsm.cc @@ -49,7 +49,7 @@ static inline void complex_trsm_template(Trsm trsm, VAL* lhs, const VAL* rhs, in } template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { trsm_template(cblas_strsm, lhs, rhs, m, n); @@ -57,7 +57,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { trsm_template(cblas_dtrsm, lhs, rhs, m, n); @@ -65,7 +65,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ float*>(lhs_); @@ -76,7 +76,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ double*>(lhs_); diff --git a/src/cunumeric/matrix/trsm.cu b/src/cunumeric/matrix/trsm.cu index 05595ee28..8bd5d66c7 100644 --- a/src/cunumeric/matrix/trsm.cu +++ b/src/cunumeric/matrix/trsm.cu @@ -43,7 +43,7 @@ static inline void trsm_template( } template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { trsm_template(cublasStrsm, lhs, rhs, m, n, 1.0F); @@ -51,7 +51,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { trsm_template(cublasDtrsm, lhs, rhs, m, n, 1.0); @@ -59,7 +59,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast(lhs_); @@ -70,7 +70,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast(lhs_); diff --git a/src/cunumeric/matrix/trsm_omp.cc b/src/cunumeric/matrix/trsm_omp.cc index 255a04cf0..2041ec17a 100644 --- a/src/cunumeric/matrix/trsm_omp.cc +++ b/src/cunumeric/matrix/trsm_omp.cc @@ -50,7 +50,7 @@ static inline void complex_trsm_template(Trsm trsm, VAL* lhs, const VAL* rhs, in } template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(float* lhs, const float* rhs, int32_t m, int32_t n) { trsm_template(cblas_strsm, lhs, rhs, m, n); @@ -58,7 +58,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(double* lhs, const double* rhs, int32_t m, int32_t n) { trsm_template(cblas_dtrsm, lhs, rhs, m, n); @@ -66,7 +66,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ float*>(lhs_); @@ -77,7 +77,7 @@ struct TrsmImplBody { }; template <> -struct TrsmImplBody { +struct TrsmImplBody { void operator()(complex* lhs_, const complex* rhs_, int32_t m, int32_t n) { auto lhs = reinterpret_cast<__complex__ double*>(lhs_); diff --git a/src/cunumeric/matrix/trsm_template.inl b/src/cunumeric/matrix/trsm_template.inl index d214aa9dd..28f37ba1b 100644 --- a/src/cunumeric/matrix/trsm_template.inl +++ b/src/cunumeric/matrix/trsm_template.inl @@ -23,23 +23,23 @@ namespace cunumeric { using namespace legate; -template +template struct TrsmImplBody; -template +template struct support_trsm : std::false_type {}; template <> -struct support_trsm : std::true_type {}; +struct support_trsm : std::true_type {}; template <> -struct support_trsm : std::true_type {}; +struct support_trsm : std::true_type {}; template <> -struct support_trsm : std::true_type {}; +struct support_trsm : std::true_type {}; template <> -struct support_trsm : std::true_type {}; +struct support_trsm : std::true_type {}; template struct TrsmImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs_array) const { using VAL = legate_type_of; @@ -62,7 +62,7 @@ struct TrsmImpl { TrsmImplBody()(lhs, rhs, m, n); } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(Array& lhs_array, Array& rhs_array) const { assert(false); diff --git a/src/cunumeric/nullary/arange_template.inl b/src/cunumeric/nullary/arange_template.inl index 97c87ef46..c71b9c44e 100644 --- a/src/cunumeric/nullary/arange_template.inl +++ b/src/cunumeric/nullary/arange_template.inl @@ -31,7 +31,7 @@ struct ArangeImplBody; template struct ArangeImpl { - template + template void operator()(ArangeArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/nullary/eye_template.inl b/src/cunumeric/nullary/eye_template.inl index 2554dd5fd..33cbe6054 100644 --- a/src/cunumeric/nullary/eye_template.inl +++ b/src/cunumeric/nullary/eye_template.inl @@ -31,7 +31,7 @@ struct EyeImplBody; template struct EyeImpl { - template + template void operator()(EyeArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/nullary/fill_template.inl b/src/cunumeric/nullary/fill_template.inl index 4c6726242..dc9c2f609 100644 --- a/src/cunumeric/nullary/fill_template.inl +++ b/src/cunumeric/nullary/fill_template.inl @@ -54,7 +54,7 @@ struct FillImpl { FillImplBody{}(out, fill_value, pitches, rect, dense); } - template + template void operator()(FillArgs& args) const { if (args.is_argval) { @@ -71,7 +71,15 @@ template static void fill_template(TaskContext& context) { FillArgs args{context.outputs()[0], context.inputs()[0], context.scalars()[0].value()}; - double_dispatch(args.out.dim(), args.out.code(), FillImpl{}, args); + Type::Code code{args.out.code()}; + if (Type::Code::STRUCT == code) { +#ifdef DEBUG_CUNUMERIC + assert(args.is_argval); +#endif + auto& field_type = static_cast(args.out.type()).field_type(1); + code = field_type.code; + } + double_dispatch(args.out.dim(), code, FillImpl{}, args); } } // namespace cunumeric diff --git a/src/cunumeric/random/rand_template.inl b/src/cunumeric/random/rand_template.inl index db5d4ccee..3b689a728 100644 --- a/src/cunumeric/random/rand_template.inl +++ b/src/cunumeric/random/rand_template.inl @@ -31,7 +31,7 @@ struct RandImplBody; template struct RandImpl { - template ::valid>* = nullptr> void operator()(RandArgs& args) const @@ -53,7 +53,7 @@ struct RandImpl { RandImplBody{}(out, rng, strides, pitches, rect); } - template ::valid>* = nullptr> void operator()(RandArgs& args) const diff --git a/src/cunumeric/random/rand_util.h b/src/cunumeric/random/rand_util.h index b492ccce0..11988c5ec 100644 --- a/src/cunumeric/random/rand_util.h +++ b/src/cunumeric/random/rand_util.h @@ -46,15 +46,15 @@ constexpr decltype(auto) op_dispatch(RandGenCode gen_code, Functor f, Fnargs&&.. return f.template operator()(std::forward(args)...); } -template +template struct RandomGenerator { static constexpr bool valid = false; }; -template +template struct RandomGenerator { using RNG = Philox_2x32<10>; - static constexpr bool valid = CODE == legate::LegateTypeCode::DOUBLE_LT; + static constexpr bool valid = CODE == legate::Type::Code::FLOAT64; RandomGenerator(uint32_t ep, const std::vector& args) : epoch(ep) {} @@ -66,10 +66,10 @@ struct RandomGenerator { uint32_t epoch; }; -template +template struct RandomGenerator { using RNG = Philox_2x32<10>; - static constexpr bool valid = CODE == legate::LegateTypeCode::DOUBLE_LT; + static constexpr bool valid = CODE == legate::Type::Code::FLOAT64; RandomGenerator(uint32_t ep, const std::vector& args) : epoch(ep) {} @@ -174,7 +174,7 @@ struct RandomGenerator { uint32_t epoch; }; -template +template struct RandomGenerator { using RNG = Philox_2x32<10>; using VAL = legate::legate_type_of; diff --git a/src/cunumeric/scan/scan_global.cc b/src/cunumeric/scan/scan_global.cc index 753a84bcb..2df4ae14d 100644 --- a/src/cunumeric/scan/scan_global.cc +++ b/src/cunumeric/scan/scan_global.cc @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScanGlobalImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_global.cu b/src/cunumeric/scan/scan_global.cu index 0be6ef994..ba1c5da9d 100644 --- a/src/cunumeric/scan/scan_global.cu +++ b/src/cunumeric/scan/scan_global.cu @@ -35,7 +35,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[idx] = func(out[idx], scalar); } -template +template struct ScanGlobalImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_global_omp.cc b/src/cunumeric/scan/scan_global_omp.cc index cdb80f3d9..3ad989aca 100644 --- a/src/cunumeric/scan/scan_global_omp.cc +++ b/src/cunumeric/scan/scan_global_omp.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScanGlobalImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_global_template.inl b/src/cunumeric/scan/scan_global_template.inl index 099a357d8..b96007dc2 100644 --- a/src/cunumeric/scan/scan_global_template.inl +++ b/src/cunumeric/scan/scan_global_template.inl @@ -21,12 +21,12 @@ namespace cunumeric { using namespace legate; -template +template struct ScanGlobalImplBody; template struct ScanGlobalImpl { - template + template void operator()(ScanGlobalArgs& args) const { using OP = ScanOp; diff --git a/src/cunumeric/scan/scan_global_util.h b/src/cunumeric/scan/scan_global_util.h index 502b9720c..ce2e8b522 100644 --- a/src/cunumeric/scan/scan_global_util.h +++ b/src/cunumeric/scan/scan_global_util.h @@ -40,16 +40,16 @@ constexpr decltype(auto) op_dispatch(ScanCode op_code, Functor f, Fnargs&&... ar return f.template operator()(std::forward(args)...); } -template +template struct ScanOp {}; -template +template struct ScanOp : thrust::plus> { static constexpr int nan_identity = 0; ScanOp() {} }; -template +template struct ScanOp : thrust::multiplies> { static constexpr int nan_identity = 1; ScanOp() {} diff --git a/src/cunumeric/scan/scan_local.cc b/src/cunumeric/scan/scan_local.cc index bfc52d49f..3c49147c1 100644 --- a/src/cunumeric/scan/scan_local.cc +++ b/src/cunumeric/scan/scan_local.cc @@ -26,7 +26,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScanLocalImplBody { using OP = ScanOp; using VAL = legate_type_of; @@ -62,7 +62,7 @@ struct ScanLocalImplBody { } }; -template +template struct ScanLocalNanImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index 258e0e282..da4053182 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -37,7 +37,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) sum_val[0] = out[0]; } -template +template struct ScanLocalImplBody { using OP = ScanOp; using VAL = legate_type_of; @@ -77,7 +77,7 @@ struct ScanLocalImplBody { } }; -template +template struct ScanLocalNanImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_local_omp.cc b/src/cunumeric/scan/scan_local_omp.cc index 2b80ab97a..4fb2d2841 100644 --- a/src/cunumeric/scan/scan_local_omp.cc +++ b/src/cunumeric/scan/scan_local_omp.cc @@ -28,7 +28,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScanLocalImplBody { using OP = ScanOp; using VAL = legate_type_of; @@ -64,7 +64,7 @@ struct ScanLocalImplBody { } }; -template +template struct ScanLocalNanImplBody { using OP = ScanOp; using VAL = legate_type_of; diff --git a/src/cunumeric/scan/scan_local_template.inl b/src/cunumeric/scan/scan_local_template.inl index c016873bb..154a86b35 100644 --- a/src/cunumeric/scan/scan_local_template.inl +++ b/src/cunumeric/scan/scan_local_template.inl @@ -21,16 +21,16 @@ namespace cunumeric { using namespace legate; -template +template struct ScanLocalImplBody; -template +template struct ScanLocalNanImplBody; template struct ScanLocalImpl { // Case where NANs are transformed - template ::value || legate::is_complex::value)>* = nullptr> @@ -56,7 +56,7 @@ struct ScanLocalImpl { ScanLocalNanImplBody()(func, out, in, args.sum_vals, pitches, rect); } // Case where NANs are as is - template ::value || legate::is_complex::value))>* = nullptr> diff --git a/src/cunumeric/scan/scan_local_util.h b/src/cunumeric/scan/scan_local_util.h index 0cfbacb00..b62db7a83 100644 --- a/src/cunumeric/scan/scan_local_util.h +++ b/src/cunumeric/scan/scan_local_util.h @@ -52,16 +52,16 @@ constexpr decltype(auto) op_dispatch(ScanCode op_code, return f.template operator()(std::forward(args)...); } -template +template struct ScanOp {}; -template +template struct ScanOp : thrust::plus> { static constexpr int nan_identity = 0; ScanOp() {} }; -template +template struct ScanOp : thrust::multiplies> { static constexpr int nan_identity = 1; ScanOp() {} diff --git a/src/cunumeric/search/argwhere.cc b/src/cunumeric/search/argwhere.cc index a3eed173d..a787c2f4c 100644 --- a/src/cunumeric/search/argwhere.cc +++ b/src/cunumeric/search/argwhere.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct ArgWhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/argwhere.cu b/src/cunumeric/search/argwhere.cu index d4131ca6d..09819aca7 100644 --- a/src/cunumeric/search/argwhere.cu +++ b/src/cunumeric/search/argwhere.cu @@ -41,7 +41,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct ArgWhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/argwhere_omp.cc b/src/cunumeric/search/argwhere_omp.cc index 3cea7fbd0..51555b684 100644 --- a/src/cunumeric/search/argwhere_omp.cc +++ b/src/cunumeric/search/argwhere_omp.cc @@ -23,7 +23,7 @@ namespace cunumeric { using namespace legate; -template +template struct ArgWhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/argwhere_template.inl b/src/cunumeric/search/argwhere_template.inl index 5c0a57e5e..5c1a91a85 100644 --- a/src/cunumeric/search/argwhere_template.inl +++ b/src/cunumeric/search/argwhere_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct ArgWhereImplBody; template struct ArgWhereImpl { - template + template void operator()(ArgWhereArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/search/nonzero.cc b/src/cunumeric/search/nonzero.cc index 93f869b88..5e2da5113 100644 --- a/src/cunumeric/search/nonzero.cc +++ b/src/cunumeric/search/nonzero.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct NonzeroImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/nonzero.cu b/src/cunumeric/search/nonzero.cu index 92fcb5047..38cfa8480 100644 --- a/src/cunumeric/search/nonzero.cu +++ b/src/cunumeric/search/nonzero.cu @@ -39,7 +39,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template +template struct NonzeroImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/nonzero_omp.cc b/src/cunumeric/search/nonzero_omp.cc index 690202aee..e07fb5170 100644 --- a/src/cunumeric/search/nonzero_omp.cc +++ b/src/cunumeric/search/nonzero_omp.cc @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct NonzeroImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/search/nonzero_template.inl b/src/cunumeric/search/nonzero_template.inl index fde09c76a..fb9935535 100644 --- a/src/cunumeric/search/nonzero_template.inl +++ b/src/cunumeric/search/nonzero_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct NonzeroImplBody; template struct NonzeroImpl { - template + template void operator()(NonzeroArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique.cc b/src/cunumeric/set/unique.cc index ed0f28f49..7aa09d0e5 100644 --- a/src/cunumeric/set/unique.cc +++ b/src/cunumeric/set/unique.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct UniqueImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique.cu b/src/cunumeric/set/unique.cu index 6d67eecae..302077c5f 100644 --- a/src/cunumeric/set/unique.cu +++ b/src/cunumeric/set/unique.cu @@ -139,7 +139,7 @@ static Piece tree_reduce(Array& output, return my_piece; } -template +template struct UniqueImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique_omp.cc b/src/cunumeric/set/unique_omp.cc index 411fda749..37a86582b 100644 --- a/src/cunumeric/set/unique_omp.cc +++ b/src/cunumeric/set/unique_omp.cc @@ -23,7 +23,7 @@ namespace cunumeric { using namespace legate; -template +template struct UniqueImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique_reduce.cc b/src/cunumeric/set/unique_reduce.cc index d18db95a1..29442e371 100644 --- a/src/cunumeric/set/unique_reduce.cc +++ b/src/cunumeric/set/unique_reduce.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct UniqueReduceImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique_reduce_template.inl b/src/cunumeric/set/unique_reduce_template.inl index 9a0fb4415..5a6a3aab0 100644 --- a/src/cunumeric/set/unique_reduce_template.inl +++ b/src/cunumeric/set/unique_reduce_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct UniqueReduceImplBody; template struct UniqueReduceImpl { - template + template void operator()(Array& output, std::vector& input_arrs) { using VAL = legate_type_of; diff --git a/src/cunumeric/set/unique_template.inl b/src/cunumeric/set/unique_template.inl index fe3046756..1ab1a7e1f 100644 --- a/src/cunumeric/set/unique_template.inl +++ b/src/cunumeric/set/unique_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct UniqueImplBody; template struct UniqueImpl { - template + template void operator()(Array& output, Array& input, std::vector& comms, diff --git a/src/cunumeric/sort/searchsorted.cc b/src/cunumeric/sort/searchsorted.cc index 174deb333..6b8fdb4cd 100644 --- a/src/cunumeric/sort/searchsorted.cc +++ b/src/cunumeric/sort/searchsorted.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct SearchSortedImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/searchsorted.cu b/src/cunumeric/sort/searchsorted.cu index c62892e8c..5f98b0259 100644 --- a/src/cunumeric/sort/searchsorted.cu +++ b/src/cunumeric/sort/searchsorted.cu @@ -64,7 +64,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) if (upper_bound > 0) { output_reduction.reduce(v_point, upper_bound + global_offset); } } -template +template struct SearchSortedImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/searchsorted_omp.cc b/src/cunumeric/sort/searchsorted_omp.cc index 115c14214..6c695494c 100644 --- a/src/cunumeric/sort/searchsorted_omp.cc +++ b/src/cunumeric/sort/searchsorted_omp.cc @@ -23,7 +23,7 @@ namespace cunumeric { using namespace legate; -template +template struct SearchSortedImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/searchsorted_template.inl b/src/cunumeric/sort/searchsorted_template.inl index 30acf5202..8ccd0661f 100644 --- a/src/cunumeric/sort/searchsorted_template.inl +++ b/src/cunumeric/sort/searchsorted_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct SearchSortedImplBody; template struct SearchSortedImpl { - template + template void operator()(SearchSortedArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/sort.cc b/src/cunumeric/sort/sort.cc index 517865fc7..3835a3598 100644 --- a/src/cunumeric/sort/sort.cc +++ b/src/cunumeric/sort/sort.cc @@ -28,7 +28,7 @@ namespace cunumeric { using namespace legate; -template +template struct SortImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/sort.cu b/src/cunumeric/sort/sort.cu index 0851f0229..0e5dc39f9 100644 --- a/src/cunumeric/sort/sort.cu +++ b/src/cunumeric/sort/sort.cu @@ -42,14 +42,14 @@ namespace cunumeric { -template +template struct support_cub : std::true_type {}; template <> -struct support_cub : std::false_type {}; +struct support_cub : std::false_type {}; template <> -struct support_cub : std::false_type {}; +struct support_cub : std::false_type {}; -template ::value>* = nullptr> +template ::value>* = nullptr> void local_sort(const legate_type_of* values_in, legate_type_of* values_out, const int64_t* indices_in, @@ -69,7 +69,7 @@ void local_sort(const legate_type_of* values_in, } } -template ::value>* = nullptr> +template ::value>* = nullptr> void local_sort(const legate_type_of* values_in, legate_type_of* values_out, const int64_t* indices_in, @@ -566,7 +566,7 @@ struct negative_plus : public thrust::binary_function ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template SegmentMergePiece> merge_all_buffers( std::vector>>& merge_buffers, bool segmented, @@ -1187,7 +1187,7 @@ void rebalance_data(SegmentMergePiece& merge_buffer, ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template void sample_sort_nccl_nd(SortPiece> local_sorted, Array& output_array_unbound, // only for unbound usage when !rebalance void* output_ptr, @@ -1658,7 +1658,7 @@ void sample_sort_nccl_nd(SortPiece> local_sorted, ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct SortImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/sort_cpu.inl b/src/cunumeric/sort/sort_cpu.inl index aa738d5e6..6ab8a585c 100644 --- a/src/cunumeric/sort/sort_cpu.inl +++ b/src/cunumeric/sort/sort_cpu.inl @@ -441,7 +441,7 @@ void rebalance_data(SegmentMergePiece& merge_buffer, } } -template +template void sample_sort_nd(SortPiece> local_sorted, Array& output_array_unbound, // only for unbound usage when !rebalance void* output_ptr, @@ -552,7 +552,7 @@ void sample_sort_nd(SortPiece> local_sorted, /*comm::coll::collAllgather(p_samples + num_samples_l * my_sort_rank, p_samples, num_samples_l * sizeof(SegmentSample), - comm::coll::CollDataType::CollUint8, + comm::coll::CollDataType::Code::CollUint8, comm);*/ // workaround - using alltoallv to mimic allgather on subset @@ -894,7 +894,7 @@ void sample_sort_nd(SortPiece> local_sorted, } } -template +template struct SortImplBodyCpu { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/sort_omp.cc b/src/cunumeric/sort/sort_omp.cc index 92aa751da..e117439a8 100644 --- a/src/cunumeric/sort/sort_omp.cc +++ b/src/cunumeric/sort/sort_omp.cc @@ -29,7 +29,7 @@ namespace cunumeric { using namespace legate; -template +template struct SortImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/sort/sort_template.inl b/src/cunumeric/sort/sort_template.inl index 19a927a5c..0a4d1c16b 100644 --- a/src/cunumeric/sort/sort_template.inl +++ b/src/cunumeric/sort/sort_template.inl @@ -24,7 +24,7 @@ namespace cunumeric { using namespace legate; -template +template struct SortImplBody; static int get_rank(Domain domain, DomainPoint index_point) @@ -41,7 +41,7 @@ static int get_rank(Domain domain, DomainPoint index_point) template struct SortImpl { - template + template void operator()(SortArgs& args, std::vector& comms) const { using VAL = legate_type_of; diff --git a/src/cunumeric/stat/bincount.cc b/src/cunumeric/stat/bincount.cc index dc73f69de..d4806cbab 100644 --- a/src/cunumeric/stat/bincount.cc +++ b/src/cunumeric/stat/bincount.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct BincountImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/stat/bincount.cu b/src/cunumeric/stat/bincount.cu index d4996a993..2ae4a0d05 100644 --- a/src/cunumeric/stat/bincount.cu +++ b/src/cunumeric/stat/bincount.cu @@ -143,7 +143,7 @@ static __global__ void weighted_bincount_kernel_rd_global( lhs[bin] <<= weights[idx + origin[0]]; } -template +template struct BincountImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/stat/bincount_omp.cc b/src/cunumeric/stat/bincount_omp.cc index 9d8f6375a..4f21e95a8 100644 --- a/src/cunumeric/stat/bincount_omp.cc +++ b/src/cunumeric/stat/bincount_omp.cc @@ -23,7 +23,7 @@ namespace cunumeric { using namespace legate; -template +template struct BincountImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/stat/bincount_template.inl b/src/cunumeric/stat/bincount_template.inl index 00034a486..83ae638e1 100644 --- a/src/cunumeric/stat/bincount_template.inl +++ b/src/cunumeric/stat/bincount_template.inl @@ -23,12 +23,12 @@ namespace cunumeric { using namespace legate; -template +template struct BincountImplBody; template struct BincountImpl { - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(BincountArgs& args) const { using VAL = legate_type_of; @@ -50,7 +50,7 @@ struct BincountImpl { } } - template ::value>* = nullptr> + template ::value>* = nullptr> void operator()(BincountArgs& args) const { assert(false); diff --git a/src/cunumeric/ternary/where.cc b/src/cunumeric/ternary/where.cc index 449ff3b46..85c602522 100644 --- a/src/cunumeric/ternary/where.cc +++ b/src/cunumeric/ternary/where.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct WhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/ternary/where.cu b/src/cunumeric/ternary/where.cu index f1d1594a2..a9dfdb1a3 100644 --- a/src/cunumeric/ternary/where.cu +++ b/src/cunumeric/ternary/where.cu @@ -40,7 +40,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) gen out[point] = mask[point] ? in1[point] : in2[point]; } -template +template struct WhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/ternary/where_omp.cc b/src/cunumeric/ternary/where_omp.cc index 26beea4bd..dd0ed7e55 100644 --- a/src/cunumeric/ternary/where_omp.cc +++ b/src/cunumeric/ternary/where_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct WhereImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/ternary/where_template.inl b/src/cunumeric/ternary/where_template.inl index 6ea668354..ccdc78b5a 100644 --- a/src/cunumeric/ternary/where_template.inl +++ b/src/cunumeric/ternary/where_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct WhereImplBody; template struct WhereImpl { - template + template void operator()(WhereArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/transform/flip.cc b/src/cunumeric/transform/flip.cc index 946cdd4a9..3aa332d57 100644 --- a/src/cunumeric/transform/flip.cc +++ b/src/cunumeric/transform/flip.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct FlipImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/transform/flip.cu b/src/cunumeric/transform/flip.cu index 88ef54227..8c6dc166b 100644 --- a/src/cunumeric/transform/flip.cu +++ b/src/cunumeric/transform/flip.cu @@ -41,7 +41,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[p] = in[q]; } -template +template struct FlipImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/transform/flip_omp.cc b/src/cunumeric/transform/flip_omp.cc index ce39ba88d..775fd6802 100644 --- a/src/cunumeric/transform/flip_omp.cc +++ b/src/cunumeric/transform/flip_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct FlipImplBody { using VAL = legate_type_of; diff --git a/src/cunumeric/transform/flip_template.inl b/src/cunumeric/transform/flip_template.inl index 82279da9b..6af541fc6 100644 --- a/src/cunumeric/transform/flip_template.inl +++ b/src/cunumeric/transform/flip_template.inl @@ -24,12 +24,12 @@ namespace cunumeric { using namespace legate; -template +template struct FlipImplBody; template struct FlipImpl { - template + template void operator()(FlipArgs& args) const { using VAL = legate_type_of; diff --git a/src/cunumeric/unary/convert.cc b/src/cunumeric/unary/convert.cc index d7ab32fc3..a3fae7fbb 100644 --- a/src/cunumeric/unary/convert.cc +++ b/src/cunumeric/unary/convert.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct ConvertImplBody { using OP = ConvertOp; using SRC = legate_type_of; diff --git a/src/cunumeric/unary/convert.cu b/src/cunumeric/unary/convert.cu index 7b839131e..ea1d7cfb1 100644 --- a/src/cunumeric/unary/convert.cu +++ b/src/cunumeric/unary/convert.cu @@ -40,7 +40,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[point] = func(in[point]); } -template +template struct ConvertImplBody { using OP = ConvertOp; using SRC = legate_type_of; diff --git a/src/cunumeric/unary/convert_omp.cc b/src/cunumeric/unary/convert_omp.cc index 139d84221..de2f20478 100644 --- a/src/cunumeric/unary/convert_omp.cc +++ b/src/cunumeric/unary/convert_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct ConvertImplBody { using OP = ConvertOp; using SRC = legate_type_of; diff --git a/src/cunumeric/unary/convert_template.inl b/src/cunumeric/unary/convert_template.inl index 075843ab3..8d507d35f 100644 --- a/src/cunumeric/unary/convert_template.inl +++ b/src/cunumeric/unary/convert_template.inl @@ -25,16 +25,12 @@ namespace cunumeric { using namespace legate; -template +template struct ConvertImplBody; -template +template struct ConvertImpl { - template * = nullptr> + template * = nullptr> void operator()(ConvertArgs& args) const { using OP = ConvertOp; @@ -63,14 +59,14 @@ struct ConvertImpl { ConvertImplBody()(func, out, in, pitches, rect, dense); } - template * = nullptr> + template * = nullptr> void operator()(ConvertArgs& args) const { assert(false); } }; -template +template struct ConvertDispatch { template ::value || @@ -94,7 +90,7 @@ struct ConvertDispatch { template struct SourceTypeDispatch { - template + template void operator()(ConvertArgs& args) const { op_dispatch(args.nan_op, ConvertDispatch{}, args); diff --git a/src/cunumeric/unary/convert_util.h b/src/cunumeric/unary/convert_util.h index f58c0265c..5fb340fd7 100644 --- a/src/cunumeric/unary/convert_util.h +++ b/src/cunumeric/unary/convert_util.h @@ -43,10 +43,10 @@ constexpr decltype(auto) op_dispatch(ConvertCode nan_op, Functor f, Fnargs&&... return f.template operator()(std::forward(args)...); } -template +template struct ConvertOp {}; -template +template struct ConvertOp { using SRC = legate::legate_type_of; using DST = legate::legate_type_of; @@ -64,7 +64,7 @@ struct ConvertOp { !legate::is_complex_type::value>* = nullptr> constexpr DST operator()(const _SRC& src) const { - if constexpr (DST_TYPE == legate::LegateTypeCode::BOOL_LT) + if constexpr (DST_TYPE == legate::Type::Code::BOOL) return static_cast(src.real()) || static_cast(src.imag()); else return static_cast(src.real()); @@ -74,8 +74,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using SRC = legate::legate_type_of; template ::value>* = nullptr> @@ -91,8 +91,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using DST = legate::legate_type_of; constexpr DST operator()(const __half& src) const @@ -101,7 +101,7 @@ struct ConvertOp { } }; -template +template struct ConvertOp { using SRC = legate::legate_type_of; using DST = legate::legate_type_of; @@ -123,8 +123,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using SRC = legate::legate_type_of; template ::value>* = nullptr> @@ -142,8 +142,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using DST = legate::legate_type_of; constexpr DST operator()(const __half& src) const @@ -153,7 +153,7 @@ struct ConvertOp { } }; -template +template struct ConvertOp { using SRC = legate::legate_type_of; using DST = legate::legate_type_of; @@ -175,8 +175,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using SRC = legate::legate_type_of; template ::value>* = nullptr> @@ -194,8 +194,8 @@ struct ConvertOp { } }; -template -struct ConvertOp { +template +struct ConvertOp { using DST = legate::legate_type_of; constexpr DST operator()(const __half& src) const diff --git a/src/cunumeric/unary/scalar_unary_red_template.inl b/src/cunumeric/unary/scalar_unary_red_template.inl index 1f57be92d..198a38cc7 100644 --- a/src/cunumeric/unary/scalar_unary_red_template.inl +++ b/src/cunumeric/unary/scalar_unary_red_template.inl @@ -28,7 +28,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScalarUnaryRed { using OP = UnaryRedOp; using LG_OP = typename OP::OP; @@ -116,7 +116,7 @@ struct ScalarUnaryRed { template struct ScalarUnaryRedImpl { - template + template void operator()(ScalarUnaryRedArgs& args) const { // The operation is always valid for contains diff --git a/src/cunumeric/unary/unary_op.cc b/src/cunumeric/unary/unary_op.cc index 6004ac759..53c085113 100644 --- a/src/cunumeric/unary/unary_op.cc +++ b/src/cunumeric/unary/unary_op.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryOpImplBody { using OP = UnaryOp; using ARG = typename OP::T; @@ -70,7 +70,7 @@ struct PointCopyImplBody { } }; -template +template struct MultiOutUnaryOpImplBody { using OP = MultiOutUnaryOp; using RHS1 = typename OP::RHS1; diff --git a/src/cunumeric/unary/unary_op.cu b/src/cunumeric/unary/unary_op.cu index 8bbb21872..41de2e20b 100644 --- a/src/cunumeric/unary/unary_op.cu +++ b/src/cunumeric/unary/unary_op.cu @@ -63,7 +63,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[point] = in[point]; } -template +template struct UnaryOpImplBody { using OP = UnaryOp; using ARG = typename OP::T; @@ -143,7 +143,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) lhs[point] = func(rhs1[point], rhs2.ptr(point)); } -template +template struct MultiOutUnaryOpImplBody { using OP = MultiOutUnaryOp; using RHS1 = typename OP::RHS1; diff --git a/src/cunumeric/unary/unary_op.h b/src/cunumeric/unary/unary_op.h index c277c5d77..a4439dbd8 100644 --- a/src/cunumeric/unary/unary_op.h +++ b/src/cunumeric/unary/unary_op.h @@ -52,187 +52,112 @@ class UnaryOpTask : public CuNumericTask { template struct inner_type_dispatch_fn { template - constexpr decltype(auto) operator()(CuNumericTypeCodes code, Functor f, Fnargs&&... args) + constexpr decltype(auto) operator()(int point_dim, Functor f, Fnargs&&... args) { - switch (code) { + switch (point_dim) { #if LEGATE_MAX_DIM >= 1 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT1: { - return f.template operator()( - std::forward(args)...); + case 1: { + return f.template operator()<1, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 2 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT2: { - return f.template operator()( - std::forward(args)...); + case 2: { + return f.template operator()<2, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 3 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT3: { - return f.template operator()( - std::forward(args)...); + case 3: { + return f.template operator()<3, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 4 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT4: { - return f.template operator()( - std::forward(args)...); + case 4: { + return f.template operator()<4, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 5 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT5: { - return f.template operator()( - std::forward(args)...); + case 5: { + return f.template operator()<5, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 6 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT6: { - return f.template operator()( - std::forward(args)...); + case 6: { + return f.template operator()<6, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 7 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT7: { - return f.template operator()( - std::forward(args)...); + case 7: { + return f.template operator()<7, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 8 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT8: { - return f.template operator()( - std::forward(args)...); + case 8: { + return f.template operator()<8, DIM>(std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 9 - case CuNumericTypeCodes::CUNUMERIC_TYPE_POINT9: { - return f.template operator()( - std::forward(args)...); + case 9: { + return f.template operator()<9, DIM>(std::forward(args)...); } #endif default: assert(false); } - return f.template operator()( - std::forward(args)...); + return f.template operator()<1, DIM>(std::forward(args)...); } }; template -constexpr decltype(auto) double_dispatch(int dim, - CuNumericTypeCodes code, - Functor f, - Fnargs&&... args) +constexpr decltype(auto) double_dispatch(int dim, int point_dim, Functor f, Fnargs&&... args) { switch (dim) { #if LEGATE_MAX_DIM >= 1 case 1: { - return cunumeric::inner_type_dispatch_fn<1>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<1>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 2 case 2: { - return cunumeric::inner_type_dispatch_fn<2>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<2>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 3 case 3: { - return cunumeric::inner_type_dispatch_fn<3>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<3>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 4 case 4: { - return cunumeric::inner_type_dispatch_fn<4>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<4>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 5 case 5: { - return cunumeric::inner_type_dispatch_fn<5>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<5>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 6 case 6: { - return cunumeric::inner_type_dispatch_fn<6>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<6>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 7 case 7: { - return cunumeric::inner_type_dispatch_fn<7>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<7>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 8 case 8: { - return cunumeric::inner_type_dispatch_fn<8>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<8>{}(point_dim, f, std::forward(args)...); } #endif #if LEGATE_MAX_DIM >= 9 case 9: { - return cunumeric::inner_type_dispatch_fn<9>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<9>{}(point_dim, f, std::forward(args)...); } #endif } assert(false); - return cunumeric::inner_type_dispatch_fn<1>{}(code, f, std::forward(args)...); + return cunumeric::inner_type_dispatch_fn<1>{}(point_dim, f, std::forward(args)...); } -template -struct CuNumericTypeOf { - using type = legate::Point<1>; -}; -#if LEGATE_MAX_DIM >= 1 -template <> -struct CuNumericTypeOf { - using type = legate::Point<1>; -}; -#endif -#if LEGATE_MAX_DIM >= 2 -template <> -struct CuNumericTypeOf { - using type = legate::Point<2>; -}; -#endif -#if LEGATE_MAX_DIM >= 3 -template <> -struct CuNumericTypeOf { - using type = legate::Point<3>; -}; -#endif -#if LEGATE_MAX_DIM >= 4 -template <> -struct CuNumericTypeOf { - using type = legate::Point<4>; -}; -#endif -#if LEGATE_MAX_DIM >= 5 -template <> -struct CuNumericTypeOf { - using type = legate::Point<5>; -}; -#endif -#if LEGATE_MAX_DIM >= 6 -template <> -struct CuNumericTypeOf { - using type = legate::Point<6>; -}; -#endif -#if LEGATE_MAX_DIM >= 7 -template <> -struct CuNumericTypeOf { - using type = legate::Point<7>; -}; -#endif -#if LEGATE_MAX_DIM >= 8 -template <> -struct CuNumericTypeOf { - using type = legate::Point<8>; -}; -#endif -#if LEGATE_MAX_DIM >= 9 -template <> -struct CuNumericTypeOf { - using type = legate::Point<9>; -}; -#endif - -template -using cunumeric_type_of = typename CuNumericTypeOf::type; - } // namespace cunumeric diff --git a/src/cunumeric/unary/unary_op_omp.cc b/src/cunumeric/unary/unary_op_omp.cc index 8cbe683a5..1badb93a8 100644 --- a/src/cunumeric/unary/unary_op_omp.cc +++ b/src/cunumeric/unary/unary_op_omp.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryOpImplBody { using OP = UnaryOp; using ARG = typename OP::T; @@ -74,7 +74,7 @@ struct PointCopyImplBody { } }; -template +template struct MultiOutUnaryOpImplBody { using OP = MultiOutUnaryOp; using RHS1 = typename OP::RHS1; diff --git a/src/cunumeric/unary/unary_op_template.inl b/src/cunumeric/unary/unary_op_template.inl index 74f882b91..548cba9bf 100644 --- a/src/cunumeric/unary/unary_op_template.inl +++ b/src/cunumeric/unary/unary_op_template.inl @@ -24,20 +24,18 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryOpImplBody; template struct PointCopyImplBody; -template +template struct MultiOutUnaryOpImplBody; template struct UnaryOpImpl { - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(UnaryOpArgs& args) const { using OP = UnaryOp; @@ -66,9 +64,7 @@ struct UnaryOpImpl { UnaryOpImplBody()(func, out, in, pitches, rect, dense); } - template ::valid>* = nullptr> + template ::valid>* = nullptr> void operator()(UnaryOpArgs& args) const { assert(false); @@ -77,7 +73,7 @@ struct UnaryOpImpl { template struct MultiOutUnaryOpImpl { - template ::valid>* = nullptr> void operator()(MultiOutUnaryOpArgs& args) const @@ -112,7 +108,7 @@ struct MultiOutUnaryOpImpl { func, lhs, rhs1, rhs2, pitches, rect, dense); } - template ::valid>* = nullptr> void operator()(MultiOutUnaryOpArgs& args) const @@ -123,17 +119,17 @@ struct MultiOutUnaryOpImpl { template struct UnaryCopyImpl { - template + template void operator()(UnaryOpArgs& args) const { using VAL = legate_type_of; execute_copy(args); } - template + template void operator()(UnaryOpArgs& args) const { - using VAL = cunumeric_type_of; + using VAL = Point; execute_copy(args); } @@ -168,12 +164,13 @@ struct UnaryOpDispatch { void operator()(UnaryOpArgs& args) const { auto dim = std::max(args.in.dim(), 1); - if ((OP_CODE == UnaryOpCode::COPY) && - (args.in.code() > LegateTypeCode::MAX_TYPE_NUMBER)) - cunumeric::double_dispatch( - dim, args.in.code(), UnaryCopyImpl{}, args); - else - legate::double_dispatch(dim, args.in.code(), UnaryOpImpl{}, args); + if ((OP_CODE == UnaryOpCode::COPY) && (args.in.code() == Type::Code::FIXED_ARRAY)) { + auto& type = static_cast(args.in.type()); + cunumeric::double_dispatch(dim, type.num_elements(), UnaryCopyImpl{}, args); + } else { + auto code = OP_CODE == UnaryOpCode::GETARG ? args.out.code() : args.in.code(); + legate::double_dispatch(dim, code, UnaryOpImpl{}, args); + } } }; diff --git a/src/cunumeric/unary/unary_op_util.h b/src/cunumeric/unary/unary_op_util.h index 2f6fab59a..f309cbcf7 100644 --- a/src/cunumeric/unary/unary_op_util.h +++ b/src/cunumeric/unary/unary_op_util.h @@ -183,20 +183,20 @@ constexpr decltype(auto) op_dispatch(UnaryOpCode op_code, Functor f, Fnargs&&... return f.template operator()(std::forward(args)...); } -template +template static constexpr bool is_floating_point = - legate::is_floating_point::value || CODE == legate::LegateTypeCode::HALF_LT; + legate::is_floating_point::value || CODE == legate::Type::Code::FLOAT16; -template +template static constexpr bool is_floating_or_complex = is_floating_point || legate::is_complex::value; -template +template struct UnaryOp { static constexpr bool valid = false; }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -235,7 +235,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -249,7 +249,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -264,7 +264,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -277,7 +277,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -291,7 +291,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -306,7 +306,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -319,7 +319,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -333,7 +333,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -348,7 +348,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -361,7 +361,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = legate::is_floating_point::value; using T = legate::legate_type_of; @@ -376,7 +376,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -389,7 +389,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_point; using T = legate::legate_type_of; @@ -403,7 +403,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -421,7 +421,7 @@ struct UnaryOp { T max; }; -template +template struct UnaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -441,7 +441,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -451,7 +451,7 @@ struct UnaryOp { constexpr T operator()(const T& x) const { return x; } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -465,7 +465,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -480,7 +480,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -493,7 +493,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_point; using T = legate::legate_type_of; @@ -504,7 +504,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -516,7 +516,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -530,7 +530,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -558,7 +558,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -571,7 +571,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -594,7 +594,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -607,7 +607,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_point; using T = legate::legate_type_of; @@ -621,7 +621,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { using T = Argval>; static constexpr bool valid = true; @@ -631,7 +631,7 @@ struct UnaryOp { constexpr decltype(auto) operator()(const T& x) const { return x.arg; } }; -template +template struct UnaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_complex_type::value; @@ -641,10 +641,10 @@ struct UnaryOp { constexpr decltype(auto) operator()(const T& x) const { return x.imag(); } }; -template +template struct UnaryOp { static constexpr bool valid = - legate::is_integral::value && CODE != legate::LegateTypeCode::BOOL_LT; + legate::is_integral::value && CODE != legate::Type::Code::BOOL; using T = legate::legate_type_of; UnaryOp(const std::vector& args) {} @@ -652,7 +652,7 @@ struct UnaryOp { constexpr T operator()(const T& x) const { return ~x; } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -680,7 +680,7 @@ struct UnaryOp { __CUDA_HD__ bool operator()(const __half& x) const { return isfinite(static_cast(x)); } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -708,7 +708,7 @@ struct UnaryOp { __CUDA_HD__ bool operator()(const __half& x) const { return isinf(x); } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -737,7 +737,7 @@ struct UnaryOp { __CUDA_HD__ bool operator()(const __half& x) const { return isnan(x); } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; ; @@ -752,7 +752,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; ; @@ -768,7 +768,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -781,7 +781,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; ; @@ -805,7 +805,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -818,7 +818,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; ; @@ -842,7 +842,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -855,7 +855,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -875,7 +875,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -885,7 +885,7 @@ struct UnaryOp { constexpr T operator()(const T& x) const { return -x; } }; -template +template struct UnaryOp { static constexpr bool valid = legate::is_floating_point::value; using T = legate::legate_type_of; @@ -896,7 +896,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -908,7 +908,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { using T = legate::legate_type_of; static constexpr bool valid = legate::is_complex_type::value; @@ -918,7 +918,7 @@ struct UnaryOp { constexpr decltype(auto) operator()(const T& x) const { return x.real(); } }; -template +template struct UnaryOp { using T = legate::legate_type_of; static constexpr bool valid = true; @@ -933,7 +933,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { using T = __half; static constexpr bool valid = true; @@ -945,7 +945,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -966,7 +966,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -995,7 +995,7 @@ constexpr T sign(const T& x) } // namespace detail -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -1020,7 +1020,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -1032,7 +1032,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = legate::is_floating_point::value; using T = legate::legate_type_of; @@ -1047,7 +1047,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -1060,7 +1060,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -1074,7 +1074,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -1089,7 +1089,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -1102,7 +1102,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -1112,7 +1112,7 @@ struct UnaryOp { constexpr T operator()(const T& x) const { return x * x; } }; -template +template struct UnaryOp { static constexpr bool valid = true; using T = legate::legate_type_of; @@ -1126,7 +1126,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -1140,7 +1140,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = is_floating_or_complex; using T = legate::legate_type_of; @@ -1154,7 +1154,7 @@ struct UnaryOp { } }; -template +template struct UnaryOp { static constexpr bool valid = legate::is_floating_point::value; using T = legate::legate_type_of; @@ -1169,7 +1169,7 @@ struct UnaryOp { }; template <> -struct UnaryOp { +struct UnaryOp { static constexpr bool valid = true; using T = __half; @@ -1182,12 +1182,12 @@ struct UnaryOp { } }; -template +template struct MultiOutUnaryOp { static constexpr bool valid = false; }; -template +template struct MultiOutUnaryOp { static constexpr bool valid = legate::is_floating_point::value; using RHS1 = legate::legate_type_of; @@ -1202,7 +1202,7 @@ struct MultiOutUnaryOp { }; template <> -struct MultiOutUnaryOp { +struct MultiOutUnaryOp { static constexpr bool valid = true; using RHS1 = __half; using RHS2 = int32_t; @@ -1215,7 +1215,7 @@ struct MultiOutUnaryOp { } }; -template +template struct MultiOutUnaryOp { static constexpr bool valid = legate::is_floating_point::value; using RHS1 = legate::legate_type_of; @@ -1230,7 +1230,7 @@ struct MultiOutUnaryOp { }; template <> -struct MultiOutUnaryOp { +struct MultiOutUnaryOp { static constexpr bool valid = true; using RHS1 = __half; using RHS2 = __half; diff --git a/src/cunumeric/unary/unary_red.cc b/src/cunumeric/unary/unary_red.cc index 39bfc7e92..2aff3d907 100644 --- a/src/cunumeric/unary/unary_red.cc +++ b/src/cunumeric/unary/unary_red.cc @@ -21,7 +21,7 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; diff --git a/src/cunumeric/unary/unary_red.cu b/src/cunumeric/unary/unary_red.cu index 766f2a7fd..3aa8b4f0c 100644 --- a/src/cunumeric/unary/unary_red.cu +++ b/src/cunumeric/unary/unary_red.cu @@ -293,7 +293,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) if (result != identity) out.reduce(point, result); } -template +template struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; diff --git a/src/cunumeric/unary/unary_red_omp.cc b/src/cunumeric/unary/unary_red_omp.cc index 1718d49d6..823726251 100644 --- a/src/cunumeric/unary/unary_red_omp.cc +++ b/src/cunumeric/unary/unary_red_omp.cc @@ -72,7 +72,7 @@ class Splitter { size_t pitches_[DIM]; }; -template +template struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; diff --git a/src/cunumeric/unary/unary_red_template.inl b/src/cunumeric/unary/unary_red_template.inl index a144bdc30..1e3b298d3 100644 --- a/src/cunumeric/unary/unary_red_template.inl +++ b/src/cunumeric/unary/unary_red_template.inl @@ -27,12 +27,12 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryRedImplBody; template struct UnaryRedImpl { - template 1) && UnaryRedOp::valid>* = nullptr> void operator()(UnaryRedArgs& args) const @@ -53,7 +53,7 @@ struct UnaryRedImpl { lhs, rhs, rect, pitches, args.collapsed_dim, volume); } - template ::valid>* = nullptr> void operator()(UnaryRedArgs& args) const diff --git a/src/cunumeric/unary/unary_red_util.h b/src/cunumeric/unary/unary_red_util.h index 9296ccfe2..d4ceb007b 100644 --- a/src/cunumeric/unary/unary_red_util.h +++ b/src/cunumeric/unary/unary_red_util.h @@ -72,14 +72,14 @@ constexpr decltype(auto) op_dispatch(UnaryRedCode op_code, Functor f, Fnargs&&.. return f.template operator()(std::forward(args)...); } -template +template struct UnaryRedOp { static constexpr bool valid = false; }; -template +template struct UnaryRedOp { - static constexpr bool valid = TYPE_CODE != legate::LegateTypeCode::COMPLEX128_LT; + static constexpr bool valid = TYPE_CODE != legate::Type::Code::COMPLEX128; using RHS = legate::legate_type_of; using VAL = bool; @@ -100,9 +100,9 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs != RHS(0); } }; -template +template struct UnaryRedOp { - static constexpr bool valid = TYPE_CODE != legate::LegateTypeCode::COMPLEX128_LT; + static constexpr bool valid = TYPE_CODE != legate::Type::Code::COMPLEX128; using RHS = legate::legate_type_of; using VAL = bool; @@ -123,7 +123,7 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs != RHS(0); } }; -template +template struct UnaryRedOp { static constexpr bool valid = true; @@ -146,7 +146,7 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return static_cast(rhs != RHS(0)); } }; -template +template struct UnaryRedOp { static constexpr bool valid = !legate::is_complex::value; @@ -169,7 +169,7 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs; } }; -template +template struct UnaryRedOp { static constexpr bool valid = !legate::is_complex::value; @@ -192,9 +192,9 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs; } }; -template +template struct UnaryRedOp { - static constexpr bool valid = TYPE_CODE != legate::LegateTypeCode::COMPLEX128_LT; + static constexpr bool valid = TYPE_CODE != legate::Type::Code::COMPLEX128; using RHS = legate::legate_type_of; using VAL = RHS; @@ -215,7 +215,7 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs; } }; -template +template struct UnaryRedOp { static constexpr bool valid = true; @@ -238,7 +238,7 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs) { return rhs; } }; -template +template struct UnaryRedOp { static constexpr bool valid = !legate::is_complex::value; @@ -271,7 +271,7 @@ struct UnaryRedOp { } }; -template +template struct UnaryRedOp { static constexpr bool valid = !legate::is_complex::value; @@ -304,7 +304,7 @@ struct UnaryRedOp { } }; -template +template struct UnaryRedOp { // Set to false so that this only gets enabled when expliclty declared valid. static constexpr bool valid = false; @@ -312,7 +312,7 @@ struct UnaryRedOp { // It does not provide fold/convert functions. using RHS = legate::legate_type_of; using VAL = bool; - using _RED_OP = UnaryRedOp; + using _RED_OP = UnaryRedOp; using OP = _RED_OP::OP; }; diff --git a/tests/integration/test_ingest.py b/tests/integration/test_ingest.py index c0f17c37b..272f2080d 100644 --- a/tests/integration/test_ingest.py +++ b/tests/integration/test_ingest.py @@ -13,12 +13,12 @@ # limitations under the License. # import numpy as np -import pyarrow as pa import pytest from legate.core import ( CustomSplit, Rect, TiledSplit, + float64, get_legion_context, get_legion_runtime, ingest, @@ -73,7 +73,7 @@ def _ingest(custom_partitioning, custom_sharding): else TiledSplit(tile_shape) ) tab = ingest( - pa.float64(), + float64, shape, colors, data_split,