Skip to content

Commit

Permalink
Binarize Dice Distance for Dense Inputs (#2370)
Browse files Browse the repository at this point in the history
Instead of binarizing in `pairwise_distances()` in cuml, do it in the distance function in raft.

Authors:
  - Anupam (https://github.com/aamijar)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Micka (https://github.com/lowener)

URL: #2370
  • Loading branch information
aamijar authored Jul 10, 2024
1 parent 7aebe87 commit 5bf6642
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 26 deletions.
28 changes: 5 additions & 23 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -279,32 +279,14 @@ void distance_impl(raft::resources const& handle,
true,
stream,
false,
raft::identity_op(),
raft::nz_op(),
raft::add_op());
} else {
y_norm += m;
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(y_norm,
y,
k,
n,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(
x_norm, x, k, m, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op());
raft::linalg::reduce(
y_norm, y, k, n, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op());
}

ops::dice_distance_op<DataT, AccT, IdxT> distance_op{};
Expand Down
5 changes: 4 additions & 1 deletion cpp/include/raft/distance/detail/distance_ops/dice.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ struct dice_distance_op {
return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT));
}

DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; };
DI void core(AccT& acc, DataT& x, DataT& y) const
{
acc += (x != DataT(0) ? DataT(1) : DataT(0)) * (y != DataT(0) ? DataT(1) : DataT(0));
};

template <typename Policy>
DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
# NOTE: this template is not perfectly formatted. Use pre-commit to get
# everything in shape again.
header = """/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,6 +95,11 @@
OpT="raft::distance::detail::ops::cosine_distance_op",
archs = [60, 80],
),
dict(
path_prefix="dice",
OpT="raft::distance::detail::ops::dice_distance_op",
archs = [60, 80],
),
dict(
path_prefix="hamming_unexpanded",
OpT="raft::distance::detail::ops::hamming_distance_op",
Expand Down

0 comments on commit 5bf6642

Please sign in to comment.