Skip to content

Commit

Permalink
Generic linalg::map (rapidsai#1329)
Browse files Browse the repository at this point in the history
Update the implementation behind `raft::linalg::map` and `raft::linalg::map_offset` to allow multiple inputs and optional index.

Originally, this is a part of the effort to reduce the latency of ivf-pq search. The new implementation replaces several helpers, which have been using thrust; at the moment, raft uses a thrust policy that occasionally inserts extra `cudaStreamSynchronize`, and this negatively affects the latency on small inputs.

The new implementation is generic enough to replace many raft's utility functions. It uses vectorized load/stores if possible, which improves performance.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1329
  • Loading branch information
achirkin authored and lowener committed Mar 15, 2023
1 parent 291932a commit 5f497bd
Show file tree
Hide file tree
Showing 14 changed files with 470 additions and 580 deletions.
26 changes: 4 additions & 22 deletions cpp/include/raft/linalg/binary_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@

#pragma once

#include "detail/binary_op.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/input_validation.hpp>
#include <raft/linalg/map.cuh>

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -52,7 +49,7 @@ template <typename InType,
void binaryOp(
OutType* out, const InType* in1, const InType* in2, IdxType len, Lambda op, cudaStream_t stream)
{
detail::binaryOp(out, in1, in2, len, op, stream);
return detail::map<false>(stream, out, len, op, in1, in2);
}

/**
Expand Down Expand Up @@ -80,27 +77,12 @@ template <typename InType,
typename = raft::enable_if_output_device_mdspan<OutType>>
void binary_op(raft::device_resources const& handle, InType in1, InType in2, OutType out, Lambda op)
{
RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
binaryOp<in_value_t, Lambda, out_value_t, std::uint32_t>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
} else {
binaryOp<in_value_t, Lambda, out_value_t, std::uint64_t>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
}
return map(handle, in1, in2, out, op);
}

/** @} */ // end of group binary_op

}; // end namespace linalg
}; // end namespace raft

#endif
#endif
98 changes: 0 additions & 98 deletions cpp/include/raft/linalg/detail/binary_op.cuh

This file was deleted.

54 changes: 0 additions & 54 deletions cpp/include/raft/linalg/detail/init.hpp

This file was deleted.

Loading

0 comments on commit 5f497bd

Please sign in to comment.