Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: allow import of Dimensions from modules within gt4py code #1615

Merged
merged 43 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
56b04f4
import dims in field_op domain in prog
nfarabullini Aug 15, 2024
4210d04
small edit
nfarabullini Aug 15, 2024
c0e41f0
included type in past.AttributeExpr
nfarabullini Aug 15, 2024
bbe4257
some cleanup
nfarabullini Aug 15, 2024
8d5e59e
some cleanup
nfarabullini Aug 15, 2024
c2175db
some cleanup
nfarabullini Aug 15, 2024
ec849aa
edits to test
nfarabullini Aug 15, 2024
8474af8
Update type_deduction.py
nfarabullini Aug 15, 2024
a0e70bd
Update type_deduction.py
nfarabullini Aug 15, 2024
41b6099
Update src/gt4py/next/ffront/program_ast.py
nfarabullini Aug 16, 2024
eb2eaee
Merge branch 'main' into module_import_impl
nfarabullini Aug 16, 2024
5a93b0f
edits to test to include all backends
nfarabullini Aug 16, 2024
e35d898
Update src/gt4py/next/type_system/type_translation.py
nfarabullini Aug 19, 2024
256a5ef
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Aug 19, 2024
ef29b38
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Aug 19, 2024
c7beb36
changes following review
nfarabullini Aug 19, 2024
345da57
changes to allow for module import in field_operator
nfarabullini Aug 21, 2024
c2caba0
small edit to test
nfarabullini Aug 21, 2024
ec1129b
ran pre-commit
nfarabullini Aug 21, 2024
ceb7fac
edits for closure_vars and following chain
nfarabullini Aug 21, 2024
70718e8
pull from Hannes' branch
nfarabullini Aug 22, 2024
f61f55f
edits following merge
nfarabullini Aug 22, 2024
c1ba851
removed edits from other changes
nfarabullini Aug 22, 2024
1112b6d
Merge branch 'main' into module_import_impl
nfarabullini Aug 22, 2024
ed88398
changes following review
nfarabullini Aug 23, 2024
6767447
Merge branch 'module_import_impl' of https://github.com/nfarabullini/…
nfarabullini Aug 23, 2024
ab55b7a
further cleanup
nfarabullini Aug 23, 2024
36d7dcb
small edit
nfarabullini Aug 23, 2024
1c4f6e7
small edits
nfarabullini Aug 23, 2024
b2f8182
small edit
nfarabullini Aug 23, 2024
87dd95e
removed one last unused component
nfarabullini Aug 26, 2024
8bef26b
edit to include dimension module import in field_operator
nfarabullini Aug 27, 2024
a4bcb40
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Aug 27, 2024
41b7815
small import edit
nfarabullini Aug 27, 2024
cbb4fe0
further edits
nfarabullini Aug 27, 2024
8c439de
Merge branch 'master' of https://github.com/nfarabullini/gt4py into m…
nfarabullini Aug 27, 2024
1469761
more edits and other tests
nfarabullini Aug 27, 2024
a11ad76
edits for error tests
nfarabullini Aug 28, 2024
7e7bb03
edits following review
nfarabullini Aug 28, 2024
44c3cc9
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Aug 28, 2024
ab9a137
ran pre-commit
nfarabullini Aug 28, 2024
8021f28
small edit to test
nfarabullini Aug 28, 2024
c7f85e7
edits following review
nfarabullini Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/foast_passes/closure_var_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def visit_Name(
return foast.Constant(value=value, location=node.location)
return node

def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Constant:
def visit_Attribute(
self, node: foast.Attribute, **kwargs: Any
) -> foast.Constant | foast.Attribute:
value = self.visit(node.value, **kwargs)
if isinstance(value, foast.Constant):
if hasattr(value.value, node.attr):
return foast.Constant(value=getattr(value.value, node.attr), location=node.location)
raise errors.MissingAttributeError(node.location, node.attr)
raise errors.DSLError(node.location, "Attribute access only applicable to constants.")
return node

def visit_FunctionDefinition(
self, node: foast.FunctionDefinition, **kwargs: Any
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ def visit_Symbol(
return new_node
return node

def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute:
new_value = self.visit(node.value, **kwargs)
new_type = getattr(new_value.type, node.attr)
if isinstance(new_type, ts.FieldType):
raise errors.DSLError(node.location, "Module imports of Fields not accepted.")
return foast.Attribute(
value=new_value,
attr=node.attr,
location=node.location,
type=new_type,
)

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript:
new_value = self.visit(node.value, **kwargs)
new_type: Optional[ts.TypeSpec] = None
Expand Down
13 changes: 13 additions & 0 deletions src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs: Any) -> past.BinOp:
def visit_Name(self, node: ast.Name) -> past.Name:
return past.Name(id=node.id, location=self.get_location(node))

def visit_Attribute(self, node: ast.Attribute) -> past.Attribute:
if not isinstance(node.ctx, ast.Load):
raise errors.DSLError(
self.get_location(node), "`node.ctx` can only be of type ast.Load"
)
assert isinstance(node.value, (ast.Name, ast.Attribute))

return past.Attribute(
attr=node.attr,
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
value=self.visit(node.value),
location=self.get_location(node),
)

def visit_Dict(self, node: ast.Dict) -> past.Dict:
return past.Dict(
keys_=[self.visit(cast(ast.AST, param)) for param in node.keys],
Expand Down
14 changes: 13 additions & 1 deletion src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Optional, cast

from gt4py.eve import NodeTranslator, traits
Expand Down Expand Up @@ -114,6 +113,19 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript
location=node.location,
)

def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute:
new_value = self.visit(node.value, **kwargs)
new_type = getattr(new_value.type, node.attr)
if isinstance(new_type, ts.FieldType):
raise errors.DSLError(node.location, "Module imports of Fields not accepted.")

return past.Attribute(
value=new_value,
attr=node.attr,
location=node.location,
type=new_type,
)

def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr:
elts = self.visit(node.elts, **kwargs)
return past.TupleExpr(
Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/next/ffront/program_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ class TupleExpr(Expr):
elts: list[Expr]


class Attribute(Expr):
attr: str
value: Expr


class Constant(Expr):
value: Any # TODO(tehrengruber): be more restrictive


class Dict(Expr):
keys_: list[Name]
keys_: list[Union[Name | Attribute]]
values_: list[TupleExpr]


Expand Down
18 changes: 17 additions & 1 deletion src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import builtins
import collections.abc
import dataclasses
import functools
import types
import typing
Expand Down Expand Up @@ -162,10 +165,21 @@ def from_type_hint(
kw_only_args={}, # TODO
returns=returns,
)

raise ValueError(f"'{type_hint}' type is not supported.")


@dataclasses.dataclass(frozen=True)
class UnknownPythonObject(ts.TypeSpec):
_object: Any

def __getattr__(self, key: str) -> ts.TypeSpec:
value = getattr(self._object, key)
return from_value(value)

def __deepcopy__(self, _: dict[int, Any]) -> UnknownPythonObject:
return UnknownPythonObject(self._object) # don't deep copy the module


def from_value(value: Any) -> ts.TypeSpec:
# TODO(tehrengruber): use protocol from gt4py.next.common when available
# instead of importing from the embedded implementation
Expand Down Expand Up @@ -204,6 +218,8 @@ def from_value(value: Any) -> ts.TypeSpec:
elems = [from_value(el) for el in value]
assert all(isinstance(elem, ts.DataType) for elem in elems)
return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert
elif isinstance(value, types.ModuleType):
return UnknownPythonObject(_object=value)
else:
type_ = xtyping.infer_type(value, annotate_callable_kwargs=True)
symbol_type = from_type_hint(type_)
Expand Down
9 changes: 9 additions & 0 deletions tests/next_tests/dummy_package/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from . import dummy_module
19 changes: 19 additions & 0 deletions tests/next_tests/dummy_package/dummy_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
import numpy as np
from next_tests.integration_tests import cases

dummy_int = 42

dummy_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32))


@gtx.field_operator
def field_op_sample(a: cases.IKField) -> cases.IKField:
return a
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import numpy as np

import gt4py.next as gtx
from gt4py.next import broadcast

from next_tests import integration_tests
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import cartesian_case

from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
mesh_descriptor,
)

