Skip to content

Commit

Permalink
Improvements in matrix::gather: test coverage, compilation errors, …
Browse files Browse the repository at this point in the history
…performance (#1126)

In order to deprecate `copy_selected` from `ann_utils.cuh`, I wanted to make sure that the performance of `matrix::gather` was on par. But in the process I discovered that:

- Map transforms and conditional copy were not tested at all.
- In fact, most of the API in `gather.cuh` wasn't covered in tests and some of the functions didn't even compile.
- The same type `MatrixIteratorT` was used for the input and output iterators, which made it impossible to take advantage of custom iterators, as is needed in `kmeans_balanced` to convert the dataset from `T` to `float` and gather in a single step.
- The performance was really poor when `D` is small because the kernel assigns one block per row (so a block could be working on only 2 or 3 elements...)

This PR addresses all the aforementioned issues.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1126
  • Loading branch information
Nyrio authored Jan 21, 2023
1 parent a9e1adc commit 0e96662
Show file tree
Hide file tree
Showing 8 changed files with 578 additions and 410 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if(BUILD_BENCH)
bench/main.cpp
)

ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp)
ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp)

ConfigureBench(
NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu
Expand Down
17 changes: 7 additions & 10 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,11 @@
#include <common/benchmark.hpp>
#include <raft/matrix/argmin.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {
namespace raft::bench::matrix {

template <typename IdxT>
struct ArgminParams {
Expand Down Expand Up @@ -57,15 +58,11 @@ struct Argmin : public fixture {
raft::device_vector<OutT, IdxT> indices;
}; // struct Argmin

const std::vector<ArgminParams<int64_t>> argmin_inputs_i64{
{1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024},
{10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024},
{100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024},
{1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024},
{10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024},
};
const std::vector<ArgminParams<int64_t>> argmin_inputs_i64 =
raft::util::itertools::product<ArgminParams<int64_t>>({1000, 10000, 100000, 1000000, 10000000},
{64, 128, 256, 512, 1024});

RAFT_BENCH_REGISTER((Argmin<float, uint32_t, int64_t>), "", argmin_inputs_i64);
RAFT_BENCH_REGISTER((Argmin<double, uint32_t, int64_t>), "", argmin_inputs_i64);

} // namespace raft::bench::linalg
} // namespace raft::bench::matrix
101 changes: 101 additions & 0 deletions cpp/bench/matrix/gather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::matrix {

template <typename IdxT>
struct GatherParams {
IdxT rows, cols, map_length;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const GatherParams<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols << "#" << p.map_length;
return os;
}

template <typename T, typename MapT, typename IdxT, bool Conditional = false>
struct Gather : public fixture {
Gather(const GatherParams<IdxT>& p) : params(p) {}

void allocate_data(const ::benchmark::State& state) override
{
matrix = raft::make_device_matrix<T, IdxT>(handle, params.rows, params.cols);
map = raft::make_device_vector<MapT, IdxT>(handle, params.map_length);
out = raft::make_device_matrix<T, IdxT>(handle, params.map_length, params.cols);
stencil = raft::make_device_vector<T, IdxT>(handle, Conditional ? params.map_length : IdxT(0));

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
raft::random::uniformInt(
handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows);
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
}
handle.sync_stream(stream);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
auto map_const_view =
raft::make_device_vector_view<const MapT, IdxT>(map.data_handle(), map.extent(0));
if constexpr (Conditional) {
auto stencil_const_view =
raft::make_device_vector_view<const T, IdxT>(stencil.data_handle(), stencil.extent(0));
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
raft::matrix::gather_if(
handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op);
} else {
raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view());
}
});
}

private:
GatherParams<IdxT> params;
raft::device_matrix<T, IdxT> matrix, out;
raft::device_vector<T, IdxT> stencil;
raft::device_vector<MapT, IdxT> map;
}; // struct Gather

template <typename T, typename MapT, typename IdxT>
using GatherIf = Gather<T, MapT, IdxT, true>;

const std::vector<GatherParams<int64_t>> gather_inputs_i64 =
raft::util::itertools::product<GatherParams<int64_t>>(
{1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000});

RAFT_BENCH_REGISTER((Gather<float, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((Gather<double, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((GatherIf<float, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((GatherIf<double, uint32_t, int64_t>), "", gather_inputs_i64);
} // namespace raft::bench::matrix
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void shuffleAndGather(const raft::handle_t& handle,
in.extent(1),
in.extent(0),
indices.data_handle(),
n_samples_to_gather,
static_cast<IndexT>(n_samples_to_gather),
out.data_handle(),
stream);
}
Expand Down
51 changes: 47 additions & 4 deletions cpp/include/raft/core/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ struct div_checkzero_op {
}
};

struct modulo_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a % b;
}
};

struct pow_op {
template <typename Type>
RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
Expand Down Expand Up @@ -189,17 +197,49 @@ struct argmax_op {
}
};

struct greater_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a > b;
}
};

struct less_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a < b;
}
};

struct greater_or_equal_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a >= b;
}
};

struct less_or_equal_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a <= b;
}
};

struct equal_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a == b;
}
};

struct notequal_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a != b;
}
Expand Down Expand Up @@ -267,6 +307,9 @@ using div_const_op = plug_const_op<Type, div_op>;
template <typename Type>
using div_checkzero_const_op = plug_const_op<Type, div_checkzero_op>;

template <typename Type>
using modulo_const_op = plug_const_op<Type, modulo_op>;

template <typename Type>
using pow_const_op = plug_const_op<Type, pow_op>;

Expand Down
Loading

0 comments on commit 0e96662

Please sign in to comment.