Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing Overloads for GCC 11/12 bug #995

Merged
merged 6 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions cpp/include/raft/linalg/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ void rsvdPerc(const raft::handle_t& handle,
* @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] S_vec singular values raft::device_vector_view of shape (K)
* @param[in] p no. of upsamples
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -220,9 +220,9 @@ void rsvd_fixed_rank(Args... args)
* @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] S_vec singular values raft::device_vector_view of shape (K)
* @param[in] p no. of upsamples
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -291,9 +291,9 @@ void rsvd_fixed_rank_symmetric(Args... args)
* @param[in] p no. of upsamples
* @param[in] tol tolerance for Jacobi-based solvers
* @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -363,9 +363,9 @@ void rsvd_fixed_rank_jacobi(Args... args)
* @param[in] p no. of upsamples
* @param[in] tol tolerance for Jacobi-based solvers
* @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -435,9 +435,9 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args)
* @param[out] S_vec singular values raft::device_vector_view of shape (K)
* @param[in] PC_perc percentage of singular values to be computed
* @param[in] UpS_perc upsampling percentage
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -490,7 +490,7 @@ void rsvd_perc(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 4>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void rsvd_perc(Args... args)
{
rsvd_perc(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand All @@ -505,9 +505,9 @@ void rsvd_perc(Args... args)
* @param[out] S_vec singular values raft::device_vector_view of shape (K)
* @param[in] PC_perc percentage of singular values to be computed
* @param[in] UpS_perc upsampling percentage
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -560,7 +560,7 @@ void rsvd_perc_symmetric(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc_symmetric`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 4>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void rsvd_perc_symmetric(Args... args)
{
rsvd_perc_symmetric(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand All @@ -577,9 +577,9 @@ void rsvd_perc_symmetric(Args... args)
* @param[in] UpS_perc upsampling percentage
* @param[in] tol tolerance for Jacobi-based solvers
* @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -634,7 +634,7 @@ void rsvd_perc_jacobi(const raft::handle_t& handle,
*
* Please see above for documentation of `rsvd_perc_jacobi`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 6>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 7>>
void rsvd_perc_jacobi(Args... args)
{
rsvd_perc_jacobi(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand All @@ -651,9 +651,9 @@ void rsvd_perc_jacobi(Args... args)
* @param[in] UpS_perc upsampling percentage
* @param[in] tol tolerance for Jacobi-based solvers
* @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major
* @param[out] V_in optional right singular values of raft::device_matrix_view with layout
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with layout
* raft::col_major
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -709,7 +709,7 @@ void rsvd_perc_symmetric_jacobi(
*
* Please see above for documentation of `rsvd_perc_symmetric_jacobi`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 6>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 7>>
void rsvd_perc_symmetric_jacobi(Args... args)
{
rsvd_perc_symmetric_jacobi(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle,
* @param[in] handle raft::handle_t
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in optional right singular values of raft::device_matrix_view with
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -237,7 +237,7 @@ void svd_qr(const raft::handle_t& handle,
*
* Please see above for documentation of `svd_qr`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 2>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr(Args... args)
{
svd_qr(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand All @@ -249,9 +249,9 @@ void svd_qr(Args... args)
* @param[in] handle raft::handle_t
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in optional left singular values of raft::device_matrix_view with layout
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in optional right singular values of raft::device_matrix_view with
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
Expand Down Expand Up @@ -295,7 +295,7 @@ void svd_qr_transpose_right_vec(
*
* Please see above for documentation of `svd_qr_transpose_right_vec`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 2>>
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr_transpose_right_vec(Args... args)
{
svd_qr_transpose_right_vec(std::forward<Args>(args)..., std::nullopt, std::nullopt);
Expand Down
44 changes: 9 additions & 35 deletions cpp/include/raft/matrix/col_wise_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ void sort_cols_per_row(const InType* in,
* @param[in] handle: raft handle
* @param[in] in: input matrix
* @param[out] out: output value(index) matrix
* @param[out] sorted_keys: Optional, output matrix for sorted keys (input)
* @param[out] sorted_keys_opt: std::optional, output matrix for sorted keys (input)
*/
template <typename in_t, typename out_t, typename matrix_idx_t>
template <typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_t>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>
sorted_keys = std::nullopt)
sorted_keys_t&& sorted_keys_opt)
{
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys =
std::forward<sorted_keys_t>(sorted_keys_opt);

RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0),
"Input and output matrices must have the same shape.");

Expand Down Expand Up @@ -109,45 +111,17 @@ void sort_cols_per_row(const raft::handle_t& handle,
}
}

namespace sort_cols_per_row_impl {
template <typename T>
struct sorted_keys_alias {
};

template <>
struct sorted_keys_alias<std::nullopt_t> {
using type = double;
};

template <typename in_t, typename matrix_idx_t>
struct sorted_keys_alias<
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>> {
using type = typename raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>::value_type;
};

template <typename T>
using sorted_keys_t = typename sorted_keys_alias<T>::type;
} // namespace sort_cols_per_row_impl

/**
* @brief Overload of `sort_keys_per_row` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sort_keys_per_row`.
*/
template <typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_vector_type>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
sorted_keys_vector_type sorted_keys)
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void sort_cols_per_row(Args... args)
{
using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t<
std::remove_const_t<std::remove_reference_t<sorted_keys_vector_type>>>;
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys_opt =
std::forward<sorted_keys_vector_type>(sorted_keys);

sort_cols_per_row(handle, in, out, sorted_keys_opt);
sort_cols_per_row(std::forward<Args>(args)..., std::nullopt);
}

}; // end namespace raft::matrix
Expand Down
76 changes: 36 additions & 40 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,25 @@ void laplace(const raft::handle_t& handle,
detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Sample the input vector without replacement, optionally based on the
* input weight vector for each element in the array.
Expand Down Expand Up @@ -730,10 +749,10 @@ void laplace(const raft::handle_t& handle,
* the CUDA stream on which to run.
* @param[inout] rng_state Pseudorandom number generator state.
* @param[in] in Input vector to be sampled.
* @param[in] wts Optional weights vector.
* @param[in] weights_opt std::optional weights vector.
* If not provided, uniform sampling will be used.
* @param[out] out Vector of samples from the input vector.
* @param[out] outIdx If provided, vector of the indices
* @param[out] outIdx_opt std::optional vector of the indices
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* sampled from the input array.
*
* @pre The number of samples `out.extent(0)`
Expand All @@ -742,14 +761,22 @@ void laplace(const raft::handle_t& handle,
* @pre The number of weights `wts.extent(0)`
* equals the number of inputs `in.extent(0)`.
*/
template <typename DataT, typename IdxT, typename WeightsT>
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts,
WeightsVectorType&& weights_opt,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx)
OutIndexVectorType&& outIdx_opt)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;

std::optional<raft::device_vector_view<const weight_type, IdxT>> wts =
std::forward<WeightsVectorType>(weights_opt);
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx =
std::forward<OutIndexVectorType>(outIdx_opt);

static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
Expand Down Expand Up @@ -777,7 +804,7 @@ void sample_without_replacement(const raft::handle_t& handle,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;
const weight_type* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

detail::sampleWithoutReplacement(rng_state,
out.data_handle(),
Expand All @@ -789,48 +816,17 @@ void sample_without_replacement(const raft::handle_t& handle,
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Overload of `sample_without_replacement` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sample_without_replacement`.
*/
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
WeightsVectorType&& wts,
raft::device_vector_view<DataT, IdxT> out,
OutIndexVectorType&& outIdx)
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 5>>
void sample_without_replacement(Args... args)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;
std::optional<raft::device_vector_view<const weight_type, IdxT>> weights =
std::forward<WeightsVectorType>(wts);
std::optional<raft::device_vector_view<IdxT, IdxT>> output_indices =
std::forward<OutIndexVectorType>(outIdx);

sample_without_replacement(handle, rng_state, in, weights, out, output_indices);
sample_without_replacement(std::forward<Args>(args)..., std::nullopt);
}

/**
Expand Down
Loading