Skip to content

Commit

Permalink
Reenable OOP complex-complex version for float input
Browse files Browse the repository at this point in the history
  • Loading branch information
hjabird committed Jun 6, 2024
1 parent 037668b commit 914d742
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 26 deletions.
8 changes: 2 additions & 6 deletions include/oneapi/mkl/dft/backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_r
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
template <typename descriptor_type, typename input_type, typename output_type>
void compute_backward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,
sycl::buffer<output_type, 1> &out) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down Expand Up @@ -131,9 +129,7 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
template <typename descriptor_type, typename input_type, typename output_type>
sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down
14 changes: 0 additions & 14 deletions include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,6 @@ using valid_compute_arg = typename std::bool_constant<
(std::is_same_v<descriptor_scalar_t<descriptor_type>, double> &&
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;

// For out-of-place complex-complex DFTs, are the input and output types correct? For SFINAE.
template <class descriptor_t, typename input_t, typename output_t>
constexpr bool valid_oop_iotypes = []() {
if constexpr (is_complex_dft<descriptor_t>) {
// Both input and output types must be complex, otherwise select real-real inplace overload.
return is_complex<input_t> && is_complex<output_t>;
}
else {
// I/O can be real or complex - no issues resolving overload with real-real inplace.
return valid_compute_arg<descriptor_t, input_t>::value &&
valid_compute_arg<descriptor_t, output_t>::value;
}
}();

template <class descriptor_t, typename data_t>
constexpr bool valid_ip_realreal_impl =
is_complex_dft<descriptor_t>&& std::is_same_v<descriptor_scalar_t<descriptor_t>, data_t>;
Expand Down
8 changes: 2 additions & 6 deletions include/oneapi/mkl/dft/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
template <typename descriptor_type, typename input_type, typename output_type>
void compute_forward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,
sycl::buffer<output_type, 1> &out) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down Expand Up @@ -129,9 +127,7 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_typ
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
template <typename descriptor_type, typename input_type, typename output_type>
sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down

0 comments on commit 914d742

Please sign in to comment.