Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[cartesian]: read-only data dims direct access & Fields #1451

Merged
29 changes: 27 additions & 2 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,18 @@ class ParsingContext(enum.Enum):
COMPUTATION = 2


def _is_absolute_indexing_name(name: str):
FlorianDeconinck marked this conversation as resolved.
Show resolved Hide resolved
return name.endswith(".at")


def _is_absolute_indexing_node(node):
return (
isinstance(node.value, ast.Attribute)
and node.value.attr == "at"
and isinstance(node.value.value, ast.Name)
)


class IRMaker(ast.NodeVisitor):
def __init__(
self,
Expand Down Expand Up @@ -1039,6 +1051,10 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:
result = nodes.VarRef(name=symbol, loc=nodes.Location.from_ast_node(node))
elif self._is_local_symbol(symbol):
raise AssertionError("Logic error")
elif _is_absolute_indexing_name(symbol):
result = nodes.FieldRef.absolute_index(
name=symbol[:-3], loc=nodes.Location.from_ast_node(node)
)
else:
raise AssertionError(f"Missing '{symbol}' symbol definition")

Expand Down Expand Up @@ -1147,12 +1163,16 @@ def visit_Subscript(self, node: ast.Subscript):
field_axes = self.fields[result.name].axes
if index is not None:
if len(field_axes) != len(index):
ro_field_message = ""
if len(field_axes) == 0:
ro_field_message = f"Did you mean .at{index}?"
raise GTScriptSyntaxError(
f"Incorrect offset specification detected. Found {index}, "
f"but the field has dimensions ({', '.join(field_axes)})"
f"but the field has dimensions ({', '.join(field_axes)}). "
f"{ro_field_message}"
)
result.offset = {axis: value for axis, value in zip(field_axes, index)}
elif isinstance(node.value, ast.Subscript):
elif isinstance(node.value, ast.Subscript) or _is_absolute_indexing_node(node):
result.data_index = [
(
nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32)
Expand Down Expand Up @@ -1605,6 +1625,11 @@ def visit_Assign(self, node: ast.Assign):
elif isinstance(t, ast.Subscript):
if isinstance(t.value, ast.Name):
name_node = t.value
elif _is_absolute_indexing_node(t):
raise GTScriptSyntaxError(
message="writing to an OffgridField ('at' global indexation) is forbidden",
loc=nodes.Location.from_ast_node(node),
)
elif isinstance(t.value, ast.Subscript) and isinstance(t.value.value, ast.Name):
name_node = t.value.value
else:
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def at_center(
name=name, offset={axis: 0 for axis in axes}, data_index=data_index or [], loc=loc
)

@classmethod
def absolute_index(cls, name: str, loc=None):
return cls(name=name, offset={}, data_index=[], loc=loc)


@attribclass
class Cast(Expr):
Expand Down
13 changes: 13 additions & 0 deletions src/gt4py/cartesian/gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,23 @@ def __getitem__(self, field_spec):
return _FieldDescriptor(dtype, axes, data_dims)


class _ReadOnlyFieldDescriptorMaker(_FieldDescriptorMaker):
def __getitem__(self, field_spec):
if not isinstance(field_spec, collections.abc.Collection) and not len(field_spec) == 2:
raise ValueError("OffgridField is defined by a tuple (type, [axes_size..])")

dtype, data_dims = field_spec

return _FieldDescriptor(dtype, [], data_dims)


# GTScript builtins: variable annotations
Field = _FieldDescriptorMaker()
"""Field descriptor."""

OffgridField = _ReadOnlyFieldDescriptorMaker()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally like ConstantField (constant with respect to the grid), but I think not everyone likes it for confusion with non-mutable. Here it both meanings coincide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantField would be confusing for users I think. Though, indeed, the system make it constant in stencils, the Field is still modifiable outside of gtscript

"""Field with no spatial dimension descriptor."""


class _SequenceDescriptor:
def __init__(self, dtype, length):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Field,
I,
J,
OffgridField,
computation,
horizontal,
interval,
Expand Down Expand Up @@ -586,3 +587,19 @@ def test(out: Field[np.float64], inp: Field[np.float64]):
backend=backend, aligned_index=(0, 0, 0), shape=(2, 2, 2), dtype=np.float64
)
test(out, inp)


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_global_index(backend):
F64_VEC4 = (np.float64, (2, 2, 2, 2))

@gtscript.stencil(backend=backend)
def test(out: Field[np.float64], inp: OffgridField[F64_VEC4]):
with computation(PARALLEL), interval(...):
out = inp.at[1, 0, 1, 0]

inp = gt_storage.ones(backend=backend, shape=(2, 2, 2, 2), dtype=np.float64)
inp[1, 0, 1, 0] = 42
out = gt_storage.zeros(backend=backend, shape=(2, 2, 2), dtype=np.float64)
test(out, inp)
assert (out[:] == 42).all()
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,71 @@ def func(in_field: gtscript.Field[np.float_]):

parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__)

def test_global_access(self):
# Check classic data dimensions are working
def data_dims(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.Field[(np.int32, (3, 3, 3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field[0, 0, 0][1, 0, 2]

parse_definition(data_dims, name=inspect.stack()[0][3], module=self.__class__.__name__)

# Check .at on read
def at_read(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.OffgridField[(np.int32, (3, 3, 3, 3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field.at[1, 0, 2, 2]

parse_definition(at_read, name=inspect.stack()[0][3], module=self.__class__.__name__)

# Can't write to the field
def at_write(
in_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.OffgridField[(np.int32, (3, 3, 3))],
):
with computation(PARALLEL), interval(...):
global_field.at[1, 0, 2] = in_field

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="writing to an OffgridField \('at' global indexation\) is forbidden",
):
parse_definition(at_write, name=inspect.stack()[0][3], module=self.__class__.__name__)

# Can't index cartesian style
def OffgridField_access_as_IJK(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.OffgridField[(np.int32, (3, 3, 3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field[1, 0, 2]

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Incorrect offset specification detected. Found .* but the field has dimensions .* Did you mean .at",
):
parse_definition(
OffgridField_access_as_IJK,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

# Check .at on read with a Field with data dimensions
def data_dims_with_at(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.Field[(np.int32, (3, 3, 3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field.at[1, 0, 2]

parse_definition(
data_dims_with_at, name=inspect.stack()[0][3], module=self.__class__.__name__
)


class TestNestedWithSyntax:
def test_nested_with(self):
Expand Down
Loading