from next_tests.dummy_package import dummy_module

from next_tests.dummy_package.dummy_module import field_op_sample


def test_import_dims_module(cartesian_case):
@gtx.field_operator
def mod_op(f: cases.IField) -> cases.IKField:
f_i_k = broadcast(f, (cases.IDim, cases.KDim))
return f_i_k

@gtx.program
def mod_prog(f: cases.IField, out: cases.IKField):
mod_op(
f,
out=out,
domain={
integration_tests.cases.IDim: (
0,
8,
), # Nested import done on purpose, do not change
cases.KDim: (0, 3),
},
)

f = cases.allocate(cartesian_case, mod_prog, "f")()
out = cases.allocate(cartesian_case, mod_prog, "out")()
expected = np.zeros_like(out.asnumpy())
expected[0:8, 0:3] = np.reshape(np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape)[
0:8, 0:3
]

cases.verify(cartesian_case, mod_prog, f, out=out, ref=expected)


nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
# TODO: these set of features should be allowed as module imports in a later PR
def test_import_module_errors_future_allowed(cartesian_case):
with pytest.raises(gtx.errors.DSLError):

@gtx.field_operator
def field_op(f: cases.IField):
f_i_k = gtx.broadcast(f, (cases.IDim, cases.KDim))
return f_i_k

