Skip to content

Commit

Permalink
[dace] Enable GPU backend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Nov 17, 2023
1 parent c76aeaf commit 5826fe2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import hashlib
import warnings
from typing import Any, Mapping, Optional, Sequence

import dace
Expand Down Expand Up @@ -94,10 +95,26 @@ def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
return {name.id: convert_arg(arg) for name, arg in zip(params, args)}


def _ensure_is_on_device(
connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType
) -> np.typing.NDArray:
if device == dace.dtypes.DeviceType.GPU:
if not isinstance(connectivity_arg, cp.ndarray):
warnings.warn(
"Copying connectivity to device. For performance make sure connectivity is provided on device."
)
return cp.asarray(connectivity_arg)
return connectivity_arg


def get_connectivity_args(
neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]]
neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]],
device: dace.dtypes.DeviceType,
) -> dict[str, Any]:
return {connectivity_identifier(offset): table.table for offset, table in neighbor_tables}
return {
connectivity_identifier(offset): _ensure_is_on_device(table.table, device)
for offset, table in neighbor_tables
}


def get_shape_args(
Expand Down Expand Up @@ -181,6 +198,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
Expand All @@ -200,7 +218,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu)

# compile SDFG and retrieve SDFG program
Expand All @@ -216,7 +233,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables)
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non
OPTIONAL_PROCESSORS = []
if dace_iterator:
OPTIONAL_PROCESSORS.append(definitions.OptionalProgramBackendId.DACE_CPU)
OPTIONAL_PROCESSORS.append(
pytest.param(definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu)
),


@pytest.fixture(
Expand Down
5 changes: 5 additions & 0 deletions tests/next_tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def lift_mode(request):
OPTIONAL_PROCESSORS = []
if dace_iterator:
OPTIONAL_PROCESSORS.append((definitions.OptionalProgramBackendId.DACE_CPU, True))
OPTIONAL_PROCESSORS.append(
pytest.param(
(definitions.OptionalProgramBackendId.DACE_GPU, True), marks=pytest.mark.requires_gpu
)
),


@pytest.fixture(
Expand Down

0 comments on commit 5826fe2

Please sign in to comment.