-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature[next]: Nested scalars args & cleanup #1540
feature[next]: Nested scalars args & cleanup #1540
Conversation
" tuple) need to have the same shape and dimensions." | ||
) | ||
size_args.extend(shape if shape else [None] * len(dims)) | ||
if shapes_and_dims: # scalar or zero-dim field otherwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean we allow writing to a scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the implementation _field_constituents_shape_and_dims
, it seems the comment is wrong as 0d fields probably returns a tuple of empty tuples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean we allow writing to a scalar?
Not sure what motivated this question. This code block only extracts size arguments, the change here merely avoids errors for cases like (scalar, (field1, field2))
. I've added a comment that explains this a little more.
Looking at the implementation _field_constituents_shape_and_dims, it seems the comment is wrong as 0d fields probably returns a tuple of empty tuples.
The comment was alright though not helpful^^, I've fixed _field_constituents_shape_and_dims
which was wrong and added a test test_zero_dim_tuple_arg
.
.filter(lambda dims: len(dims) > 0) | ||
.to_list() | ||
) | ||
if len(fields_dims) > 0: # param has no field constituent otherwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is a "field constituent"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A constituent of a composite with type field, e.g. here (scalar, (field, scalar))
the composite is a tuple and field
is a field constituent. I have introduced the nomenclature composite and constituent as a generalization of tuple
s and struct
s where the constituents are the elements and members respectively. Given the amount of question marks this caused we should probably revisit this naming.
@@ -676,11 +676,23 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: | |||
|
|||
def _get_axes( | |||
field_or_tuple: LocatedField | tuple, | |||
*, | |||
ignore_zero_dims=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the PR it's not obvious to me how this is used. Please add an itir test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added doctests here. The motivation is similar as in the frontend for (scalar, (field1, field2))
we sometimes only want to check that field1
and field2
have the same dimensions, but don't care about scalar
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes in this file related to the newly added tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are required for "mixed" tuple args, e.g. an argument with type (scalar, (field, scalar))
. This is the code path in make_in_iterator
where this function is called.
for input_ in inputs: | ||
lowered_input = self.visit(input_, **kwargs) | ||
|
||
def _convert_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- unclear why this should be a closure
_convert_input
-> better name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored this section and the convert_output
case below to use a common utility function _process_elements
.
raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") | ||
lowered_output = self.visit(node) | ||
|
||
def _convert_output(el_type: ts.ScalarType | ts.FieldType, path: tuple[int, ...]) -> Expr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment at _convert_input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above.
src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py
Show resolved
Hide resolved
src/gt4py/next/config.py
Outdated
@@ -61,7 +61,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: | |||
#: Master debug flag | |||
#: Changes defaults for all the other options to be as helpful for debugging as possible. | |||
#: Does not override values set in environment variables. | |||
DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) | |||
DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What motivated you to change this here instead of setting the env var? Anyway, this is a reminder to revert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't figured out how to have default env vars for test execution in a project in PyCharm. Reverted.
if dims: | ||
assert hasattr(arg, "shape") and len(arg.shape) == len(dims) | ||
yield (arg.shape, dims) | ||
else: | ||
yield (tuple(), dims) | ||
pass | ||
case ts.ScalarType(): | ||
yield (tuple(), []) | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constructs looks weird to me, I guess it's equivalent to replacing pass
by return
which might be slightly more expressive, but still ugly. Alternatives could be yield from []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used yield from []
now.
@@ -676,11 +676,23 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: | |||
|
|||
def _get_axes( | |||
field_or_tuple: LocatedField | tuple, | |||
*, | |||
ignore_zero_dims=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes in this file related to the newly added tests?
src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py
Outdated
Show resolved
Hide resolved
@@ -895,7 +935,7 @@ def deref(self) -> Any: | |||
|
|||
assert self.pos is not None | |||
shifted_pos = self.pos.copy() | |||
axes = _get_axes(self.field) | |||
axes = _get_axes(self.field, ignore_zero_dims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:
- write iterator tests
- check if
(vertex_field, vertex_k_field)
should also be valid here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I've added two tests.
- I've added two test named
test_tuple_arg_with_unpromotable_dims
andtest_scalar_arg_with_field
. The implementation is beyond the scope of this PR.
…ple args with different dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 remarks
if ( | ||
isinstance(expr, FunCall) | ||
and isinstance(expr.fun, SymRef) | ||
and expr.fun.id == "tuple_get" | ||
and len(expr.args) == 2 | ||
and _is_ref_or_tuple_expr_of_ref(expr.args[1]) | ||
): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this case tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed tested, the _process_elements
creates such gtfn_ir
expressions for example in the test_multicopy
case.
as_fieldop
are detected using the new type inference.