Skip to content

Commit

Permalink
edits following review
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Aug 28, 2024
1 parent 8021f28 commit c7f85e7
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 35 deletions.
5 changes: 1 addition & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,11 @@ def visit_Symbol(

def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Attribute:
new_value = self.visit(node.value, **kwargs)
new_type = getattr(new_value.type, node.attr)
if isinstance(new_type, ts.FieldType):
raise errors.DSLError(node.location, "Module imports of Fields not accepted.")
return foast.Attribute(
value=new_value,
attr=node.attr,
location=node.location,
type=new_type,
type=getattr(new_value.type, node.attr),
)

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscript:
Expand Down
6 changes: 1 addition & 5 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,11 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript

def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute:
new_value = self.visit(node.value, **kwargs)
new_type = getattr(new_value.type, node.attr)
if isinstance(new_type, ts.FieldType):
raise errors.DSLError(node.location, "Module imports of Fields not accepted.")

return past.Attribute(
value=new_value,
attr=node.attr,
location=node.location,
type=new_type,
type=getattr(new_value.type, node.attr),
)

def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,3 @@ def field_op(f: cases.IField):
@gtx.program
def field_op(f: cases.IField):
dummy_module.field_op_sample(f, out=f, offset_provider={})


@pytest.mark.checks_specific_error
def test_import_module_errors(cartesian_case):
with pytest.raises(gtx.errors.DSLError):
new_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32))

@gtx.field_operator(backend=cartesian_case.executor)
def field_op():
f_new = dummy_module.dummy_field
return f_new

field_op(out=new_field, offset_provider={})

with pytest.raises(gtx.errors.DSLError):
new_field = gtx.as_field([cases.IDim], np.ones((10,), dtype=gtx.int32))

@gtx.field_operator(backend=cartesian_case.executor)
def field_op(f: cases.IField):
return f

@gtx.program(backend=cartesian_case.executor)
def program_op(out: cases.IField):
field_op(dummy_module.dummy_field, out=out)

program_op(new_field, offset_provider={})

0 comments on commit c7f85e7

Please sign in to comment.