Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CC] [autodiff] Support AdStack on C backend #1752

Merged
merged 6 commits into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CCTransformer : public IRVisitor {
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config,
/*vectorize=*/false, kernel->grad,
/*ad_use_stack=*/false, config.print_ir,
/*ad_use_stack=*/true, config.print_ir,
/*lower_global_access*/ true);
}

Expand Down Expand Up @@ -520,6 +520,64 @@ class CCTransformer : public IRVisitor {
data_type_short_name(stmt->ret_type.data_type));
}

void visit(StackAllocaStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);

const auto &var_name = stmt->raw_name();
emit("Ti_u8 {}[{}];", var_name, stmt->size_in_bytes() + sizeof(uint32_t));
emit("Ti_ad_stack_init({});", var_name);
}

void visit(StackPopStmt *stmt) override {
emit("Ti_ad_stack_pop({});", stmt->stack->raw_name());
}

void visit(StackPushStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto &stack_name = stack->raw_name();
auto elem_size = stack->element_size_in_bytes();
emit("Ti_ad_stack_push({}, {});", stack_name, elem_size);
auto primal_name = stmt->raw_name() + "_primal_";
auto dt_name = cc_data_type_name(stmt->element_type());
auto var = define_var(dt_name + " *", primal_name);
emit("{} = ({} *) Ti_ad_stack_top_primal({}, {});", var, dt_name,
stack_name, elem_size);
emit("*{} = {};", primal_name, stmt->v->raw_name());
}

void visit(StackLoadTopStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto primal_name = stmt->raw_name() + "_primal_";
auto dt_name = cc_data_type_name(stmt->element_type());
auto var = define_var(dt_name + " *", primal_name);
emit("{} = ({} *)Ti_ad_stack_top_primal({}, {});", var, dt_name,
stack->raw_name(), stack->element_size_in_bytes());
emit("{} = *{};", define_var(dt_name, stmt->raw_name()), primal_name);
}

void visit(StackLoadTopAdjStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto adjoint_name = stmt->raw_name() + "_adjoint_";
auto dt_name = cc_data_type_name(stmt->element_type());
auto var = define_var(dt_name + " *", adjoint_name);
emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name,
stack->raw_name(), stack->element_size_in_bytes());
emit("{} = *{};", define_var(dt_name, stmt->raw_name()), adjoint_name);
}

void visit(StackAccAdjointStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto adjoint_name = stmt->raw_name() + "_adjoint_";
auto dt_name = cc_data_type_name(stmt->element_type());
auto var = define_var(dt_name + " *", adjoint_name);
emit("{} = ({} *)Ti_ad_stack_top_adjoint({}, {});", var, dt_name,
stack->raw_name(), stack->element_size_in_bytes());
emit("printf(\"%d\\n\", *Ti_ad_stack_n({}));", stack->raw_name());
emit("printf(\"%p\\n\", {});", stack->raw_name());
emit("printf(\"%p\\n\", {});", adjoint_name);
emit("*{} += {};", adjoint_name, stmt->v->raw_name());
}

template <typename... Args>
void emit(std::string f, Args &&... args) {
line_appender.append(std::move(f), std::move(args)...);
Expand Down
44 changes: 44 additions & 0 deletions taichi/backends/cc/runtime/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,50 @@ static inline Ti_f32 Ti_rand_f32(void) {
return (Ti_f32) drand48(); // [0.0, 1.0)
}

// Copied from Metal:
typedef Ti_u8 *Ti_AdStackPtr;

static inline Ti_u32 *Ti_ad_stack_n(Ti_AdStackPtr stack) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... I would suggest not using a capital letter as the beginning of a function's name. I haven't taken a look at the C backend before, so is there any reason that the Ti_ prefix is used? I would suggest cc_ prefix to show that it's the C backend (and probably Cc for classes).

return (Ti_u32 *)stack;
}

static inline Ti_AdStackPtr Ti_ad_stack_data(Ti_AdStackPtr stack) {
return stack + sizeof(Ti_u32);
}

static inline void Ti_ad_stack_init(Ti_AdStackPtr stack) {
Ti_u32 *n = Ti_ad_stack_n(stack);
Ti_i32 *data = (Ti_i32 *)Ti_ad_stack_data(stack);
*n = 0;
}

static inline Ti_AdStackPtr Ti_ad_stack_top_primal(Ti_AdStackPtr stack,
Ti_u32 element_size) {
Ti_u32 *n = Ti_ad_stack_n(stack);
return Ti_ad_stack_data(stack) + (*n - 1) * 2 * element_size;
Copy link
Collaborator Author

@archibate archibate Aug 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied from ad_stack.metal.h, can you tell me why n - 1 here? @k-ye
Sometimes n can be 0 and it gets overflowed, resulting in a serious segfault when the lhs pointer is 64-bit (-1 = 0xffffffff). But it somehow silently passed on Metal whose pointer is 32-bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when n is 0, it's pretty OK to have a segfault here -- just like this in C++:

std::stack<int> s;
s.top();  // runtime error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, this is just l[len(l) - 1]. As mentioned, accessing top without push sounds like a bug.

}

static inline Ti_AdStackPtr Ti_ad_stack_top_adjoint(Ti_AdStackPtr stack,
Ti_u32 element_size) {
return Ti_ad_stack_top_primal(stack, element_size) + element_size;
}

static inline void Ti_ad_stack_pop(Ti_AdStackPtr stack) {
Ti_u32 *n = Ti_ad_stack_n(stack);
--(*n);
}

static inline void Ti_ad_stack_push(Ti_AdStackPtr stack, Ti_u32 element_size) {
Ti_u32 i;
Ti_u32 *n = Ti_ad_stack_n(stack);
++(*n);

Ti_AdStackPtr data = Ti_ad_stack_top_primal(stack, element_size);
for (i = 0; i < element_size * 2; ++i) {
data[i] = 0;
}
}

) "\n" STR(
)

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ bool is_extension_supported(Arch arch, Extension ext) {
Extension::adstack, Extension::bls, Extension::assertion}},
{Arch::metal, {Extension::adstack}},
{Arch::opengl, {Extension::extfunc}},
{Arch::cc, {Extension::data64, Extension::extfunc}},
{Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}},
};
// if (with_opengl_extension_data64())
// arch2ext[Arch::opengl].insert(Extension::data64); // TODO: singleton
Expand Down