Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements in matrix::gather: test coverage, compilation errors, performance #1126

Merged
merged 13 commits into from
Jan 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -102,7 +102,7 @@ if(BUILD_BENCH)
bench/main.cpp
)

ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp)
ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp)

ConfigureBench(
NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu
Expand Down
17 changes: 7 additions & 10 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,11 @@
#include <common/benchmark.hpp>
#include <raft/matrix/argmin.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {
namespace raft::bench::matrix {

template <typename IdxT>
struct ArgminParams {
Expand Down Expand Up @@ -57,15 +58,11 @@ struct Argmin : public fixture {
raft::device_vector<OutT, IdxT> indices;
}; // struct Argmin

const std::vector<ArgminParams<int64_t>> argmin_inputs_i64{
{1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024},
{10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024},
{100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024},
{1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024},
{10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024},
};
const std::vector<ArgminParams<int64_t>> argmin_inputs_i64 =
raft::util::itertools::product<ArgminParams<int64_t>>({1000, 10000, 100000, 1000000, 10000000},
{64, 128, 256, 512, 1024});

RAFT_BENCH_REGISTER((Argmin<float, uint32_t, int64_t>), "", argmin_inputs_i64);
RAFT_BENCH_REGISTER((Argmin<double, uint32_t, int64_t>), "", argmin_inputs_i64);

} // namespace raft::bench::linalg
} // namespace raft::bench::matrix
101 changes: 101 additions & 0 deletions cpp/bench/matrix/gather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::matrix {

template <typename IdxT>
struct GatherParams {
IdxT rows, cols, map_length;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const GatherParams<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols << "#" << p.map_length;
return os;
}

template <typename T, typename MapT, typename IdxT, bool Conditional = false>
struct Gather : public fixture {
Gather(const GatherParams<IdxT>& p) : params(p) {}

void allocate_data(const ::benchmark::State& state) override
{
matrix = raft::make_device_matrix<T, IdxT>(handle, params.rows, params.cols);
map = raft::make_device_vector<MapT, IdxT>(handle, params.map_length);
out = raft::make_device_matrix<T, IdxT>(handle, params.map_length, params.cols);
stencil = raft::make_device_vector<T, IdxT>(handle, Conditional ? params.map_length : IdxT(0));

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
raft::random::uniformInt(
handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows);
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
}
handle.sync_stream(stream);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
auto map_const_view =
raft::make_device_vector_view<const MapT, IdxT>(map.data_handle(), map.extent(0));
if constexpr (Conditional) {
auto stencil_const_view =
raft::make_device_vector_view<const T, IdxT>(stencil.data_handle(), stencil.extent(0));
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
raft::matrix::gather_if(
handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op);
} else {
raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view());
}
});
}

private:
GatherParams<IdxT> params;
raft::device_matrix<T, IdxT> matrix, out;
raft::device_vector<T, IdxT> stencil;
raft::device_vector<MapT, IdxT> map;
}; // struct Gather

template <typename T, typename MapT, typename IdxT>
using GatherIf = Gather<T, MapT, IdxT, true>;

const std::vector<GatherParams<int64_t>> gather_inputs_i64 =
raft::util::itertools::product<GatherParams<int64_t>>(
{1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000});

RAFT_BENCH_REGISTER((Gather<float, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((Gather<double, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((GatherIf<float, uint32_t, int64_t>), "", gather_inputs_i64);
RAFT_BENCH_REGISTER((GatherIf<double, uint32_t, int64_t>), "", gather_inputs_i64);
} // namespace raft::bench::matrix
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void shuffleAndGather(const raft::handle_t& handle,
in.extent(1),
in.extent(0),
indices.data_handle(),
n_samples_to_gather,
static_cast<IndexT>(n_samples_to_gather),
out.data_handle(),
stream);
}
Expand Down
43 changes: 43 additions & 0 deletions cpp/include/raft/core/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ struct div_checkzero_op {
}
};

struct modulo_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a % b;
}
};

struct pow_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
Expand Down Expand Up @@ -190,6 +198,38 @@ struct argmax_op {
}
};

struct greater_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
{
return a > b;
}
};

struct less_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
{
return a < b;
}
};

struct greater_or_equal_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
{
return a >= b;
}
};

struct less_or_equal_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
{
return a <= b;
}
};

struct equal_op {
template <typename Type>
constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const
Expand Down Expand Up @@ -268,6 +308,9 @@ using div_const_op = plug_const_op<Type, div_op>;
template <typename Type>
using div_checkzero_const_op = plug_const_op<Type, div_checkzero_op>;

template <typename Type>
using modulo_const_op = plug_const_op<Type, modulo_op>;

template <typename Type>
using pow_const_op = plug_const_op<Type, pow_op>;

Expand Down
Loading