with pytest.raises(ValueError):

@gtx.field_operator
def field_op(f: cases.IField):
type_ = gtx.int32
return f

with pytest.raises(gtx.errors.DSLError):

@gtx.field_operator
def field_op(f: cases.IField):
f_new = dummy_module.field_op_sample(f)
return f_new

with pytest.raises(gtx.errors.DSLError):

@gtx.field_operator
def field_op(f: cases.IField):
f_new = field_op_sample(f)
return f_new
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved

@gtx.program
def field_op(f: cases.IField):
dummy_module.field_op_sample(f, out=f, offset_provider={})


@pytest.mark.checks_specific_error
def test_import_module_errors(cartesian_case):
with pytest.raises(gtx.errors.DSLError):
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
new_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32))

@gtx.field_operator(backend=cartesian_case.executor)
def field_op():
f_new = dummy_module.dummy_field
return f_new

field_op(out=new_field, offset_provider={})

with pytest.raises(gtx.errors.DSLError):
new_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32))

@gtx.field_operator(backend=cartesian_case.executor)
def field_op(f: cases.IField):
return f

@gtx.program(backend=cartesian_case.executor)
def program_op(f: cases.IField, out: cases.IField):
field_op(dummy_module.dummy_field, out=out)

program_op(dummy_module.dummy_field, new_field, offset_provider={})
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from gt4py.next import common
from gt4py.next.type_system import type_specifications as ts, type_translation

from ... import dummy_package


class CustomInt32DType:
@property
Expand Down Expand Up @@ -185,3 +187,15 @@ def test_generic_variadic_dims(value, expected_dims):
)
def test_as_from_dtype(dtype):
assert type_translation.as_dtype(type_translation.from_dtype(dtype)) == dtype


def test_from_value_module():
assert isinstance(
type_translation.from_value(dummy_package), type_translation.UnknownPythonObject
)
assert type_translation.from_value(dummy_package).dummy_module.dummy_int == ts.ScalarType(
kind=ts.ScalarKind.INT32
)
assert type_translation.from_value(dummy_package.dummy_module.dummy_int) == ts.ScalarType(
kind=ts.ScalarKind.INT32
)
Loading