Skip to content

Commit

Permalink
Merge branch 'branch-22.08' into fix/conda_meta_cmake
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Jul 28, 2022
2 parents 3ae7c59 + 2325d2b commit 7e5a459
Show file tree
Hide file tree
Showing 127 changed files with 14,410 additions and 3,565 deletions.
3 changes: 1 addition & 2 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ function sed_runner() {
sed_runner 's/'"RAFT VERSION .* LANGUAGES"'/'"RAFT VERSION ${NEXT_FULL_TAG} LANGUAGES"'/g' cpp/CMakeLists.txt
sed_runner 's/'"pylibraft_version .*)"'/'"pylibraft_version ${NEXT_FULL_TAG})"'/g' python/pylibraft/CMakeLists.txt
sed_runner 's/'"pyraft_version .*)"'/'"pyraft_version ${NEXT_FULL_TAG})"'/g' python/raft/CMakeLists.txt
sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' cpp/CMakeLists.txt
sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' python/pylibraft/CMakeLists.txt
sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' fetch_rapids.cmake

# Docs update
sed_runner 's/version = .*/version = '"'${NEXT_SHORT_TAG}'"'/g' docs/source/conf.py
Expand Down
4 changes: 1 addition & 3 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ set(RAPIDS_VERSION "22.06")
set(RAFT_VERSION "${RAPIDS_VERSION}.00")

cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-22.10/RAPIDS.cmake
${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(../fetch_rapids.cmake)
include(rapids-cmake)
include(rapids-cpm)
include(rapids-cuda)
Expand Down
37 changes: 36 additions & 1 deletion cpp/bench/spatial/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include <common/benchmark.hpp>

#include <raft/random/rng.cuh>
#include <raft/spatial/knn/knn.cuh>

#include <raft/spatial/knn/ivf_flat.cuh>
#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.cuh>
#endif
Expand Down Expand Up @@ -126,6 +127,34 @@ struct host_uvector {
T* arr_;
};

template <typename ValT, typename IdxT>
struct ivf_flat_knn {
using dist_t = float;

std::optional<const raft::spatial::knn::ivf_flat::index<ValT, IdxT>> index;
raft::spatial::knn::ivf_flat::index_params index_params;
raft::spatial::knn::ivf_flat::search_params search_params;
params ps;

ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::spatial::knn::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}

void search(const raft::handle_t& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
raft::spatial::knn::ivf_flat::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
}
};

template <typename ValT, typename IdxT>
struct brute_force_knn {
using dist_t = ValT;
Expand Down Expand Up @@ -326,7 +355,13 @@ const std::vector<Scope> kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::B
}

KNN_REGISTER(float, int64_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull);
KNN_REGISTER(float, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(int8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(uint8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);

KNN_REGISTER(float, uint32_t, brute_force_knn, kInputs, kNoCopyOnly, kScopeFull);
KNN_REGISTER(float, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(int8_t, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(uint8_t, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);

} // namespace raft::bench::spatial
Loading

0 comments on commit 7e5a459

Please sign in to comment.