diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index 803d182b011dc..c0b2f164f7210 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -48,7 +48,7 @@ class CCTransformer : public IRVisitor { config.demote_dense_struct_fors = true; irpass::compile_to_executable(ir, config, /*vectorize=*/false, kernel->grad, - /*ad_use_stack=*/false, config.print_ir, + /*ad_use_stack=*/true, config.print_ir, /*lower_global_access*/ true); } @@ -520,6 +520,64 @@ class CCTransformer : public IRVisitor { data_type_short_name(stmt->ret_type.data_type)); } + void visit(StackAllocaStmt *stmt) override { + TI_ASSERT(stmt->width() == 1); + + const auto &var_name = stmt->raw_name(); + emit("Ti_u8 {}[{}];", var_name, stmt->size_in_bytes() + sizeof(uint32_t)); + emit("Ti_ad_stack_init({});", var_name); + } + + void visit(StackPopStmt *stmt) override { + emit("Ti_ad_stack_pop({});", stmt->stack->raw_name()); + } + + void visit(StackPushStmt *stmt) override { + auto *stack = stmt->stack->as(); + const auto &stack_name = stack->raw_name(); + auto elem_size = stack->element_size_in_bytes(); + emit("Ti_ad_stack_push({}, {});", stack_name, elem_size); + auto primal_name = stmt->raw_name() + "_primal_"; + auto dt_name = cc_data_type_name(stmt->element_type()); + auto var = define_var(dt_name + " *", primal_name); + emit("{} = ({} *) Ti_ad_stack_top_primal({}, {});", var, dt_name, + stack_name, elem_size); + emit("*{} = {};", primal_name, stmt->v->raw_name()); + } + + void visit(StackLoadTopStmt *stmt) override { + auto *stack = stmt->stack->as(); + const auto primal_name = stmt->raw_name() + "_primal_"; + auto dt_name = cc_data_type_name(stmt->element_type()); + auto var = define_var(dt_name + " *", primal_name); + emit("{} = ({} *)Ti_ad_stack_top_primal({}, {});", var, dt_name, + stack->raw_name(), stack->element_size_in_bytes()); + emit("{} = *{};", define_var(dt_name, stmt->raw_name()), primal_name); + } + + void visit(StackLoadTopAdjStmt *stmt) override { + auto *stack = stmt->stack->as(); + const auto adjoint_name = stmt->raw_name() + "_adjoint_"; + auto dt_name = cc_data_type_name(stmt->element_type()); + auto var = define_var(dt_name + " *", adjoint_name); + emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name, + stack->raw_name(), stack->element_size_in_bytes()); + emit("{} = *{};", define_var(dt_name, stmt->raw_name()), adjoint_name); + } + + void visit(StackAccAdjointStmt *stmt) override { + auto *stack = stmt->stack->as(); + const auto adjoint_name = stmt->raw_name() + "_adjoint_"; + auto dt_name = cc_data_type_name(stmt->element_type()); + auto var = define_var(dt_name + " *", adjoint_name); + emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name, + stack->raw_name(), stack->element_size_in_bytes()); + emit("printf(\"%d\\n\", *Ti_ad_stack_n({}));", stack->raw_name()); + emit("printf(\"%p\\n\", {});", stack->raw_name()); + emit("printf(\"%p\\n\", {});", adjoint_name); + emit("*{} += {};", adjoint_name, stmt->v->raw_name()); + } + template void emit(std::string f, Args &&... args) { line_appender.append(std::move(f), std::move(args)...); diff --git a/taichi/backends/cc/runtime/base.h b/taichi/backends/cc/runtime/base.h index 609cc66740933..2b4f450437039 100644 --- a/taichi/backends/cc/runtime/base.h +++ b/taichi/backends/cc/runtime/base.h @@ -83,6 +83,50 @@ static inline Ti_f32 Ti_rand_f32(void) { return (Ti_f32) drand48(); // [0.0, 1.0) } +// Copied from Metal: +typedef Ti_u8 *Ti_AdStackPtr; + +static inline Ti_u32 *Ti_ad_stack_n(Ti_AdStackPtr stack) { + return (Ti_u32 *)stack; +} + +static inline Ti_AdStackPtr Ti_ad_stack_data(Ti_AdStackPtr stack) { + return stack + sizeof(Ti_u32); +} + +static inline void Ti_ad_stack_init(Ti_AdStackPtr stack) { + Ti_u32 *n = Ti_ad_stack_n(stack); + Ti_i32 *data = (Ti_i32 *)Ti_ad_stack_data(stack); + *n = 0; +} + +static inline Ti_AdStackPtr Ti_ad_stack_top_primal(Ti_AdStackPtr stack, + Ti_u32 element_size) { + Ti_u32 *n = Ti_ad_stack_n(stack); + return Ti_ad_stack_data(stack) + (*n - 1) * 2 * element_size; +} + +static inline Ti_AdStackPtr Ti_ad_stack_top_adjoint(Ti_AdStackPtr stack, + Ti_u32 element_size) { + return Ti_ad_stack_top_primal(stack, element_size) + element_size; +} + +static inline void Ti_ad_stack_pop(Ti_AdStackPtr stack) { + Ti_u32 *n = Ti_ad_stack_n(stack); + --(*n); +} + +static inline void Ti_ad_stack_push(Ti_AdStackPtr stack, Ti_u32 element_size) { + Ti_u32 i; + Ti_u32 *n = Ti_ad_stack_n(stack); + ++(*n); + + Ti_AdStackPtr data = Ti_ad_stack_top_primal(stack, element_size); + for (i = 0; i < element_size * 2; ++i) { + data[i] = 0; + } +} + ) "\n" STR( ) diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index cd2098eb552c8..82662c60c30ff 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -19,7 +19,7 @@ bool is_extension_supported(Arch arch, Extension ext) { Extension::adstack, Extension::bls, Extension::assertion}}, {Arch::metal, {Extension::adstack}}, {Arch::opengl, {Extension::extfunc}}, - {Arch::cc, {Extension::data64, Extension::extfunc}}, + {Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}}, }; // if (with_opengl_extension_data64()) // arch2ext[Arch::opengl].insert(Extension::data64); // TODO: singleton