Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to use cuVS for vector search #6085

Merged
merged 52 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
2f2c55c
Migrate from raft to cuvs for pairwise_distance and bfknn
benfred Jul 16, 2024
9116ae8
Merge branch 'branch-24.08' into cuvs
benfred Jul 22, 2024
84bc77a
.
benfred Jul 22, 2024
cbb79ec
use cuvs::distance::DistanceType where possible
benfred Jul 22, 2024
f653059
Revert "use cuvs::distance::DistanceType where possible"
benfred Sep 6, 2024
a6f2c2a
use stats from raft
benfred Sep 18, 2024
c2c9c04
Merge remote-tracking branch 'origin/branch-24.10' into cuvs
benfred Sep 18, 2024
95b9c14
use ivf-* from cuvs
benfred Sep 20, 2024
1934d40
Merge remote-tracking branch 'origin/branch-24.10' into cuvs
benfred Sep 20, 2024
0601412
Merge branch 'rapidsai:branch-24.10' into cuvs
benfred Sep 26, 2024
f30f933
add libcuvs to dependencies.yaml
benfred Sep 26, 2024
69d2398
Merge branch 'cuvs' of https://github.com/benfred/cuml into cuvs
benfred Sep 26, 2024
6c46624
attempt to fix build error in CI
benfred Sep 26, 2024
3fb5479
Merge branch 'branch-24.10' into cuvs
benfred Sep 26, 2024
5380c17
fix tsne
benfred Sep 27, 2024
dc1f7a6
Merge branch 'cuvs' of https://github.com/benfred/cuml into cuvs
benfred Sep 27, 2024
fdd18c5
fix ivf-fla
benfred Sep 27, 2024
590865d
fix test_nearest_neighbors_rbc test for haversine distance
benfred Sep 27, 2024
c7d1b0e
re-add MetricProcessor code
benfred Sep 27, 2024
e8c1b18
suggestions from code review
benfred Sep 27, 2024
bd58347
fix dask pytests
benfred Sep 29, 2024
adb450a
attempt to fix python build errors in CI
benfred Sep 29, 2024
df47d3c
use raft in header only mode
benfred Sep 29, 2024
c72173c
Use kmeans/mutual reachability code from pending cuvs PR's
benfred Sep 29, 2024
6139db0
remove comment
benfred Sep 30, 2024
3e1c465
cmake fixes
benfred Sep 30, 2024
732ea10
.
benfred Sep 30, 2024
4942bb0
pick up right cuvs version
benfred Sep 30, 2024
fef9920
.
benfred Sep 30, 2024
3c0c47e
empty commit for ci
benfred Sep 30, 2024
d8e6b6d
Exclude libcuvs.so in auditwheel.
bdice Sep 30, 2024
257898e
add cuvs to python dependencies
benfred Oct 1, 2024
9f805cd
Merge branch 'cuvs' of https://github.com/benfred/cuml into cuvs
benfred Oct 1, 2024
176c9a9
use l2expanded distance in kmeans transform
benfred Oct 2, 2024
2122120
Merge branch 'branch-24.10' into cuvs
benfred Oct 2, 2024
c12f1d5
Set rpath for cuvs
KyleFromNVIDIA Oct 2, 2024
f3edb8f
Don't link Python modules against cuvs directly
KyleFromNVIDIA Oct 2, 2024
f906e79
Remove superfluous cuvs::cuvs references
KyleFromNVIDIA Oct 2, 2024
585acd4
Add cuvs rpath
KyleFromNVIDIA Oct 2, 2024
1964a94
remove cuvs pin
benfred Oct 2, 2024
f1db388
Merge branch 'cuvs' of https://github.com/benfred/cuml into cuvs
benfred Oct 2, 2024
259256a
updates to handle bfknn api changes
benfred Oct 3, 2024
4efd97a
link cuvs statically in python wheels
benfred Oct 3, 2024
41daf01
empty commit for ci
benfred Oct 3, 2024
31331d7
empty commit for ci
benfred Oct 3, 2024
6bef472
empty commit for ci
benfred Oct 3, 2024
3e88b8e
remove pin
benfred Oct 3, 2024
064cced
Merge branch 'branch-24.10' into cuvs
cjnolet Oct 3, 2024
3959c58
re-add pin + suggestions from code review
benfred Oct 3, 2024
dc6de84
Merge branch 'cuvs' of https://github.com/benfred/cuml into cuvs
benfred Oct 3, 2024
461a271
fix
benfred Oct 3, 2024
f50b9d2
remove pin
benfred Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cd ${package_dir}
case "${RAPIDS_CUDA_VERSION}" in
12.*)
EXCLUDE_ARGS=(
--exclude "libcuvs.so"
--exclude "libcublas.so.12"
--exclude "libcublasLt.so.12"
--exclude "libcufft.so.11"
Expand All @@ -32,12 +33,14 @@ case "${RAPIDS_CUDA_VERSION}" in
EXTRA_CMAKE_ARGS=";-DUSE_CUDA_MATH_WHEELS=ON"
;;
11.*)
EXCLUDE_ARGS=()
EXCLUDE_ARGS=(
--exclude "libcuvs.so"
)
EXTRA_CMAKE_ARGS=";-DUSE_CUDA_MATH_WHEELS=OFF"
;;
esac

