diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index c9f4f76097..e807e7a257 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -48,13 +48,15 @@ def visit_Name( return foast.Constant(value=value, location=node.location) return node - def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Constant: + def visit_Attribute( + self, node: foast.Attribute, **kwargs: Any + ) -> foast.Constant | foast.Attribute: value = self.visit(node.value, **kwargs) if isinstance(value, foast.Constant): if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise errors.MissingAttributeError(node.location, node.attr) - raise errors.DSLError(node.location, "Attribute access only applicable to constants.") + return node def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs: Any diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index da77eba8b1..e12534da92 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -468,6 +468,15 @@ def visit_Symbol( return new_node return node + def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute: + new_value = self.visit(node.value, **kwargs) + return foast.Attribute( + value=new_value, + attr=node.attr, + location=node.location, + type=getattr(new_value.type, node.attr), + ) + def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript: new_value = self.visit(node.value, **kwargs) new_type: Optional[ts.TypeSpec] = None diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 3eac039af4..e1de316b15 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -162,6 +162,19 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs: Any) -> past.BinOp: def visit_Name(self, node: ast.Name) -> past.Name: return past.Name(id=node.id, location=self.get_location(node)) + def visit_Attribute(self, node: ast.Attribute) -> past.Attribute: + if not isinstance(node.ctx, ast.Load): + raise errors.DSLError( + self.get_location(node), "`node.ctx` can only be of type ast.Load" + ) + assert isinstance(node.value, (ast.Name, ast.Attribute)) + + return past.Attribute( + attr=node.attr, + value=self.visit(node.value), + location=self.get_location(node), + ) + def visit_Dict(self, node: ast.Dict) -> past.Dict: return past.Dict( keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 357f6e0799..92f7327218 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -5,7 +5,6 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - from typing import Any, Optional, cast from gt4py.eve import NodeTranslator, traits @@ -114,6 +113,15 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript location=node.location, ) + def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute: + new_value = self.visit(node.value, **kwargs) + return past.Attribute( + value=new_value, + attr=node.attr, + location=node.location, + type=getattr(new_value.type, node.attr), + ) + def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) return past.TupleExpr( diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index 83a3f87809..ea579aa211 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -72,12 +72,17 @@ class TupleExpr(Expr): elts: list[Expr] +class Attribute(Expr): + attr: str + value: Expr + + class Constant(Expr): value: Any # TODO(tehrengruber): be more restrictive class Dict(Expr): - keys_: list[Name] + keys_: list[Union[Name | Attribute]] values_: list[TupleExpr] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 90b8e9ce47..62a6781316 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -6,8 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import builtins import collections.abc +import dataclasses import functools import types import typing @@ -162,10 +165,21 @@ def from_type_hint( kw_only_args={}, # TODO returns=returns, ) - raise ValueError(f"'{type_hint}' type is not supported.") +@dataclasses.dataclass(frozen=True) +class UnknownPythonObject(ts.TypeSpec): + _object: Any + + def __getattr__(self, key: str) -> ts.TypeSpec: + value = getattr(self._object, key) + return from_value(value) + + def __deepcopy__(self, _: dict[int, Any]) -> UnknownPythonObject: + return UnknownPythonObject(self._object) # don't deep copy the module + + def from_value(value: Any) -> ts.TypeSpec: # TODO(tehrengruber): use protocol from gt4py.next.common when available # instead of importing from the embedded implementation @@ -204,6 +218,8 @@ def from_value(value: Any) -> ts.TypeSpec: elems = [from_value(el) for el in value] assert all(isinstance(elem, ts.DataType) for elem in elems) return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert + elif isinstance(value, types.ModuleType): + return UnknownPythonObject(_object=value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) diff --git a/tests/next_tests/dummy_package/__init__.py b/tests/next_tests/dummy_package/__init__.py new file mode 100644 index 0000000000..3a81af6157 --- /dev/null +++ b/tests/next_tests/dummy_package/__init__.py @@ -0,0 +1,9 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from . import dummy_module diff --git a/tests/next_tests/dummy_package/dummy_module.py b/tests/next_tests/dummy_package/dummy_module.py new file mode 100644 index 0000000000..2a56a39a6c --- /dev/null +++ b/tests/next_tests/dummy_package/dummy_module.py @@ -0,0 +1,19 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import gt4py.next as gtx +import numpy as np +from next_tests.integration_tests import cases + +dummy_int = 42 + +dummy_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32)) + + +@gtx.field_operator +def field_op_sample(a: cases.IKField) -> cases.IKField: + return a diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py new file mode 100644 index 0000000000..87bf0e5bd7 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py @@ -0,0 +1,88 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +import gt4py.next as gtx +from gt4py.next import broadcast + +from next_tests import integration_tests +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) + +from next_tests.dummy_package import dummy_module + + +def test_import_dims_module(cartesian_case): + @gtx.field_operator + def mod_op(f: cases.IField) -> cases.IKField: + f_i_k = broadcast(f, (cases.IDim, cases.KDim)) + return f_i_k + + @gtx.program + def mod_prog(f: cases.IField, out: cases.IKField): + mod_op( + f, + out=out, + domain={ + integration_tests.cases.IDim: ( + 0, + 8, + ), # Nested import done on purpose, do not change + cases.KDim: (0, 3), + }, + ) + + f = cases.allocate(cartesian_case, mod_prog, "f")() + out = cases.allocate(cartesian_case, mod_prog, "out")() + expected = np.zeros_like(out.asnumpy()) + expected[0:8, 0:3] = np.reshape(np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape)[ + 0:8, 0:3 + ] + + cases.verify(cartesian_case, mod_prog, f, out=out, ref=expected) + + +# TODO: these set of features should be allowed as module imports in a later PR +def test_import_module_errors_future_allowed(cartesian_case): + with pytest.raises(gtx.errors.DSLError): + + @gtx.field_operator + def field_op(f: cases.IField): + f_i_k = gtx.broadcast(f, (cases.IDim, cases.KDim)) + return f_i_k + + with pytest.raises(ValueError): + + @gtx.field_operator + def field_op(f: cases.IField): + type_ = gtx.int32 + return f + + with pytest.raises(gtx.errors.DSLError): + + @gtx.field_operator + def field_op(f: cases.IField): + f_new = dummy_module.field_op_sample(f) + return f_new + + with pytest.raises(gtx.errors.DSLError): + + @gtx.field_operator + def field_op(f: cases.IField): + return f + + @gtx.program + def field_op(f: cases.IField): + dummy_module.field_op_sample(f, out=f, offset_provider={}) diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index 793d851745..e9f33ade2e 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -18,6 +18,8 @@ from gt4py.next import common from gt4py.next.type_system import type_specifications as ts, type_translation +from ... import dummy_package + class CustomInt32DType: @property @@ -185,3 +187,15 @@ def test_generic_variadic_dims(value, expected_dims): ) def test_as_from_dtype(dtype): assert type_translation.as_dtype(type_translation.from_dtype(dtype)) == dtype + + +def test_from_value_module(): + assert isinstance( + type_translation.from_value(dummy_package), type_translation.UnknownPythonObject + ) + assert type_translation.from_value(dummy_package).dummy_module.dummy_int == ts.ScalarType( + kind=ts.ScalarKind.INT32 + ) + assert type_translation.from_value(dummy_package.dummy_module.dummy_int) == ts.ScalarType( + kind=ts.ScalarKind.INT32 + )