Skip to content

Commit

Permalink
Merge branch 'branch-23.08' into cagra_pad_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher authored Jun 6, 2023
2 parents 21e2d08 + 64f15b6 commit da1fb53
Show file tree
Hide file tree
Showing 37 changed files with 1,196 additions and 670 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
build_type: pull-request
package-name: raft_dask
# Always want to test against latest dask/distributed.
test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/[email protected]"
test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/[email protected]"
test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/[email protected]"
test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/[email protected]"
test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test"
test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py"
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ jobs:
date: ${{ inputs.date }}
sha: ${{ inputs.sha }}
package-name: raft_dask
test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/[email protected]"
test-before-arm64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/[email protected]"
test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/[email protected]"
test-before-arm64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/[email protected]"
test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test"
6 changes: 3 additions & 3 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ dependencies:
- cupy>=12.0.0
- cxx-compiler
- cython>=0.29,<0.30
- dask-core==2023.3.2
- dask-core>=2023.5.1
- dask-cuda==23.8.*
- dask==2023.3.2
- distributed==2023.3.2.1
- dask>=2023.5.1
- distributed>=2023.5.1
- doxygen>=1.8.20
- gcc_linux-64=11.*
- gmock>=1.13.0
Expand Down
6 changes: 3 additions & 3 deletions conda/recipes/raft-dask/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ requirements:
run:
- {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }}
- cuda-python >=11.7.1,<12.0
- dask ==2023.3.2
- dask-core ==2023.3.2
- dask >=2023.5.1
- dask-core >=2023.5.1
- dask-cuda ={{ minor_version }}
- distributed ==2023.3.2.1
- distributed >=2023.5.1
- joblib >=0.11
- nccl >=2.9.9
- pylibraft {{ version }}
Expand Down
79 changes: 46 additions & 33 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ void distance_impl(raft::resources const& handle,
bool is_row_major,
DataT) // unused
{
ASSERT(
!(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
Expand All @@ -137,9 +135,27 @@ void distance_impl(raft::resources const& handle,
AccT* y_norm = workspace;
AccT* sq_x_norm = workspace;
AccT* sq_y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
sq_x_norm += std::max(m, n);
sq_y_norm = sq_x_norm;
raft::linalg::rowNorm(
sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream);
} else {
y_norm += m;

raft::linalg::reduce(x_norm,
x,
k,
Expand Down Expand Up @@ -167,21 +183,6 @@ void distance_impl(raft::resources const& handle,
sq_y_norm = sq_x_norm + m;
raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream);
} else {
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
sq_x_norm += m;
sq_y_norm = sq_x_norm;
raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
}

using OpT = ops::correlation_distance_op<DataT, AccT, IdxT>;
Expand Down Expand Up @@ -210,23 +211,25 @@ void distance_impl(raft::resources const& handle,
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::rowNorm(
x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
} else {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
} else {
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
}

ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
Expand Down Expand Up @@ -453,21 +456,29 @@ void distance_impl_l2_expanded( // NOTE: different name
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if ((x == y) && is_row_major) {
raft::linalg::rowNorm(x_norm,
x,
k,
std::max(m, n),
raft::linalg::L2Norm,
is_row_major,
stream,
raft::identity_op{});
} else {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
} else {
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
}

ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{perform_sqrt};
Expand Down Expand Up @@ -789,8 +800,10 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1;

if (is_allocated) {
// TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input
// accuracy issue is resolved until then we allocate as m + n.
worksize += numOfBuffers * m * sizeof(AccType);
if (x != y) worksize += numOfBuffers * n * sizeof(AccType);
worksize += numOfBuffers * n * sizeof(AccType);
}

return worksize;
Expand Down
52 changes: 31 additions & 21 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@

#pragma once

#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter_types.hpp> // none_ivf_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::neighbors::ivf_flat::detail {

template <typename T, typename AccT, typename IdxT>
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
Expand All @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \
extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \
T, AccT, IdxT, IvfSampleFilterT) \
extern template void \
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT, IvfSampleFilterT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Loading

0 comments on commit da1fb53

Please sign in to comment.