diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 169d09d691272..9ffc80db38c0e 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -486,6 +486,12 @@ def transform_as_kernel(): arg.arg, kernel_arguments.decl_matrix_arg( ctx.func.arguments[i].annotation)) + elif isinstance(ctx.func.arguments[i].annotation, + primitive_types.RefType): + ctx.create_variable( + arg.arg, + kernel_arguments.decl_scalar_arg( + ctx.func.arguments[i].annotation)) else: ctx.global_vars[ arg.arg] = kernel_arguments.decl_scalar_arg( diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index d512f2ceb5a66..05adeb18316d7 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -8,7 +8,7 @@ from taichi.lang.expr import Expr from taichi.lang.matrix import Matrix, MatrixType from taichi.lang.util import cook_dtype -from taichi.types.primitive_types import u64 +from taichi.types.primitive_types import RefType, u64 class KernelArgument: @@ -47,9 +47,13 @@ def subscript(self, i, j): def decl_scalar_arg(dtype): + is_ref = False + if isinstance(dtype, RefType): + is_ref = True + dtype = dtype.tp dtype = cook_dtype(dtype) arg_id = impl.get_runtime().prog.decl_arg(dtype, False) - return Expr(_ti_core.make_arg_load_expr(arg_id, dtype)) + return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref)) def decl_matrix_arg(matrixtype): @@ -63,8 +67,8 @@ def decl_sparse_matrix(dtype): ptr_type = cook_dtype(u64) # Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer arg_id = impl.get_runtime().prog.decl_arg(ptr_type, False) - return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type), - value_type) + return SparseMatrixProxy( + _ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type) def decl_ndarray_arg(dtype, dim, element_shape, layout): diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 0b820c11eac04..418b51bb8bba3 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -244,6 +244,9 @@ def func_call_rvalue(self, key, args): if not isinstance(anno, template): if id(anno) in primitive_types.type_ids: non_template_args.append(ops.cast(args[i], anno)) + elif isinstance(anno, primitive_types.RefType): + non_template_args.append( + _ti_core.make_reference(args[i].ptr)) else: non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args) @@ -302,7 +305,8 @@ def extract_arguments(self): else: if not id(annotation ) in primitive_types.type_ids and not isinstance( - annotation, template): + annotation, template) and not isinstance( + annotation, primitive_types.RefType): raise TaichiSyntaxError( f'Invalid type annotation (argument {i}) of Taichi function: {annotation}' ) diff --git a/python/taichi/types/primitive_types.py b/python/taichi/types/primitive_types.py index 726da1831a32a..3149ffc450a8b 100644 --- a/python/taichi/types/primitive_types.py +++ b/python/taichi/types/primitive_types.py @@ -141,6 +141,16 @@ # ---------------------------------------- + +class RefType: + def __init__(self, tp): + self.tp = tp + + +def ref(tp): + return RefType(tp) + + real_types = [f16, f32, f64, float] real_type_ids = [id(t) for t in real_types] @@ -173,4 +183,5 @@ 'u32', 'uint64', 'u64', + 'ref', ] diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index ec23e8be085ef..4a018afa6bf47 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -35,6 +35,8 @@ std::vector get_load_pointers(Stmt *load_stmt) { return std::vector(1, stack_pop->stack); } else if (auto external_func = load_stmt->cast()) { return external_func->arg_stmts; + } else if (auto ref = load_stmt->cast()) { + return {ref->var}; } else { return std::vector(); } diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 37469c154c941..22f701abb7cb5 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1086,6 +1086,9 @@ llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) { llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) { auto intermediate_bits = 0; + if (type.is_pointer()) { + return builder->CreatePtrToInt(val, tlctx->get_data_type()); + } if (auto cit = type->cast()) { intermediate_bits = data_type_bits(cit->get_compute_type()); } else { @@ -1109,8 +1112,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::Type *dest_ty = nullptr; if (stmt->is_ptr) { - dest_ty = - llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); + dest_ty = llvm::PointerType::get( + tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type); @@ -2460,6 +2463,10 @@ llvm::Value *CodeGenLLVM::create_mesh_xlogue(std::unique_ptr &block) { return xlogue; } +void CodeGenLLVM::visit(ReferenceStmt *stmt) { + llvm_val[stmt] = llvm_val[stmt->var]; +} + void CodeGenLLVM::visit(FuncCallStmt *stmt) { if (!func_map.count(stmt->func)) { auto guard = get_function_creation_guard( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index bbea19ba60dd6..51707e7ae6dfc 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -369,6 +369,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(MeshPatchIndexStmt *stmt) override; + void visit(ReferenceStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index faba902169142..4ec43c58357f9 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -19,3 +19,4 @@ PER_EXPRESSION(FuncCallExpression) PER_EXPRESSION(MeshPatchIndexExpression) PER_EXPRESSION(MeshRelationAccessExpression) PER_EXPRESSION(MeshIndexConversionExpression) +PER_EXPRESSION(ReferenceExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index b26a942860a8a..c40c89290afd8 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -18,6 +18,7 @@ PER_STATEMENT(FuncCallStmt) PER_STATEMENT(ReturnStmt) PER_STATEMENT(ArgLoadStmt) +PER_STATEMENT(ReferenceStmt) PER_STATEMENT(ExternalPtrStmt) PER_STATEMENT(PtrOffsetStmt) PER_STATEMENT(ConstStmt) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 92c882bcd19ff..f6bb7e607f2ef 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -216,6 +216,12 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { emit(")"); } + void visit(ReferenceExpression *expr) override { + emit("ref("); + expr->var->accept(this); + emit(")"); + } + static std::string expr_to_string(Expr &expr) { std::ostringstream oss; ExpressionHumanFriendlyPrinter printer(&oss); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 627ff3592cf26..95acab8df2f0c 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -121,7 +121,7 @@ void ArgLoadExpression::type_check(CompileConfig *) { } void ArgLoadExpression::flatten(FlattenContext *ctx) { - auto arg_load = std::make_unique(arg_id, dt); + auto arg_load = std::make_unique(arg_id, dt, is_ptr); ctx->push_back(std::move(arg_load)); stmt = ctx->back_stmt(); } @@ -485,17 +485,19 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { op_type = AtomicOpType::add; } // expand rhs - auto expr = val; - flatten_rvalue(expr, ctx); + flatten_rvalue(val, ctx); + auto src_val = val->stmt; if (dest.is()) { // local variable // emit local store stmt auto alloca = ctx->current_block->lookup_var(dest.cast()->id); - ctx->push_back(op_type, alloca, expr->stmt); + ctx->push_back(op_type, alloca, src_val); } else { TI_ASSERT(dest.is() || - dest.is()); + dest.is() || + (dest.is() && + dest.cast()->is_ptr)); flatten_lvalue(dest, ctx); - ctx->push_back(op_type, dest->stmt, expr->stmt); + ctx->push_back(op_type, dest->stmt, src_val); } stmt = ctx->back_stmt(); stmt->tb = tb; @@ -625,6 +627,16 @@ void MeshIndexConversionExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void ReferenceExpression::type_check(CompileConfig *) { + ret_type = var->ret_type; +} + +void ReferenceExpression::flatten(FlattenContext *ctx) { + flatten_lvalue(var, ctx); + ctx->push_back(var->stmt); + stmt = ctx->back_stmt(); +} + Block *ASTBuilder::current_block() { if (stack_.empty()) return nullptr; @@ -945,6 +957,9 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { else { TI_NOT_IMPLEMENTED } + } else if (ptr.is() && + ptr.cast()->is_ptr) { + flatten_global_load(ptr, ctx); } } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 2e6ef5bf1c9b8..d34f7b8274c7f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -276,14 +276,20 @@ class ArgLoadExpression : public Expression { public: int arg_id; DataType dt; + bool is_ptr; - ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) { + ArgLoadExpression(int arg_id, DataType dt, bool is_ptr = false) + : arg_id(arg_id), dt(dt), is_ptr(is_ptr) { } void type_check(CompileConfig *config) override; void flatten(FlattenContext *ctx) override; + bool is_lvalue() const override { + return is_ptr; + } + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; @@ -727,6 +733,19 @@ class MeshIndexConversionExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class ReferenceExpression : public Expression { + public: + Expr var; + void type_check(CompileConfig *config) override; + + ReferenceExpression(const Expr &expr) : var(expr) { + } + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class ASTBuilder { private: enum LoopState { None, Outermost, Inner }; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 83e9a57f49fd0..d9099ccdcf785 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -898,6 +898,26 @@ class FuncCallStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * A reference to a variable. + */ +class ReferenceStmt : public Stmt { + public: + Stmt *var; + bool global_side_effect{false}; + + ReferenceStmt(Stmt *var) : var(var) { + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + return global_side_effect; + } + + TI_STMT_DEF_FIELDS(ret_type, var); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * Exit the kernel or function with a return value. */ diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 29374e70fad48..6d4095fd48b89 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -726,7 +726,9 @@ void export_lang(py::module &m) { Stmt::make); m.def("make_arg_load_expr", - Expr::make); + Expr::make); + + m.def("make_reference", Expr::make); m.def("make_external_tensor_expr", Expr::maketype_hint(), stmt->name(), stmt->var->name()); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 5141724e19deb..7920370e3444e 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -408,7 +408,9 @@ class LowerAST : public IRVisitor { TI_NOT_IMPLEMENTED } } else { // global variable - TI_ASSERT(dest.is()); + TI_ASSERT(dest.is() || + (dest.is() && + dest.cast()->is_ptr)); flatten_lvalue(dest, &fctx); fctx.push_back(dest->stmt, expr->stmt); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 5eb1ba0e95804..fc3092b01b105 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -534,6 +534,11 @@ class TypeCheck : public IRVisitor { void visit(BitStructStoreStmt *stmt) override { // do nothing } + + void visit(ReferenceStmt *stmt) override { + stmt->ret_type = stmt->var->ret_type; + stmt->ret_type.set_is_pointer(true); + } }; namespace irpass { diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 33a19546be3de..af554afcbd266 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -80,7 +80,7 @@ def _get_expected_matrix_apis(): 'lang', 'length', 'linalg', 'log', 'loop_config', 'math', 'max', 'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange', 'no_activate', 'one', 'opengl', 'polar_decompose', 'pow', 'profiler', - 'randn', 'random', 'raw_div', 'raw_mod', 'rescale_index', 'reset', + 'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', 'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level', 'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static', 'static_assert', 'static_print', 'stop_grad', 'svd', 'swizzle_generator', diff --git a/tests/python/test_function.py b/tests/python/test_function.py index a7d38116cf4e8..70e70692f0315 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -304,3 +304,33 @@ def bar(a: ti.i32) -> ti.i32: assert bar(10) == 11 * 5 assert bar(200) == 99 * 50 + + +@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True) +def test_ref(): + @ti.experimental.real_func + def foo(a: ti.ref(ti.f32)): + a = 7 + + @ti.kernel + def bar(): + a = 5. + foo(a) + assert a == 7 + + bar() + + +@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True) +def test_ref_atomic(): + @ti.experimental.real_func + def foo(a: ti.ref(ti.f32)): + a += a + + @ti.kernel + def bar(): + a = 5. + foo(a) + assert a == 10. + + bar()