SKBUILD_CMAKE_ARGS="-DDETECT_CONDA_ENV=OFF;-DDISABLE_DEPRECATION_WARNINGS=ON;-DCPM_cumlprims_mg_SOURCE=${GITHUB_WORKSPACE}/cumlprims_mg/${EXTRA_CMAKE_ARGS}" \
SKBUILD_CMAKE_ARGS="-DDETECT_CONDA_ENV=OFF;-DDISABLE_DEPRECATION_WARNINGS=ON;-DCPM_cumlprims_mg_SOURCE=${GITHUB_WORKSPACE}/cumlprims_mg/;-DUSE_CUVS_WHEEL=ON${EXTRA_CMAKE_ARGS}" \
python -m pip wheel . \
-w dist \
-vvv \
Expand Down
2 changes: 2 additions & 0 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ echo "${NEXT_FULL_TAG}" > VERSION
DEPENDENCIES=(
cudf
cuml
cuvs
dask-cuda
dask-cudf
libcuml
libcuml-tests
libcumlprims
libcuvs
libraft-headers
libraft
librmm
Expand Down
3 changes: 2 additions & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- cudatoolkit
- cudf==24.10.*,>=0.0.0a0
- cupy>=12.0.0
- cuvs==24.10.*,>=0.0.0a0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think cuml needs a dependency on cuvs python today. We could keep this for the release, though, just to avoid additional changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are correct here (we only need cuvs for the requirements/pyproject outputs for wheels). Let's file a follow-up PR once this one merges so we don't lose track. It can target 24.12.

Thankfully, cuvs is only ~350 kB: https://anaconda.org/rapidsai-nightly/cuvs/files

- cxx-compiler
- cython>=3.0.0
- dask-cuda==24.10.*,>=0.0.0a0
Expand All @@ -39,8 +40,8 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libcuvs==24.10.*,>=0.0.0a0
- libraft-headers==24.10.*,>=0.0.0a0
- libraft==24.10.*,>=0.0.0a0
- librmm==24.10.*,>=0.0.0a0
- nbsphinx
- ninja
Expand Down
3 changes: 2 additions & 1 deletion conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- cuda-version=12.5
- cudf==24.10.*,>=0.0.0a0
- cupy>=12.0.0
- cuvs==24.10.*,>=0.0.0a0
- cxx-compiler
- cython>=3.0.0
- dask-cuda==24.10.*,>=0.0.0a0
Expand All @@ -36,8 +37,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==24.10.*,>=0.0.0a0
- libraft-headers==24.10.*,>=0.0.0a0
- libraft==24.10.*,>=0.0.0a0
- librmm==24.10.*,>=0.0.0a0
- nbsphinx
- ninja
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/clang_tidy_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libcuvs==24.10.*,>=0.0.0a0
- libraft-headers==24.10.*,>=0.0.0a0
- libraft==24.10.*,>=0.0.0a0
- librmm==24.10.*,>=0.0.0a0
- ninja
- nvcc_linux-64=11.8
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/cpp_all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libcuvs==24.10.*,>=0.0.0a0
- libraft-headers==24.10.*,>=0.0.0a0
- libraft==24.10.*,>=0.0.0a0
- librmm==24.10.*,>=0.0.0a0
- ninja
- nvcc_linux-64=11.8
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/cpp_all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==24.10.*,>=0.0.0a0
- libraft-headers==24.10.*,>=0.0.0a0
- libraft==24.10.*,>=0.0.0a0
- librmm==24.10.*,>=0.0.0a0
- ninja
- spdlog>=1.14.1,<1.15
Expand Down
4 changes: 2 additions & 2 deletions conda/recipes/libcuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ requirements:
{% endif %}
- fmt {{ fmt_version }}
- libcumlprims ={{ minor_version }}
- libraft ={{ minor_version }}
- libcuvs ={{ minor_version }}
- libraft-headers ={{ minor_version }}
- librmm ={{ minor_version }}
- spdlog {{ spdlog_version }}
Expand Down Expand Up @@ -116,7 +116,7 @@ outputs:
- libcusparse
{% endif %}
- libcumlprims ={{ minor_version }}
- libraft ={{ minor_version }}
- libcuvs ={{ minor_version }}
- librmm ={{ minor_version }}
- treelite {{ treelite_version }}
about:
Expand Down
21 changes: 5 additions & 16 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ option(SINGLEGPU "Disable all mnmg components and comms libraries" OFF)
option(USE_CCACHE "Cache build artifacts with ccache" OFF)
option(CUDA_STATIC_RUNTIME "Statically link the CUDA runtime" OFF)
option(CUDA_STATIC_MATH_LIBRARIES "Statically link the CUDA math libraries" OFF)
option(CUML_USE_RAFT_STATIC "Build and statically link the RAFT libraries" OFF)
option(CUML_RAFT_COMPILED "Use libraft shared library" ON)
option(CUML_USE_CUVS_STATIC "Build and statically link the CUVS library" OFF)
option(CUML_USE_TREELITE_STATIC "Build and statically link the treelite library" OFF)
option(CUML_EXPORT_TREELITE_LINKAGE "Whether to publicly or privately link treelite to libcuml++" OFF)
option(CUML_USE_CUMLPRIMS_MG_STATIC "Build and statically link the cumlprims_mg library" OFF)
Expand All @@ -78,6 +77,7 @@ option(CUML_EXCLUDE_RAFT_FROM_ALL "Exclude RAFT targets from cuML's 'all' target
option(CUML_EXCLUDE_TREELITE_FROM_ALL "Exclude Treelite targets from cuML's 'all' target" OFF)
option(CUML_EXCLUDE_CUMLPRIMS_MG_FROM_ALL "Exclude cumlprims_mg targets from cuML's 'all' target" OFF)
option(CUML_RAFT_CLONE_ON_PIN "Explicitly clone RAFT branch when pinned to non-feature branch" ON)
KyleFromNVIDIA marked this conversation as resolved.
Show resolved Hide resolved
option(CUML_CUVS_CLONE_ON_PIN "Explicitly clone CUVS branch when pinned to non-feature branch" ON)

message(VERBOSE "CUML_CPP: Building libcuml_c shared library. Contains the cuML C API: ${BUILD_CUML_C_LIBRARY}")
message(VERBOSE "CUML_CPP: Building libcuml shared library: ${BUILD_CUML_CPP_LIBRARY}")
Expand All @@ -98,7 +98,7 @@ message(VERBOSE "CUML_CPP: Disabling all mnmg components and comms libraries: ${
message(VERBOSE "CUML_CPP: Cache build artifacts with ccache: ${USE_CCACHE}")
message(VERBOSE "CUML_CPP: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTIME}")
message(VERBOSE "CUML_CPP: Statically link the CUDA math libraries: ${CUDA_STATIC_MATH_LIBRARIES}")
message(VERBOSE "CUML_CPP: Build and statically link RAFT libraries: ${CUML_USE_RAFT_STATIC}")
message(VERBOSE "CUML_CPP: Build and statically link CUVS libraries: ${CUML_USE_CUVS_STATIC}")
message(VERBOSE "CUML_CPP: Build and statically link Treelite library: ${CUML_USE_TREELITE_STATIC}")

set(CUML_ALGORITHMS "ALL" CACHE STRING "Experimental: Choose which algorithms are built into libcuml++.so. Can specify individual algorithms or groups in a semicolon-separated list.")
Expand Down Expand Up @@ -228,6 +228,7 @@ endif()
include(cmake/thirdparty/get_cccl.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_raft.cmake)
include(cmake/thirdparty/get_cuvs.cmake)

if(LINK_TREELITE)
include(cmake/thirdparty/get_treelite.cmake)
Expand Down Expand Up @@ -442,18 +443,6 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/kl_divergence.cu
src/metrics/mutual_info_score.cu
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_correlation.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hamming.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_jensen_shannon.cu
src/metrics/pairwise_distance_kl_divergence.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/pairwise_distance_russell_rao.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down Expand Up @@ -635,7 +624,7 @@ if(BUILD_CUML_CPP_LIBRARY)
)

target_link_libraries(${CUML_CPP_TARGET}
PUBLIC rmm::rmm
PUBLIC rmm::rmm ${CUVS_LIB}
${_cuml_cpp_public_libs}
PRIVATE ${_cuml_cpp_private_libs}
)
Expand Down
1 change: 0 additions & 1 deletion cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ if(BUILD_CUML_BENCH)
benchmark::benchmark
${TREELITE_LIBS}
raft::raft
raft::compiled
)

target_include_directories(${CUML_CPP_BENCH_TARGET}
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/sg/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ std::vector<Params> getInputs()
p.kmeans.max_iter = 300;
p.kmeans.tol = 1e-4;
p.kmeans.verbosity = RAFT_LEVEL_INFO;
p.kmeans.metric = raft::distance::DistanceType::L2Expanded;
p.kmeans.metric = cuvs::distance::DistanceType::L2Expanded;
p.kmeans.rng_state = raft::random::RngState(p.blobs.seed);
p.kmeans.inertia_check = true;
std::vector<std::pair<int, int>> rowcols = {
Expand Down
77 changes: 77 additions & 0 deletions cpp/cmake/thirdparty/get_cuvs.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#=============================================================================
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=============================================================================

set(CUML_MIN_VERSION_cuvs "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}.00")
set(CUML_BRANCH_VERSION_cuvs "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}")

function(find_and_configure_cuvs)
set(oneValueArgs VERSION FORK PINNED_TAG EXCLUDE_FROM_ALL USE_CUVS_STATIC COMPILE_LIBRARY CLONE_ON_PIN)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_cuvs}")
message(STATUS "CUML: CUVS pinned tag found: ${PKG_PINNED_TAG}. Cloning cuvs locally.")
set(CPM_DOWNLOAD_cuvs ON)
elseif(PKG_USE_CUVS_STATIC AND (NOT CPM_cuvs_SOURCE))
message(STATUS "CUML: Cloning cuvs locally to build static libraries.")
set(CPM_DOWNLOAD_cuvs ON)
else()
message(STATUS "Not cloning cuvs locally")
endif()

if(PKG_USE_CUVS_STATIC)
set(CUVS_LIB cuvs::cuvs_static PARENT_SCOPE)
else()
set(CUVS_LIB cuvs::cuvs PARENT_SCOPE)
endif()

rapids_cpm_find(cuvs ${PKG_VERSION}
GLOBAL_TARGETS cuvs::cuvs
BUILD_EXPORT_SET cuml-exports
INSTALL_EXPORT_SET cuml-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/cuvs.git
GIT_TAG ${PKG_PINNED_TAG}
SOURCE_SUBDIR cpp
EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}
OPTIONS
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
)

if(cuvs_ADDED)
message(VERBOSE "CUML: Using CUVS located in ${cuvs_SOURCE_DIR}")
else()
message(VERBOSE "CUML: Using CUVS located in ${cuvs_DIR}")
endif()


endfunction()

# Change pinned tag here to test a commit in CI
# To use a different CUVS locally, set the CMake variable
# CPM_cuvs_SOURCE=/path/to/local/cuvs
find_and_configure_cuvs(VERSION ${CUML_MIN_VERSION_cuvs}
FORK benfred
PINNED_TAG static_lib2
benfred marked this conversation as resolved.
Show resolved Hide resolved
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_CUVS_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local cuvs clone in build directory
# even if it's already installed.
CLONE_ON_PIN ${CUML_CUVS_CLONE_ON_PIN}
COMPILE_LIBRARY ${CUML_CUVS_COMPILED}
USE_CUVS_STATIC ${CUML_USE_CUVS_STATIC}
)
12 changes: 1 addition & 11 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,6 @@ function(find_and_configure_raft)
string(APPEND RAFT_COMPONENTS " distributed")
endif()

if(PKG_COMPILE_LIBRARY)
if(NOT PKG_USE_RAFT_STATIC)
string(APPEND RAFT_COMPONENTS " compiled")
set(RAFT_COMPILED_LIB raft::compiled PARENT_SCOPE)
else()
string(APPEND RAFT_COMPONENTS " compiled_static")
set(RAFT_COMPILED_LIB raft::compiled_static PARENT_SCOPE)
endif()
endif()

# We need to set this each time so that on subsequent calls to cmake
# the raft-config.cmake re-evaluates the RAFT_NVTX value
set(RAFT_NVTX ${PKG_NVTX})
Expand All @@ -66,7 +56,7 @@ function(find_and_configure_raft)
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"BUILD_CAGRA_HNSWLIB OFF"
"RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}"
"RAFT_COMPILE_LIBRARY OFF"
)

