Skip to content

Commit

Permalink
fix[next][dace]: Bugfix in neighbor reduction (#1456)
Browse files Browse the repository at this point in the history
When visiting a neighbor expression, the DaCe ITIR backend was ignoring the node arguments and using directly the closure symbol i_K . The problem found in one diffusion stencil is that the arguments contained a vertical offset, which returns (i_K - 1), and this offset was lost.

This PR adds test coverage to GT4Py for the above case.
  • Loading branch information
edopao authored Feb 16, 2024
1 parent 2970575 commit 4276d01
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,37 +400,46 @@ def builtin_neighbors(
neighbor_value_node,
)
else:
data_access_index = ",".join(f"{dim}_v" for dim in sorted(iterator.dimensions))
connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v"
data_access_tasklet = state.add_tasklet(
"data_access",
code="__data = __field[__idx]"
code=f"__data = __field[{data_access_index}] "
+ (
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
if offset_provider.has_skip_values
else ""
),
inputs={"__field", "__idx"},
inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions},
outputs={"__data"},
debuginfo=di,
)
# select full shape only in the neighbor-axis dimension
field_subset = tuple(
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=create_memlet_at(iterator.field.data, field_subset),
memlet=create_memlet_full(iterator.field.data, field_desc),
dst_conn="__field",
)
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
"__idx",
dace.Memlet(data=neighbor_index_var, subset="0"),
)
for dim in iterator.dimensions:
connector = f"{dim}_v"
if dim == offset_provider.neighbor_axis.value:
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
connector,
dace.Memlet(data=neighbor_index_var, subset="0"),
)
else:
state.add_memlet_path(
iterator.indices[dim],
me,
data_access_tasklet,
dst_conn=connector,
memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"),
)

state.add_memlet_path(
data_access_tasklet,
mx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import TypeAlias

import numpy as np
import pytest
Expand All @@ -29,7 +30,9 @@
JDim,
Joff,
KDim,
Koff,
V2EDim,
Vertex,
cartesian_case,
unstructured_case,
)
Expand Down Expand Up @@ -108,6 +111,44 @@ def fencil(edge_f: cases.EField, out: cases.VField):
)


@pytest.mark.uses_unstructured_shift
def test_reduction_execution_with_offset(unstructured_case):
EKField: TypeAlias = gtx.Field[[Edge, KDim], np.int32]
VKField: TypeAlias = gtx.Field[[Vertex, KDim], np.int32]

@gtx.field_operator
def reduction(edge_f: EKField) -> VKField:
return neighbor_sum(edge_f(V2E), axis=V2EDim)

@gtx.field_operator
def fencil_op(edge_f: EKField) -> VKField:
red = reduction(edge_f)
return red(Koff[1])

@gtx.program
def fencil(edge_f: EKField, out: VKField):
fencil_op(edge_f, out=out)

v2e_table = unstructured_case.offset_provider["V2E"].table
field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})()
out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})()

cases.verify(
unstructured_case,
fencil,
field,
out,
inout=out,
ref=np.sum(
field.asnumpy()[:, 1][v2e_table],
axis=1,
initial=0,
where=v2e_table != common.SKIP_VALUE,
).reshape(out.shape),
offset_provider=unstructured_case.offset_provider | {"Koff": KDim},
)


@pytest.mark.uses_unstructured_shift
@pytest.mark.uses_constant_fields
def test_reduction_expression_in_call(unstructured_case):
Expand Down

0 comments on commit 4276d01

Please sign in to comment.