Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support annotated assignment #3709

Merged
merged 13 commits into from
Dec 7, 2021
Merged
71 changes: 71 additions & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,53 @@ def build_Name(ctx, node):
node.ptr = ctx.get_var_by_name(node.id)
return node

@staticmethod
def build_AnnAssign(ctx, node):
annotation = getattr(node.annotation, 'attr')
strongoier marked this conversation as resolved.
Show resolved Hide resolved

woooodyye marked this conversation as resolved.
Show resolved Hide resolved
node.value = build_stmt(ctx, node.value)
node.target = build_stmt(ctx, node.target)

is_static_assign = isinstance(
node.value, ast.Call) and ASTResolver.resolve_to(
node.value.func, ti.static, globals())

node.ptr = ASTTransformer.build_ann_assign_basic(
ctx, node.target, node.value.ptr, is_static_assign, annotation)
return node

@staticmethod
def build_ann_assign_basic(ctx, target, value, is_static_assign,
strongoier marked this conversation as resolved.
Show resolved Hide resolved
annotation):
"""Build basic assginment like this: target : annotation = value.
strongoier marked this conversation as resolved.
Show resolved Hide resolved

Args:
ctx (ast_builder_utils.BuilderContext): The builder context.
target (ast.Name): A variable name. `target.id` holds the name as
a string.
annotation: A type we hope to assign to the target
value: A node representing the value.
is_static_assign: A boolean value indicating whether this is a static assignment
"""
is_local = isinstance(target, ast.Name)
if is_static_assign:
strongoier marked this conversation as resolved.
Show resolved Hide resolved
if not is_local:
raise TaichiSyntaxError(
"Static assign cannot be used on elements in arrays")
ctx.create_variable(target.id, value)
var = value
elif is_local and not ctx.is_var_declared(target.id):
var = cast_type(value, annotation)
var = ti.expr_init(var)
ctx.create_variable(target.id, var)
else:
var = target.ptr
if str(var.ptr.get_ret_type()) != annotation:
strongoier marked this conversation as resolved.
Show resolved Hide resolved
raise TaichiSyntaxError(
"Static assign cannot have type overloading")
var.assign(value)
return var

@staticmethod
def build_Assign(ctx, node):
node.value = build_stmt(ctx, node.value)
Expand Down Expand Up @@ -1018,3 +1065,27 @@ def build_stmts(ctx, stmts):
else:
result.append(stmt)
return result


def cast_type(expr, annotation):
if annotation == 'i8':
return ti.cast(expr, ti.i8)
if annotation == 'i16':
return ti.cast(expr, ti.i16)
if annotation == 'i32':
return ti.cast(expr, ti.i32)
if annotation == 'i64':
return ti.cast(expr, ti.i64)
if annotation == 'u8':
return ti.cast(expr, ti.u8)
if annotation == 'u16':
return ti.cast(expr, ti.u16)
if annotation == 'u32':
return ti.cast(expr, ti.u32)
if annotation == 'u64':
return ti.cast(expr, ti.u64)
if annotation == 'f32':
return ti.cast(expr, ti.f32)
if annotation == 'f64':
return ti.cast(expr, ti.f64)
raise TaichiSyntaxError("Typed assign must be a supported primitive type")
6 changes: 3 additions & 3 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class CCTransformer : public IRVisitor {
static inline std::string invoke_libc(std::string name,
DataType dt,
std::string const &fmt,
Args &&... args) {
Args &&...args) {
auto arguments = fmt::format(fmt, std::forward<Args>(args)...);
return invoke_libc(name, dt, arguments);
}
Expand Down Expand Up @@ -590,12 +590,12 @@ class CCTransformer : public IRVisitor {
}

template <typename... Args>
void emit(std::string f, Args &&... args) {
void emit(std::string f, Args &&...args) {
line_appender.append(std::move(f), std::move(args)...);
}

template <typename... Args>
void emit_header(std::string f, Args &&... args) {
void emit_header(std::string f, Args &&...args) {
line_appender_header.append(std::move(f), std::move(args)...);
}
}; // namespace cccp
Expand Down
12 changes: 12 additions & 0 deletions tests/python/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ def func_assign():
assert a == 1

func_assign()


@ti.test(debug=True)
def test_ann_assign():
@ti.kernel
def func_ann():
a: ti.i32 = 1
b: ti.f32 = a
assert a == 1
assert b == 1.0

func_ann()