if(raft_ADDED)
Expand Down
2 changes: 1 addition & 1 deletion cpp/examples/kmeans/kmeans_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ int main(int argc, char* argv[])
params.max_iter = 300;
params.tol = 0.05;
}
params.metric = raft::distance::DistanceType::L2SqrtExpanded;
params.metric = cuvs::distance::DistanceType::L2SqrtExpanded;
params.init = ML::kmeans::KMeansParams::InitMethod::Random;

// Inputs copied from kmeans_test.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuml/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cuml/common/log_levels.hpp>

#include <raft/cluster/kmeans_types.hpp>
#include <cuvs/cluster/kmeans.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!


namespace raft {
class handle_t;
Expand All @@ -28,7 +28,7 @@ namespace ML {

namespace kmeans {

using KMeansParams = raft::cluster::KMeansParams;
using KMeansParams = cuvs::cluster::kmeans::params;

/**
* @brief Compute k-means clustering and predicts cluster index for each sample
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/cuml/cluster/kmeans_mg.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,7 +48,7 @@ namespace opg {
* @param[out] n_iter Number of iterations run.
*/

void fit(const raft::handle_t& handle,
void fit(const raft::resources& handle,
const KMeansParams& params,
const float* X,
int n_samples,
Expand All @@ -58,7 +58,7 @@ void fit(const raft::handle_t& handle,
float& inertia,
int& n_iter);

void fit(const raft::handle_t& handle,
void fit(const raft::resources& handle,
const KMeansParams& params,
const double* X,
int n_samples,
Expand All @@ -68,7 +68,7 @@ void fit(const raft::handle_t& handle,
double& inertia,
int& n_iter);

void fit(const raft::handle_t& handle,
void fit(const raft::resources& handle,
const KMeansParams& params,
const float* X,
int64_t n_samples,
Expand All @@ -78,7 +78,7 @@ void fit(const raft::handle_t& handle,
float& inertia,
int64_t& n_iter);

void fit(const raft::handle_t& handle,
void fit(const raft::resources& handle,
const KMeansParams& params,
const double* X,
int64_t n_samples,
Expand Down
Loading
Loading