Skip to content

Commit

Permalink
[Bug] Warp reduction bug fix (#2519)
Browse files Browse the repository at this point in the history
* bug fix

* update tests to cover the bug
  • Loading branch information
AmesingFlank authored Jul 13, 2021
1 parent d2a139a commit 286328c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
26 changes: 13 additions & 13 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,19 +1083,19 @@ i32 op_xor_i32(i32 a, i32 b) {
return a ^ b;
}

#define DEFINE_REDUCTION(op, dtype) \
dtype warp_reduce_##op##_##dtype(dtype val) { \
for (int offset = 16; offset > 0; offset /= 2) \
val = op_##op##_##dtype( \
val, cuda_shfl_down_sync_i32(0xFFFFFFFF, val, offset, 31)); \
return val; \
} \
dtype reduce_##op##_##dtype(dtype *result, dtype val) { \
dtype warp_result = warp_reduce_##op##_##dtype(val); \
if ((thread_idx() & (warp_size() - 1)) == 0) { \
atomic_##op##_##dtype(result, warp_result); \
} \
return val; \
#define DEFINE_REDUCTION(op, dtype) \
dtype warp_reduce_##op##_##dtype(dtype val) { \
for (int offset = 16; offset > 0; offset /= 2) \
val = op_##op##_##dtype( \
val, cuda_shfl_down_sync_##dtype(0xFFFFFFFF, val, offset, 31)); \
return val; \
} \
dtype reduce_##op##_##dtype(dtype *result, dtype val) { \
dtype warp_result = warp_reduce_##op##_##dtype(val); \
if ((thread_idx() & (warp_size() - 1)) == 0) { \
atomic_##op##_##dtype(result, warp_result); \
} \
return val; \
}

DEFINE_REDUCTION(add, i32);
Expand Down
19 changes: 14 additions & 5 deletions tests/python/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ def _test_reduction_single(dtype, criterion, op):
a = ti.field(dtype, shape=N)
tot = ti.field(dtype, shape=())

@ti.kernel
def fill():
for i in a:
a[i] = i
if dtype in [ti.f32, ti.f64]:

@ti.kernel
def fill():
for i in a:
a[i] = i + 0.5
else:

@ti.kernel
def fill():
for i in a:
a[i] = i

ti_op = ti_ops[op]

Expand All @@ -62,7 +70,8 @@ def reduce_tmp() -> dtype:
reduce()
tot2 = reduce_tmp()

ground_truth = np_ops[op](a.to_numpy())
np_arr = np.append(a.to_numpy(), [0])
ground_truth = np_ops[op](np_arr)

assert criterion(tot[None], ground_truth)
assert criterion(tot2, ground_truth)
Expand Down

0 comments on commit 286328c

Please sign in to comment.