Skip to content

Commit

Permalink
Add Fused L2 Expanded KNN kernel (#339)
Browse files Browse the repository at this point in the history
-- adds fused L2 expanded kNN kernel, this is faster by at least 20-25% on higher dimensions (D >= 128) than L2 unexpanded version.
-- also on smaller dimension (D <=32) L2 expanded is always faster by 10-15%
 -- slight improvement in updateSortedWarpQ device function by reducing redundant instruction.
-- Fix incorrect output for NN >32 case when taking prod-cons knn merge path, this was caught in HDBSCAN pytest.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #339
  • Loading branch information
mdoijade authored Nov 23, 2021
1 parent 94e6909 commit 6166a47
Show file tree
Hide file tree
Showing 7 changed files with 664 additions and 247 deletions.
29 changes: 26 additions & 3 deletions cpp/include/raft/device_atomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,15 @@ struct genericAtomicOperationImpl<T, Op, 4> {
__forceinline__ __device__ T operator()(T* addr, T const& update_value,
Op op) {
using T_int = unsigned int;

T old_value = *addr;
T assumed{old_value};

if constexpr (std::is_same<T, float>{} && (std::is_same<Op, DeviceMin>{})) {
if (isnan(update_value)) {
return old_value;
}
}

do {
assumed = old_value;
const T new_value = op(old_value, update_value);
Expand All @@ -191,13 +196,32 @@ struct genericAtomicOperationImpl<T, Op, 4> {
type_reinterpret<T_int, T>(assumed),
type_reinterpret<T_int, T>(new_value));
old_value = type_reinterpret<T, T_int>(ret);

} while (assumed != old_value);

return old_value;
}
};

// 4 bytes fp32 atomic Max operation
template <>
struct genericAtomicOperationImpl<float, DeviceMax, 4> {
using T = float;
__forceinline__ __device__ T operator()(T* addr, T const& update_value,
DeviceMax op) {
if (isnan(update_value)) {
return *addr;
}

T old =
(update_value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(update_value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(update_value)));

return old;
}
};

// 8 bytes atomic operation
template <typename T, typename Op>
struct genericAtomicOperationImpl<T, Op, 8> {
Expand Down Expand Up @@ -423,7 +447,6 @@ struct typesAtomicCASImpl<T, 4> {
T_int ret = atomicCAS(reinterpret_cast<T_int*>(addr),
type_reinterpret<T_int, T>(compare),
type_reinterpret<T_int, T>(update_value));

return type_reinterpret<T, T_int>(ret);
}
};
Expand Down
Loading

0 comments on commit 6166a47

Please sign in to comment.