diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 6a70b4d8..5e917e66 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -737,6 +737,9 @@ def axis_offsets( "local_js": gtscript.J[0] + self.jsc - origin[1], "j_end": j_end, "local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1, + "k_start": origin[2] if len(origin) > 2 else 0, + "k_end": (origin[2] if len(origin) > 2 else 0) + + (domain[2] - 1 if len(domain) > 2 else 0), } def get_origin_domain( diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index ac189ad8..2af1218d 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -107,6 +107,23 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str): np.testing.assert_array_equal(q_orig.data, q_ref.data) +def test_stencil_vertical_bounds(backend: str): + factory = get_stencil_factory(backend) + origins = [(3, 3, 0), (2, 2, 1)] + domains = [(1, 1, 3), (2, 2, 4)] + stencils = get_stencils_with_varied_bounds( + add_1_in_region_stencil, + origins, + domains, + stencil_factory=factory, + ) + + assert "k_start" in stencils[0].externals and stencils[0].externals["k_start"] == 0 + assert "k_end" in stencils[0].externals and stencils[0].externals["k_end"] == 2 + assert "k_start" in stencils[1].externals and stencils[1].externals["k_start"] == 1 + assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 4 + + @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool): backend = "numpy"