Skip to content

Commit

Permalink
Encapsulate fatomic_fetch_add and ifloordiv as functions (#465)
Browse files Browse the repository at this point in the history
* Encapsulate fatomic_fetch_add and ifloor_div as functions

* simplify how functions are emitted
  • Loading branch information
k-ye authored Feb 15, 2020
1 parent 993c112 commit abbf5b1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 33 deletions.
73 changes: 42 additions & 31 deletions taichi/backends/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,8 @@ class MetalKernelCodegen : public IRVisitor {
const auto bin_name = bin->raw_name();
if (bin->op_type == BinaryOpType::floordiv) {
if (is_integral(bin->element_type())) {
const auto intm = fmt::format("{}_intermediate_", bin_name);
emit("const {} {} = ({} / {});", dt_name, intm, lhs_name, rhs_name);
// Should we construct an AST for this?
const auto expr_str = fmt::format(
"(({lhs} * {rhs} < 0) && ({rhs} * {intm} != {lhs})) ? ({intm} - 1) "
": {intm}",
fmt::arg("lhs", lhs_name), fmt::arg("rhs", rhs_name),
fmt::arg("intm", intm));
emit("const {} {} = ({});", dt_name, bin_name, expr_str);
emit("const {} {} = ifloordiv({}, {});", dt_name, bin_name, lhs_name,
rhs_name);
} else {
emit("const {} {} = floor({} / {});", dt_name, bin_name, lhs_name,
rhs_name);
Expand Down Expand Up @@ -280,27 +273,8 @@ class MetalKernelCodegen : public IRVisitor {
"metal::memory_order_relaxed);",
stmt->raw_name(), stmt->dest->raw_name(), stmt->val->raw_name());
} else if (dt == DataType::f32) {
// A huge hack! Metal does not support atomic floating point numbers
// natively.
const auto dest_name = stmt->dest->raw_name();
const auto cas_ok = fmt::format("{}_cas_ok_", dest_name);
const auto old_val = fmt::format("{}_old_", dest_name);
const auto new_val = fmt::format("{}_new_", dest_name);
emit("bool {} = false;", cas_ok);
emit("float {} = 0.0f;", stmt->raw_name());
emit("while (!{}) {{", cas_ok);
push_indent();
emit("float {} = *{};", old_val, dest_name);
emit("float {} = ({} + {});", new_val, old_val, stmt->val->raw_name());
emit("{} = atomic_compare_exchange_weak_explicit(", cas_ok);
emit(" (device atomic_int *){},", dest_name);
emit(" (thread int*)(&{}),", old_val);
emit(" *((thread int *)(&{})),", new_val);
emit(" metal::memory_order_relaxed,");
emit(" metal::memory_order_relaxed);");
emit("{} = {};", stmt->raw_name(), old_val);
pop_indent();
emit("}}");
emit("const float {} = fatomic_fetch_add({}, {});", stmt->raw_name(),
stmt->dest->raw_name(), stmt->val->raw_name());
} else {
TC_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -395,13 +369,18 @@ class MetalKernelCodegen : public IRVisitor {
emit("namespace {{");
emit("");
generate_common_functions();
emit("");
kernel_src_code_ += snode_structs_source_code;
emit("}} // namespace");
emit("");
}

void generate_common_functions() {
gen_union_cast();
gen_ifloordiv();
gen_fatomic_fetch_add_func();
}

void gen_union_cast() {
// For some reason, if I emit taichi/common.h's union_cast(), Metal failed
// to compile. More strangely, if I copy the generated code to XCode as a
// Metal kernel, it compiled successfully...
Expand All @@ -410,6 +389,38 @@ class MetalKernelCodegen : public IRVisitor {
emit(" static_assert(sizeof(T) == sizeof(G), \"Size mismatch\");");
emit(" return *reinterpret_cast<thread const T*>(&g);");
emit("}}");
emit("");
}

void gen_ifloordiv() {
emit("inline int ifloordiv(int lhs, int rhs) {{");
emit(" const int intm = (lhs / rhs);");
emit(
" return (((lhs * rhs < 0) && (rhs * intm != lhs)) ? (intm - 1) : "
"intm);");
emit("}}");
emit("");
}

void gen_fatomic_fetch_add_func() {
// A huge hack! Metal does not support atomic floating point numbers
// natively.
emit("float fatomic_fetch_add(device float* dest, const float operand) {{");
emit(" bool ok = false;");
emit(" float old_val = 0.0f;");
emit(" while (!ok) {{");
emit(" old_val = *dest;");
emit(" float new_val = (old_val + operand);");
emit(" ok = atomic_compare_exchange_weak_explicit(");
emit(" (device atomic_int *)dest,");
emit(" (thread int*)(&old_val),");
emit(" *((thread int *)(&new_val)),");
emit(" metal::memory_order_relaxed,");
emit(" metal::memory_order_relaxed);");
emit(" }}");
emit(" return old_val;");
emit("}}");
emit("");
}

void generate_kernel_args_struct(Kernel *kernel) {
Expand Down
5 changes: 4 additions & 1 deletion tests/python/test_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ def place():
ti.root.dense(ti.i, n).place(x, y)
ti.root.place(c)

# Make Taichi correctly infer the type
# TODO: Taichi seems to treat numpy.int32 as a float type, fix that.
init_ck = 0 if vartype == ti.i32 else 0.0
@ti.kernel
def func():
ck = ti.to_numpy_type(vartype)(0)
ck = init_ck
for i in range(n):
x[i] = ti.atomic_add(c[None], step)
y[i] = ti.atomic_add(ck, step)
Expand Down
17 changes: 16 additions & 1 deletion tests/python/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,29 @@ def test_floor_div():
_test_floor_div(ti.i32, 10, ti.f32, 3, ti.f32, 3)
_test_floor_div(ti.f32, 10, ti.i32, 3, ti.f32, 3)

_test_floor_div(ti.i32, -10, ti.i32, 3, ti.f32, -4)
_test_floor_div(ti.f32, -10, ti.f32, 3, ti.f32, -4)
_test_floor_div(ti.i32, -10, ti.f32, 3, ti.f32, -4)
_test_floor_div(ti.f32, -10, ti.i32, 3, ti.f32, -4)

_test_floor_div(ti.i32, 10, ti.i32, -3, ti.f32, -4)
_test_floor_div(ti.f32, 10, ti.f32, -3, ti.f32, -4)
_test_floor_div(ti.i32, 10, ti.f32, -3, ti.f32, -4)
_test_floor_div(ti.f32, 10, ti.i32, -3, ti.f32, -4)

def test_true_div():
_test_true_div(ti.i32, 3, ti.i32, 2, ti.f32, 1.5)
_test_true_div(ti.f32, 3, ti.f32, 2, ti.f32, 1.5)
_test_true_div(ti.i32, 3, ti.f32, 2, ti.f32, 1.5)
_test_true_div(ti.f32, 3, ti.i32, 2, ti.f32, 1.5)
_test_true_div(ti.f32, 3, ti.i32, 2, ti.i32, 1)


_test_true_div(ti.i32, -3, ti.i32, 2, ti.f32, -1.5)
_test_true_div(ti.f32, -3, ti.f32, 2, ti.f32, -1.5)
_test_true_div(ti.i32, -3, ti.f32, 2, ti.f32, -1.5)
_test_true_div(ti.f32, -3, ti.i32, 2, ti.f32, -1.5)
_test_true_div(ti.f32, -3, ti.i32, 2, ti.i32, -1)

@ti.all_archs
def test_div_default_ip():
ti.get_runtime().set_default_ip(ti.i64)
Expand Down

0 comments on commit abbf5b1

Please sign in to comment.