Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove faiss from cuml #5293

Merged
merged 22 commits into from
Mar 28, 2023
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4f66867
Tests seem to pass w/ the new compiled bits
cjnolet Mar 14, 2023
957bd1c
Merge branch 'branch-23.04' into imp-2306-using_raft_compiled_target
cjnolet Mar 20, 2023
083f667
Reverting changes to dependencies.yaml
cjnolet Mar 20, 2023
0abad43
Update cpp/CMakeLists.txt
cjnolet Mar 20, 2023
ef2d8ee
Fixing bad merge
cjnolet Mar 20, 2023
16d39fe
Temporarily using artifacts from commit
cjnolet Mar 21, 2023
241171b
Using correct path
cjnolet Mar 21, 2023
30e531f
Something got reverted. Fixing
cjnolet Mar 21, 2023
3ee94da
Adding cuml_use_faiss_static option back in
cjnolet Mar 21, 2023
2db2603
Adding CUML_USE_RAFT_NN back
cjnolet Mar 21, 2023
3578b83
Turning USE_RAFT_NN on in the proper place
cjnolet Mar 21, 2023
e567150
Fixing get_raft
cjnolet Mar 21, 2023
6ad3f12
remove faiss from cuml
benfred Mar 22, 2023
4c3ce76
fix get_raft
benfred Mar 22, 2023
f850e93
Merge branch 'branch-23.04' into remove_faiss
benfred Mar 23, 2023
b7c5f2f
Use device inline epilogue
benfred Mar 24, 2023
bed5213
Merge branch 'branch-23.04' into remove_faiss
cjnolet Mar 25, 2023
71185fb
Add note about exposing more metrics to hdbscan
benfred Mar 27, 2023
8b39d92
Merge branch 'remove_faiss' of github.com:benfred/cuml into remove_faiss
benfred Mar 27, 2023
6002077
Merge branch 'branch-23.04' into remove_faiss
benfred Mar 27, 2023
f3a1543
remove 'k' param from brute_force::knn code
benfred Mar 28, 2023
3557be1
Merge branch 'remove_faiss' of github.com:benfred/cuml into remove_faiss
benfred Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use device inline epilogue
benfred committed Mar 24, 2023
commit b7c5f2f70c41d720f5bd18c8606180b3afb59c6e
55 changes: 22 additions & 33 deletions cpp/src/hdbscan/detail/reachability.cuh
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@

#include <raft/linalg/unary_op.cuh>

#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/linalg/symmetrize.cuh>

@@ -164,26 +164,15 @@ void _compute_core_dists(const raft::handle_t& handle,
}

// Functor to post-process distances into reachability space
template <typename T>
template <typename value_idx, typename value_t>
struct ReachabilityPostProcess {
void operator()(T* input, size_t batch_i, size_t batch_j, size_t rows, size_t cols) const
DI value_t operator()(value_t value, value_idx row, value_idx col) const
{
// Trying to access member variables directly from the device lambda causes
// an invalid memory access for me, copy to a temporary to work
// around
const T* core_dists_ = core_dists;
const T alpha_ = alpha;

raft::linalg::map_offset(
handle, raft::make_device_vector_view(input, rows * cols), [=] __device__(size_t i) {
size_t row = i / cols, col = i % cols;
return max(core_dists_[col + batch_j], max(core_dists_[row + batch_i], alpha_ * input[i]));
});
return max(core_dists[col], max(core_dists[row], alpha * value));
}

const raft::handle_t& handle;
const T* core_dists;
T alpha;
const value_t* core_dists;
value_t alpha;
};

/**
@@ -216,22 +205,22 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle,
// `A type local to a function cannot be used in the template argument of the
// enclosing parent function (and any parent classes) of an extended __device__
// or __host__ __device__ lambda`
auto post_process = ReachabilityPostProcess<value_t>{handle, core_dists, alpha};

raft::neighbors::detail::tiled_brute_force_knn(handle,
X,
X,
m,
m,
n,
k,
out_dists,
out_inds,
raft::distance::DistanceType::L2SqrtExpanded,
0,
0,
0,
post_process);
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha};

auto X_view = raft::make_device_matrix_view(X, m, n);
std::vector<raft::device_matrix_view<const value_t, size_t>> index = {X_view};

raft::neighbors::brute_force::knn<value_idx, value_t>(
handle,
index,
X_view,
raft::make_device_matrix_view(out_inds, m, static_cast<size_t>(k)),
raft::make_device_matrix_view(out_dists, m, static_cast<size_t>(k)),
k,
raft::distance::DistanceType::L2SqrtExpanded,
std::make_optional<float>(2.0f),
std::nullopt,
epilogue);
}

/**