Skip to content

Commit

Permalink
[type] Support basic custom int/float types on metal (#2145)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Jan 9, 2021
1 parent 7df6f22 commit 210b212
Show file tree
Hide file tree
Showing 17 changed files with 527 additions and 107 deletions.
242 changes: 214 additions & 28 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ std::string buffer_to_name(BuffersEnum b) {
return {};
}

bool is_ret_type_bit_pointer(Stmt *s) {
if (auto *ty = s->ret_type->cast<PointerType>()) {
// Don't use as() directly, it would fail when we inject a global tmp.
return ty->is_bit_pointer();
}
return false;
}

class KernelCodegen : public IRVisitor {
private:
enum class Section {
Expand Down Expand Up @@ -128,10 +136,12 @@ class KernelCodegen : public IRVisitor {
generate_kernels();

std::string source_code;
for (const auto s : kAllSections) {
source_code += section_appenders_.find(s)->second.lines();
source_code += '\n';
}
source_code += section_appenders_.at(Section::Headers).lines();
source_code += "namespace {\n";
source_code += section_appenders_.at(Section::Structs).lines();
source_code += section_appenders_.at(Section::KernelFuncs).lines();
source_code += "} // namespace\n";
source_code += section_appenders_.at(Section::Kernels).lines();
return source_code;
}

Expand Down Expand Up @@ -189,21 +199,6 @@ class KernelCodegen : public IRVisitor {
kRootBufferName);
}

void visit(GetChStmt *stmt) override {
// E.g. `parent.get*(runtime, mem_alloc)`
const auto get_call =
fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid,
kRuntimeVarName, kMemAllocVarName);
if (stmt->output_snode->is_place()) {
emit(R"(device {}* {} = {}.val;)",
metal_data_type_name(stmt->output_snode->dt), stmt->raw_name(),
get_call);
} else {
emit(R"({} {} = {};)", stmt->output_snode->node_type_name,
stmt->raw_name(), get_call);
}
}

void visit(LinearizeStmt *stmt) override {
std::string val = "0";
for (int i = 0; i < (int)stmt->inputs.size(); i++) {
Expand All @@ -229,8 +224,29 @@ class KernelCodegen : public IRVisitor {
}
const auto *sn = stmt->snode;
const auto snty = sn->type;
if (snty == SNodeType::bit_struct) {
// Example *bit_struct* struct generated on Metal:
//
// struct Sx {
// // bit_struct
// Sx(device byte *b, ...) : base(b) {}
// device byte *base;
// };
emit("auto {} = {}.base;", stmt->raw_name(), parent);
return;
}
const std::string index_name = stmt->input_index->raw_name();

// Example SNode struct generated on Metal:
//
// struct S1 {
// // dense
// S1(device byte *addr, ...) { rep_.init(addr); }
// S1_ch children(int i) { return {rep_.addr() + (i * elem_stride)}; }
// inline void activate(int i) { rep_.activate(i); }
// ...
// private:
// SNodeRep_dense rep_;
// };
if (stmt->activate) {
TI_ASSERT(is_supported_sparse_type(snty));
emit("{}.activate({});", parent, index_name);
Expand All @@ -239,6 +255,32 @@ class KernelCodegen : public IRVisitor {
parent, index_name);
}

void visit(GetChStmt *stmt) override {
auto *in_snode = stmt->input_snode;
auto *out_snode = stmt->output_snode;
if (in_snode->type == SNodeType::bit_struct) {
TI_ASSERT(stmt->ret_type->as<PointerType>()->is_bit_pointer());
const auto *bit_struct_ty = in_snode->dt->cast<BitStructType>();
const auto bit_offset =
bit_struct_ty->get_member_bit_offset(in_snode->child_id(out_snode));
// stmt->input_ptr is the "base" member in the generated SNode struct.
emit("SNodeBitPointer {}({}, /*offset=*/{});", stmt->raw_name(),
stmt->input_ptr->raw_name(), bit_offset);
return;
}
// E.g. `parent.get*(runtime, mem_alloc)`
const auto get_call =
fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid,
kRuntimeVarName, kMemAllocVarName);
if (out_snode->is_place()) {
emit(R"(device {}* {} = {}.val;)", metal_data_type_name(out_snode->dt),
stmt->raw_name(), get_call);
} else {
emit(R"({} {} = {};)", out_snode->node_type_name, stmt->raw_name(),
get_call);
}
}

void visit(SNodeOpStmt *stmt) override {
const std::string result_var = stmt->raw_name();
const auto opty = stmt->op_type;
Expand Down Expand Up @@ -292,13 +334,23 @@ class KernelCodegen : public IRVisitor {

void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());

if (!is_ret_type_bit_pointer(stmt->ptr)) {
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
return;
}
handle_bit_pointer_global_store(stmt);
}

void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit(R"({} {} = *{};)", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), stmt->ptr->raw_name());
std::string rhs_expr;
if (!is_ret_type_bit_pointer(stmt->ptr)) {
rhs_expr = fmt::format("*{}", stmt->ptr->raw_name());
} else {
rhs_expr = construct_bit_pointer_global_load(stmt);
}
emit("const auto {} = {};", stmt->raw_name(), rhs_expr);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down Expand Up @@ -457,7 +509,6 @@ class KernelCodegen : public IRVisitor {

void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
const auto dt = stmt->val->element_type();
const auto op_type = stmt->op_type;
std::string op_name;
bool handle_float = false;
Expand All @@ -475,6 +526,11 @@ class KernelCodegen : public IRVisitor {
TI_NOT_IMPLEMENTED;
}

if (is_ret_type_bit_pointer(stmt->dest)) {
handle_bit_pointer_atomics(stmt);
return;
}

std::string val_var = stmt->val->raw_name();
// TODO(k-ye): This is not a very reliable way to detect if we're in TLS
// xlogues...
Expand All @@ -488,7 +544,7 @@ class KernelCodegen : public IRVisitor {
emit("if ({} == 0) {{", kKernelTidInSimdgroupName);
current_appender().push_indent();
}

const auto dt = stmt->val->element_type();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, "
Expand Down Expand Up @@ -626,9 +682,11 @@ class KernelCodegen : public IRVisitor {
if (std::holds_alternative<Stmt *>(entry)) {
auto *arg_stmt = std::get<Stmt *>(entry);
const auto dt = arg_stmt->element_type();
TI_ASSERT_INFO(dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::f32),
"print() only supports i32 or f32 scalars for now.");
TI_ASSERT_INFO(
dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::f32),
"print() only supports i32, u32 or f32 scalars for now.");
emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_name(dt), i,
arg_stmt->raw_name());
} else {
Expand Down Expand Up @@ -773,6 +831,133 @@ class KernelCodegen : public IRVisitor {
emit_kernel_args_struct();
}

void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string store_value_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
store_value_expr = stmt->data->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
auto *digits_cit = cft->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
store_value_expr = construct_float_to_custom_int_expr(
stmt->data, cft->get_scale(), digits_cit);
} else {
TI_NOT_IMPLEMENTED;
}
// Type of |stmt->ptr| is SNodeBitPointer
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(),
store_value_expr);
} else {
emit("mtl_set_partial_bits({},", stmt->ptr->raw_name());
emit(" {},", store_value_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of the load result
std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_as_custom_int(stmt->ptr, cit);
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
const auto loaded = construct_load_as_custom_int(
stmt->ptr, cft->get_digits_type()->as<CustomIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_custom_float()
return fmt::format("(static_cast<float>({}) * {})", loaded,
cft->get_scale());
}
TI_NOT_IMPLEMENTED;
return "";
}

void handle_bit_pointer_atomics(AtomicOpStmt *stmt) {
TI_ERROR_IF(stmt->op_type != AtomicOpType::add,
"Only atomic add is supported for bit pointer types");
// Type of |dest_ptr| is SNodeBitPointer
const auto *dest_ptr = stmt->dest;
auto *ptr_type = dest_ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string val_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
val_expr = stmt->val->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
cit = cft->get_digits_type()->as<CustomIntType>();
val_expr =
construct_float_to_custom_int_expr(stmt->val, cft->get_scale(), cit);
} else {
TI_NOT_IMPLEMENTED;
}
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("const auto {} = mtl_atomic_add_full_bits({}, {});",
stmt->raw_name(), dest_ptr->raw_name(), val_expr);
} else {
emit("const auto {} = mtl_atomic_add_partial_bits({},", stmt->raw_name(),
dest_ptr->raw_name());
emit(" {},", val_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of `int(val_stmt * (1.0f / scale) + 0.5f)`
std::string construct_float_to_custom_int_expr(
const Stmt *val_stmt,
float64 scale,
CustomIntType *digits_cit) const {
DataType compute_dt(digits_cit->get_compute_type()->as<PrimitiveType>());
// This implicitly casts double to float on the host.
const float inv_scale = 1.0 / scale;
// Creating an expression (instead of holding intermediate results with
// variables) because |val_stmt| could be used multiple times. If the
// intermediate variables are named based on |val_stmt|, it would result in
// symbol redefinitions.
return fmt::format("mtl_float_to_custom_int<{}>(/*inv_scale=*/{} * {})",
metal_data_type_name(compute_dt), inv_scale,
val_stmt->raw_name());
}

// Returns expression of the loaded integer.
std::string construct_load_as_custom_int(const Stmt *bit_ptr_stmt,
CustomIntType *cit) const {
DataType compute_dt(cit->get_compute_type()->as<PrimitiveType>());
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
return fmt::format("mtl_get_full_bits<{}>({})",
metal_data_type_name(compute_dt),
bit_ptr_stmt->raw_name());
}
return fmt::format("mtl_get_partial_bits<{}>({}, {})",
metal_data_type_name(compute_dt),
bit_ptr_stmt->raw_name(), num_bits);
}

void validate_cft_for_metal(CustomFloatType *cft) const {
if (cft->get_exponent_type() != nullptr) {
TI_NOT_IMPLEMENTED;
}
if (cft->get_compute_type()->as<PrimitiveType>() != PrimitiveType::f32) {
TI_ERROR("Metal only supports 32-bit float");
}
}

static bool is_full_bits(int bits) {
return bits == (sizeof(uint32_t) * 8);
}

void emit_kernel_args_struct() {
if (ctx_attribs_.empty()) {
return;
Expand Down Expand Up @@ -924,6 +1109,7 @@ class KernelCodegen : public IRVisitor {
emit("const int {} = {} - {};", total_elems_name, end_expr, begin_expr);
ka.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop;
}
// TODO: I've seen cases where |block_dim| was set to 1...
ka.advisory_num_threads_per_group = stmt->block_dim;
// begin_ = thread_id + begin_expr
emit("const int begin_ = {} + {};", kKernelThreadIdName, begin_expr);
Expand Down
Loading

0 comments on commit 210b212

Please sign in to comment.