Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support annotated assignment #3709

Merged
merged 13 commits into from
Dec 7, 2021
47 changes: 47 additions & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,53 @@ def build_Name(ctx, node):
node.ptr = ctx.get_var_by_name(node.id)
return node

@staticmethod
def build_AnnAssign(ctx, node):

woooodyye marked this conversation as resolved.
Show resolved Hide resolved
node.value = build_stmt(ctx, node.value)
node.target = build_stmt(ctx, node.target)
node.annotation = build_stmt(ctx, node.annotation)

is_static_assign = isinstance(
node.value, ast.Call) and ASTResolver.resolve_to(
node.value.func, ti.static, globals())

node.ptr = ASTTransformer.build_ann_assign_basic(
ctx, node.target, node.value.ptr, is_static_assign,
node.annotation.ptr)
return node

@staticmethod
def build_ann_assign_basic(ctx, target, value, is_static_assign,
strongoier marked this conversation as resolved.
Show resolved Hide resolved
annotation):
"""Build basic assginment like this: target : annotation = value.
strongoier marked this conversation as resolved.
Show resolved Hide resolved

Args:
ctx (ast_builder_utils.BuilderContext): The builder context.
target (ast.Name): A variable name. `target.id` holds the name as
a string.
annotation: A type we hope to assign to the target
value: A node representing the value.
is_static_assign: A boolean value indicating whether this is a static assignment
"""
is_local = isinstance(target, ast.Name)
anno = ti.expr_init(annotation)
if is_static_assign:
strongoier marked this conversation as resolved.
Show resolved Hide resolved
raise TaichiSyntaxError(
"Static assign cannot be used on annotated assignment")
if is_local and not ctx.is_var_declared(target.id):
var = ti.cast(value, anno)
var = ti.expr_init(var)

strongoier marked this conversation as resolved.
Show resolved Hide resolved
ctx.create_variable(target.id, var)
else:
var = target.ptr
if var.ptr.get_ret_type() != anno:
raise TaichiSyntaxError(
"Static assign cannot have type overloading")
var.assign(value)
return var

@staticmethod
def build_Assign(ctx, node):
node.value = build_stmt(ctx, node.value)
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,33 @@ def func_assign():
assert a == 1

func_assign()


@ti.test(debug=True)
def test_assign_ann():
@ti.kernel
def func_ann():
#need to introduce it
strongoier marked this conversation as resolved.
Show resolved Hide resolved
my_float = ti.f32
a: ti.i32 = 1
b: ti.f32 = a
d: my_float = 1
assert a == 1
assert b == 1.0
assert d == 1.0

func_ann()


@ti.test()
def test_assign_ann_over():
@ti.kernel
def func_ann_over():
my_int = ti.i32
d: my_int = 2
d: ti.f32 = 2.0

with pytest.raises(ti.lang.exception.TaichiCompilationError) as e:
woooodyye marked this conversation as resolved.
Show resolved Hide resolved
func_ann_over()

assert e.type is ti.lang.exception.TaichiCompilationError
strongoier marked this conversation as resolved.
Show resolved Hide resolved