-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DFT] compute_forward (forward_ip_rr) function template in forward.hpp overrides compute_forward (forward_op_cc) if in and out types are the same #499
Comments
Reproducer: sycl::queue sycl_queue(dev);
auto x_usm = sycl::malloc_shared<std::complex<float>>(N, sycl_queue);
auto y_usm = sycl::malloc_shared<std::complex<float>>(N, sycl_queue);
// 1. create descriptors
oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>
desc(static_cast<std::int64_t>(N));
// 2. variadic set_value
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
// 3. commit_descriptor (runtime dispatch)
desc.commit(sycl_queue);
// 4. compute_forward / compute_backward (runtime dispatch)
auto compute_event = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);
// Do something with transformed data.
compute_event.wait();
// 5. Free USM allocation.
sycl::free(x_usm, sycl_queue);
sycl::free(y_usm, sycl_queue); The above calls the real-real complex storage version, as described instead of the desired out-of-place version. Reproduction works neatly with the CuFFT backend, since cuFFT doesn't support real-real complex storage. The problem can be worked-around by giving template arguments explicitly: auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm); |
In term of the API:
My preferred solution to this is to throw a useful error if the function that gets called does not match the information with which the descriptor was committed. Do you think this is sufficient @CETuke ? |
I believe this is for real to complex and complex to real FFT. But as documented in the specification in https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/dft/config_params/data_layouts.html there should not be any conflict so it looks to me that this is a gap in the way the data type and descriptor consistency check is implemented (using valid_compute_arg). |
Thanks for taking this on, apologies for the late reply as I've been away, Could you please double-check the following is true: auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm); This is the auctual code I am using (essentially, your USM reproducer in this case): int reproducer() {
sycl::queue sycl_queue(dev);
auto x_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
auto y_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
// 1. create descriptors
oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>
desc(static_cast<std::int64_t>(8));
// 2. variadic set_value
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
// 3. commit_descriptor (runtime dispatch)
desc.commit(sycl_queue);
// 4. compute_forward / compute_backward (runtime dispatch)
auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);
// Do something with transformed data.
compute_event.wait();
// 5. Free USM allocation.
sycl::free(x_usm, sycl_queue);
sycl::free(y_usm, sycl_queue);
} Because even if I use your reproducer code verbatim and give the template arguments explicitly (as in the code above) it is still selecting the incorrect In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format. In fact, I've just tried your pull request code and now my driver code won't compile at all due to the added check on line 133 in forward.hpp. /install/include/oneapi/mkl/dft/forward.hpp:53:19: error: static assertion failed due to requirement '!detail::is_complex_arg<std::complex<float>>::value': expected real type for data_type
53 | static_assert(!detail::is_complex_arg<data_type>::value, "expected real type for data_type");
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ I guess this could be a difference in compiler behaviour (I'm using cmake 3.22.1 with clang++ 19.0.0git), but in my case explicit arguments aren't differentiatiing between the following two functions when both in and out are of type std::complex<float> (or std::complex<double>). line 128: //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) line 140: //Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
const std::vector<sycl::event> &dependencies = {}) |
@CETuke For the USM variation, I still get the same results: sycl::queue sycl_queue(dev, exception_handler);
auto x_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
auto y_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
// 1. create descriptors
oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>
desc(static_cast<std::int64_t>(8));
// 2. variadic set_value
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
// 3. commit_descriptor (runtime dispatch)
desc.commit(sycl_queue);
// 4. compute_forward / compute_backward (runtime dispatch)
// The following selects the correct compute_forward for me
auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);
// The following either selects the wrong version of compute forward
// * Generates a static_assert with the patch
// * Calls incorrect variant without the patch, with CuFFT backend generating the error:
// "oneMKL: DFT/compute_forward(desc, inout_re, inout_im, dependencies): function is not implemented cuFFT does not support real-real complex storage."
// auto compute_event = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);
// Do something with transformed data.
compute_event.wait();
// 5. Free USM allocation.
sycl::free(x_usm, sycl_queue); I'm slightly surprised by your result. The fact it isn't compiling due to the static_assert is a feature to me - its stopping you from compiling the in-place version that you didn't want. However, the three template arguments that you're giving it (as I understand) ( I'm going to look into a patch that uses SFINAE instead to help you choose the correct overload instead of generating errors. Perhaps you can give that a go once its ready? - I'll update this issue. |
Yes, I am giving it that template argument but my compiler (which is actually from a nightly snapshot from Intel's branch of LLVM /opt/llvm/llvm-nightly-2024-03-04/build/bin/clang) seems to be still choosing the two template argumented REAL_REAL. Sure, I'd be happy to try a SFINAE patch if you get one, I'll see if I can get someone here to independently reproduce the behaviour I'm seeing in the meantime. Thanks for the help. |
An implementation using SFINAE is available at https://github.com/hjabird/oneMKL/tree/dft_overload_resolution_with_enableif
|
@hjabird I can confirm this solution works well. In the spirit of openness, I also managed to track down the previous problem with my compiler apparently choosing the REAL_REAL versions despite explicit template arguments to an untemplated call to the compute_forward() function elsewhere in my driver code I hadn't noticed. So the compile error from your pull request code was actually being triggered by that and not the section of code I was looking at, apologies. However, I think your SFINAE solution is, by far, the better as we no longer need to provide explicit template arguments just to get the common use-case OOP C2C functions to work. |
Thanks, @hjabird! That SFINAE approach looks like a significant improvement in my opinion. I have one concern but I am a little rusty on the metaprogramming front so you may need to correct me below. Apologies for commenting here but I couldn't find a way to comment directly where relevant. Consider a single-precision complex descriptor
|
Closing since the spec change and oneMKL changes are merged. |
Summary
File: forward.hpp - "In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format" compute_forward function template overlaps with the "Out-of-place transform" where input_type == output_type (for example, Compex to Complex DFT).
Version
develop branch v0.4 tag (#34)
Environment
oneMKL DFT - implementing a new GPU backend
Observed behavior
When I set up my descriptor like this:
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::detail::config_value::NOT_INPLACE);
desc.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::detail::config_value::COMPLEX_COMPLEX);
and call a compute_forward() with complex in and out parameters the in-place REAL_REAL function template overlaps and is incorrectly chosen instead of the out-of-place COMPLEX_COMPLEX transform.
forward.hpp:
I can't use explicit template specialisation in my DFT backend (forward.cpp) file because the oneapi::mkl::dft REAL_REAL function above reinterprets the "in" and "out" data from complex (or complex) to scalar_type so my backend only sees compute_forward(desc, scalar_type, scalar_type).
Expected behavior
REAL_REAL variant only gets selected if inout_re and inout_im are of float or double type (not complex type).
or, alternatively:
REAL_REAL variant only gets selected if:
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::detail::config_value::INPLACE);
desc.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::detail::config_value::REAL_REAL);
The text was updated successfully, but these errors were encountered: