Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing Overloads for GCC 11/12 bug #995

Merged
merged 6 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/include/raft/linalg/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ void rsvd_perc(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 4>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void rsvd_perc(Args... args)
{
rsvd_perc(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down Expand Up @@ -560,7 +560,7 @@ void rsvd_perc_symmetric(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc_symmetric`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 4>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void rsvd_perc_symmetric(Args... args)
{
rsvd_perc_symmetric(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down Expand Up @@ -634,7 +634,7 @@ void rsvd_perc_jacobi(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc_jacobi`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 6>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 7>>
void rsvd_perc_jacobi(Args... args)
{
rsvd_perc_jacobi(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down Expand Up @@ -709,7 +709,7 @@ void rsvd_perc_symmetric_jacobi(
*
* Please see above for documentation of `rsvd_perc_symmetric_jacobi`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 6>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 7>>
void rsvd_perc_symmetric_jacobi(Args... args)
{
rsvd_perc_symmetric_jacobi(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void svd_qr(const raft::handle_t& handle,
*
* Please see above for documentation of `svd_qr`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 2>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr(Args... args)
{
svd_qr(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down Expand Up @@ -295,7 +295,7 @@ void svd_qr_transpose_right_vec(
*
* Please see above for documentation of `svd_qr_transpose_right_vec`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 2>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr_transpose_right_vec(Args... args)
{
svd_qr_transpose_right_vec(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down
44 changes: 9 additions & 35 deletions cpp/include/raft/matrix/col_wise_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ void sort_cols_per_row(const InType* in,
* @param[in] handle: raft handle
* @param[in] in: input matrix
* @param[out] out: output value(index) matrix
* @param[out] sorted_keys: Optional, output matrix for sorted keys (input)
* @param[out] sorted_keys_opt: Optional, output matrix for sorted keys (input)
*/
template <typename in_t, typename out_t, typename matrix_idx_t>
template <typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_t>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>
sorted_keys = std::nullopt)
sorted_keys_t&& sorted_keys_opt)
{
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys =
std::forward<sorted_keys_t>(sorted_keys_opt);

RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0),
"Input and output matrices must have the same shape.");

Expand Down Expand Up @@ -109,45 +111,17 @@ void sort_cols_per_row(const raft::handle_t& handle,
}
}

namespace sort_cols_per_row_impl {
template <typename T>
struct sorted_keys_alias {
};

template <>
struct sorted_keys_alias<std::nullopt_t> {
using type = double;
};

template <typename in_t, typename matrix_idx_t>
struct sorted_keys_alias<
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>> {
using type = typename raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>::value_type;
};

template <typename T>
using sorted_keys_t = typename sorted_keys_alias<T>::type;
} // namespace sort_cols_per_row_impl

/**
* @brief Overload of `sort_keys_per_row` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sort_keys_per_row`.
*/
template <typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_vector_type>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
sorted_keys_vector_type sorted_keys)
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void sort_cols_per_row(Args... args)
{
using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t<
std::remove_const_t<std::remove_reference_t<sorted_keys_vector_type>>>;
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys_opt =
std::forward<sorted_keys_vector_type>(sorted_keys);

sort_cols_per_row(handle, in, out, sorted_keys_opt);
sort_cols_per_row(std::forward<Args>(args)..., std::nullopt);
}

}; // end namespace raft::matrix
Expand Down
76 changes: 36 additions & 40 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,25 @@ void laplace(const raft::handle_t& handle,
detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Sample the input vector without replacement, optionally based on the
* input weight vector for each element in the array.
Expand Down Expand Up @@ -730,10 +749,10 @@ void laplace(const raft::handle_t& handle,
* the CUDA stream on which to run.
* @param[inout] rng_state Pseudorandom number generator state.
* @param[in] in Input vector to be sampled.
* @param[in] wts Optional weights vector.
* @param[in] weights_opt Optional weights vector.
* If not provided, uniform sampling will be used.
* @param[out] out Vector of samples from the input vector.
* @param[out] outIdx If provided, vector of the indices
* @param[out] outIdx_opt If provided, vector of the indices
* sampled from the input array.
*
* @pre The number of samples `out.extent(0)`
Expand All @@ -742,14 +761,22 @@ void laplace(const raft::handle_t& handle,
* @pre The number of weights `wts.extent(0)`
* equals the number of inputs `in.extent(0)`.
*/
template <typename DataT, typename IdxT, typename WeightsT>
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts,
WeightsVectorType&& weights_opt,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx)
OutIndexVectorType&& outIdx_opt)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;

std::optional<raft::device_vector_view<const weight_type, IdxT>> wts =
std::forward<WeightsVectorType>(weights_opt);
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx =
std::forward<OutIndexVectorType>(outIdx_opt);

static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
Expand Down Expand Up @@ -777,7 +804,7 @@ void sample_without_replacement(const raft::handle_t& handle,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;
const weight_type* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

detail::sampleWithoutReplacement(rng_state,
out.data_handle(),
Expand All @@ -789,48 +816,17 @@ void sample_without_replacement(const raft::handle_t& handle,
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Overload of `sample_without_replacement` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sample_without_replacement`.
*/
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
WeightsVectorType&& wts,
raft::device_vector_view<DataT, IdxT> out,
OutIndexVectorType&& outIdx)
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void sample_without_replacement(Args... args)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;
std::optional<raft::device_vector_view<const weight_type, IdxT>> weights =
std::forward<WeightsVectorType>(wts);
std::optional<raft::device_vector_view<IdxT, IdxT>> output_indices =
std::forward<OutIndexVectorType>(outIdx);

sample_without_replacement(handle, rng_state, in, weights, out, output_indices);
sample_without_replacement(std::forward<Args>(args)..., std::nullopt);
}

/**
Expand Down
36 changes: 16 additions & 20 deletions cpp/include/raft/stats/contingency_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,25 @@ void contingencyMatrix(const T* groundTruth,
* @param[in] ground_truth: device 1-d array for ground truth (num of rows)
* @param[in] predicted_label: device 1-d array for prediction (num of columns)
* @param[out] out_mat: output buffer for contingency matrix
* @param[in] min_label: Optional, min value in input ground truth array
* @param[in] max_label: Optional, max value in input ground truth array
* @param[in] opt_min_label: Optional, min value in input ground truth array
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] opt_max_label: Optional, max value in input ground truth array
*/
template <typename value_t, typename out_t, typename idx_t, typename layout_t>
template <typename value_t,
typename out_t,
typename idx_t,
typename layout_t,
typename opt_min_label_t,
typename opt_max_label_t>
void contingency_matrix(const raft::handle_t& handle,
raft::device_vector_view<const value_t, idx_t> ground_truth,
raft::device_vector_view<const value_t, idx_t> predicted_label,
raft::device_matrix_view<out_t, idx_t, layout_t> out_mat,
std::optional<value_t> min_label = std::nullopt,
std::optional<value_t> max_label = std::nullopt)
opt_min_label_t&& opt_min_label,
opt_max_label_t&& opt_max_label)
{
std::optional<value_t> min_label = std::forward<opt_min_label_t>(opt_min_label);
std::optional<value_t> max_label = std::forward<opt_max_label_t>(opt_max_label);

RAFT_EXPECTS(ground_truth.size() == predicted_label.size(), "Size mismatch");
RAFT_EXPECTS(ground_truth.is_exhaustive(), "ground_truth must be contiguous");
RAFT_EXPECTS(predicted_label.is_exhaustive(), "predicted_label must be contiguous");
Expand Down Expand Up @@ -188,22 +196,10 @@ void contingency_matrix(const raft::handle_t& handle,
*
* Please see above for documentation of `contingency_matrix`.
*/
template <typename value_t,
typename out_t,
typename idx_t,
typename layout_t,
typename opt_min_label_t,
typename opt_max_label_t>
void contingency_matrix(const raft::handle_t& handle,
raft::device_vector_view<const value_t, idx_t> ground_truth,
raft::device_vector_view<const value_t, idx_t> predicted_label,
raft::device_matrix_view<out_t, idx_t, layout_t> out_mat,
opt_min_label_t&& min_label = std::nullopt,
opt_max_label_t&& max_label = std::nullopt)
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 4>>
void contingency_matrix(Args... args)
{
std::optional<value_t> opt_min_label = std::forward<opt_min_label_t>(min_label);
std::optional<value_t> opt_max_label = std::forward<opt_max_label_t>(max_label);
contingency_matrix(handle, ground_truth, predicted_label, out_mat, opt_min_label, opt_max_label);
contingency_matrix(std::forward<Args>(args)..., std::nullopt, std::nullopt);
}
}; // namespace stats
}; // namespace raft
Expand Down