Skip to content

Commit

Permalink
[DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL (ux…
Browse files Browse the repository at this point in the history
…lfoundation#503)

* OOP COMPLEX and IP REAL_REAL overload resolution is problematic
  * Inplace real-real overload would be selected when out-of-place complex-complex DFT was intended.
  * With spec update, this PR uses SFINAE to give the expected behaviour for the user.
  • Loading branch information
hjabird authored and normallytangent committed Aug 6, 2024
1 parent 08adee2 commit 0d301e6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
6 changes: 4 additions & 2 deletions include/oneapi/mkl/dft/backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout)
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down Expand Up @@ -114,7 +115,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down
22 changes: 20 additions & 2 deletions include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,34 @@ struct descriptor_info<descriptor<precision::DOUBLE, domain::COMPLEX>> {
using backward_type = std::complex<double>;
};

// Get the scalar type associated with a descriptor.
template <class descriptor_t>
using descriptor_scalar_t = typename descriptor_info<descriptor_t>::scalar_type;

template <typename T>
constexpr bool is_complex_dft = false;
template <precision Prec>
constexpr bool is_complex_dft<descriptor<Prec, domain::COMPLEX>> = true;

template <typename T>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<std::complex<T>> = true;

template <typename T, typename... Ts>
using is_one_of = typename std::bool_constant<(std::is_same_v<T, Ts> || ...)>;

template <typename descriptor_type, typename T>
using valid_compute_arg = typename std::bool_constant<
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, float> &&
(std::is_same_v<descriptor_scalar_t<descriptor_type>, float> &&
is_one_of<T, float, sycl::float2, sycl::float4, std::complex<float>>::value) ||
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, double> &&
(std::is_same_v<descriptor_scalar_t<descriptor_type>, double> &&
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;

template <class descriptor_t, typename data_t>
constexpr bool valid_ip_realreal_impl =
is_complex_dft<descriptor_t>&& std::is_same_v<descriptor_scalar_t<descriptor_t>, data_t>;

// compute the range of a reinterpreted buffer
template <typename In, typename Out>
std::size_t reinterpret_range(std::size_t size) {
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/mkl/dft/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout) {
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down Expand Up @@ -114,12 +115,12 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
reinterpret_cast<scalar_type *>(inout_im), dependencies);
Expand All @@ -133,7 +134,6 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;
return get_commit(desc)->forward_op_cc(desc, reinterpret_cast<fwd_type *>(in),
Expand Down

0 comments on commit 0d301e6

Please sign in to comment.