Skip to content

Commit

Permalink
Fix conflicts with branch-23.12
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Oct 10, 2023
2 parents bdc8d9a + f979607 commit 9d56f32
Show file tree
Hide file tree
Showing 50 changed files with 2,939 additions and 138 deletions.
5 changes: 4 additions & 1 deletion conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- matplotlib
- nccl>=2.9.9
- ninja
- nlohmann_json>=3.11.2
- nvcc_linux-64=11.8
- openblas
- rmm=23.12.*
- pandas
- pyyaml
- rmm==23.12.*
- scikit-build>=0.13.1
- sysroot_linux-64==2.17
name: bench_ann_cuda-118_arch-x86_64
1 change: 1 addition & 0 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ void parse_build_param(const nlohmann::json& conf,
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

template <typename T, typename IdxT>
Expand Down
6 changes: 5 additions & 1 deletion cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function(ConfigureBench)
PRIVATE raft::raft
raft_internal
$<$<BOOL:${ConfigureBench_LIB}>:raft::compiled>
${RAFT_CTK_MATH_DEPENDENCIES}
benchmark::benchmark
Threads::Threads
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
Expand Down Expand Up @@ -73,11 +74,14 @@ function(ConfigureBench)
endfunction()

if(BUILD_PRIMS_BENCH)
ConfigureBench(
NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/core/copy.cu bench/prims/main.cpp
)

ConfigureBench(
NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu
bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)
ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ struct KMeansBalanced : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
} else {
raft::random::uniform(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
}
resource::sync_stream(handle, stream);
}
Expand Down
401 changes: 401 additions & 0 deletions cpp/bench/prims/core/copy.cu

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions cpp/bench/prims/distance/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ struct GramMatrix : public fixture {
A.resize(params.m * params.k, stream);
B.resize(params.k * params.n, stream);
C.resize(params.m * params.n, stream);
raft::random::Rng r(123456ULL);
r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream);
raft::random::RngState rng(123456ULL);
raft::random::uniform(handle, rng, A.data(), params.m * params.k, T(-1.0), T(1.0));
raft::random::uniform(handle, rng, B.data(), params.k * params.n, T(-1.0), T(1.0));
}

~GramMatrix()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct rowNorm : public fixture {
rowNorm(const norm_input<IdxT>& p) : params(p), in(p.rows * p.cols, stream), dots(p.rows, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/normalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct rowNormalize : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_cols_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct reduce_cols_by_key : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_rows_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct reduce_rows_by_key : public fixture {
workspace(p.rows, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct Argmin : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
resource::sync_stream(handle, stream);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ struct Gather : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
raft::random::uniformInt(
handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows);
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1));
}
resource::sync_stream(handle, stream);
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ struct CagraBench : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniformInt(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
} else {
raft::random::uniform(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniform(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
}

// Generate random knn graph

raft::random::uniformInt<IdxT>(
state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1, stream);
handle, state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1);

auto metric = raft::distance::DistanceType::L2Expanded;

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ struct knn : public fixture {
constexpr T kRangeMax = std::is_integral_v<T> ? std::numeric_limits<T>::max() : T(1);
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniformInt(handle, state, vec.data(), n, kRangeMin, kRangeMax);
} else {
raft::random::uniform(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniform(handle, state, vec.data(), n, kRangeMin, kRangeMax);
}
}

Expand Down
74 changes: 74 additions & 0 deletions cpp/include/raft/core/copy.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/detail/copy.hpp>
namespace raft {
/**
* @brief Copy data from one mdspan to another with the same extents
*
* This function copies data from one mdspan to another, regardless of whether
* or not the mdspans have the same layout, memory type (host/device/managed)
* or data type. So long as it is possible to convert the data type from source
* to destination, and the extents are equal, this function should be able to
* perform the copy. Any necessary device operations will be stream-ordered via the CUDA stream
* provided by the `raft::resources` argument.
*
* This header includes a custom kernel used for copying data between
* completely arbitrary mdspans on device. To compile this function in a
* non-CUDA translation unit, `raft/core/copy.hpp` may be used instead. The
* pure C++ header will correctly compile even without a CUDA compiler.
* Depending on the specialization, this CUDA header may invoke the kernel and
* therefore require a CUDA compiler.
*
* Limitations: Currently this function does not support copying directly
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the
* underlying memory layout are currently not performant, although they are supported.
*
* Note that when copying to an mdspan with a non-unique layout (i.e. the same
* underlying memory is addressed by different element indexes), the source
* data must contain non-unique values for every non-unique destination
* element. If this is not the case, the behavior is undefined. Some copies
* to non-unique layouts which are well-defined will nevertheless fail with an
* exception to avoid race conditions in the underlying copy.
*
* @tparam DstType An mdspan type for the destination container.
* @tparam SrcType An mdspan type for the source container
* @param res raft::resources used to provide a stream for copies involving the
* device.
* @param dst The destination mdspan.
* @param src The source mdspan.
*/
template <typename DstType, typename SrcType>
detail::mdspan_copyable_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}

#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED
#define RAFT_NON_CUDA_COPY_IMPLEMENTED
template <typename DstType, typename SrcType>
detail::mdspan_copyable_not_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}
#endif
} // namespace raft
69 changes: 69 additions & 0 deletions cpp/include/raft/core/copy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/detail/copy.hpp>
namespace raft {

#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED
#define RAFT_NON_CUDA_COPY_IMPLEMENTED
/**
* @brief Copy data from one mdspan to another with the same extents
*
* This function copies data from one mdspan to another, regardless of whether
* or not the mdspans have the same layout, memory type (host/device/managed)
* or data type. So long as it is possible to convert the data type from source
* to destination, and the extents are equal, this function should be able to
* perform the copy.
*
* This header does _not_ include the custom kernel used for copying data
* between completely arbitrary mdspans on device. For arbitrary copies of this
* kind, `#include <raft/core/copy.cuh>` instead. Specializations of this
* function that require the custom kernel will be SFINAE-omitted when this
* header is used instead of `copy.cuh`. This header _does_ support
* device-to-device copies that can be performed with cuBLAS or a
* straightforward cudaMemcpy. Any necessary device operations will be stream-ordered via the CUDA
* stream provided by the `raft::resources` argument.
*
* Limitations: Currently this function does not support copying directly
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the
* underlying memory layout are currently not performant, although they are supported.
*
* Note that when copying to an mdspan with a non-unique layout (i.e. the same
* underlying memory is addressed by different element indexes), the source
* data must contain non-unique values for every non-unique destination
* element. If this is not the case, the behavior is undefined. Some copies
* to non-unique layouts which are well-defined will nevertheless fail with an
* exception to avoid race conditions in the underlying copy.
*
* @tparam DstType An mdspan type for the destination container.
* @tparam SrcType An mdspan type for the source container
* @param res raft::resources used to provide a stream for copies involving the
* device.
* @param dst The destination mdspan.
* @param src The source mdspan.
*/
template <typename DstType, typename SrcType>
detail::mdspan_copyable_not_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}
#endif

} // namespace raft
23 changes: 23 additions & 0 deletions cpp/include/raft/core/cuda_support.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace raft {
#ifndef RAFT_DISABLE_CUDA
auto constexpr static const CUDA_ENABLED = true;
#else
auto constexpr static const CUDA_ENABLED = false;
#endif
} // namespace raft
Loading

0 comments on commit 9d56f32

Please sign in to comment.