Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT] compute_forward (forward_ip_rr) function template in forward.hpp overrides compute_forward (forward_op_cc) if in and out types are the same #499

Closed
CETuke opened this issue May 25, 2024 · 10 comments
Assignees
Labels
API A request to add/change/fix/improve the API

Comments

@CETuke
Copy link

CETuke commented May 25, 2024

Summary

File: forward.hpp - "In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format" compute_forward function template overlaps with the "Out-of-place transform" where input_type == output_type (for example, Compex to Complex DFT).

Version

develop branch v0.4 tag (#34)

Environment

oneMKL DFT - implementing a new GPU backend

Observed behavior

When I set up my descriptor like this:
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::detail::config_value::NOT_INPLACE);
desc.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::detail::config_value::COMPLEX_COMPLEX);

and call a compute_forward() with complex in and out parameters the in-place REAL_REAL function template overlaps and is incorrectly chosen instead of the out-of-place COMPLEX_COMPLEX transform.

forward.hpp:

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format <- THIS GETS CHOSEN
template <typename descriptor_type, typename data_type>
void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
                     sycl::buffer<data_type, 1> &inout_im)  {
    static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
                  "unexpected type for data_type");

    using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
    auto type_corrected_inout_re = inout_re.template reinterpret<scalar_type, 1>(
        detail::reinterpret_range<data_type, scalar_type>(inout_re.size()));
    auto type_corrected_inout_im = inout_im.template reinterpret<scalar_type, 1>(
        detail::reinterpret_range<data_type, scalar_type>(inout_im.size()));
    get_commit(desc)->forward_ip_rr(desc, type_corrected_inout_re, type_corrected_inout_im);
}

//Out-of-place transform <- THIS SHOULD HAVE BEEN CHOSEN
template <typename descriptor_type, typename input_type, typename output_type>
void compute_forward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,
                     sycl::buffer<output_type, 1> &out)

I can't use explicit template specialisation in my DFT backend (forward.cpp) file because the oneapi::mkl::dft REAL_REAL function above reinterprets the "in" and "out" data from complex (or complex) to scalar_type so my backend only sees compute_forward(desc, scalar_type, scalar_type).

Expected behavior

REAL_REAL variant only gets selected if inout_re and inout_im are of float or double type (not complex type).

or, alternatively:

REAL_REAL variant only gets selected if:
desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::detail::config_value::INPLACE);
desc.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::detail::config_value::REAL_REAL);

@hjabird hjabird self-assigned this May 25, 2024
@hjabird hjabird added the API A request to add/change/fix/improve the API label May 25, 2024
@hjabird
Copy link
Contributor

hjabird commented May 28, 2024

Reproducer:

    sycl::queue sycl_queue(dev);
    auto x_usm = sycl::malloc_shared<std::complex<float>>(N, sycl_queue);
    auto y_usm = sycl::malloc_shared<std::complex<float>>(N, sycl_queue);

    // 1. create descriptors
    oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
                                 oneapi::mkl::dft::domain::COMPLEX>
        desc(static_cast<std::int64_t>(N));

    // 2. variadic set_value
    desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                   oneapi::mkl::dft::config_value::NOT_INPLACE);

    // 3. commit_descriptor (runtime dispatch)
    desc.commit(sycl_queue);

    // 4. compute_forward / compute_backward (runtime dispatch)
    auto compute_event = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);

    // Do something with transformed data.
    compute_event.wait();

    // 5. Free USM allocation.
    sycl::free(x_usm, sycl_queue);
    sycl::free(y_usm, sycl_queue);

The above calls the real-real complex storage version, as described instead of the desired out-of-place version. Reproduction works neatly with the CuFFT backend, since cuFFT doesn't support real-real complex storage.

The problem can be worked-around by giving template arguments explicitly:

    auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);

@hjabird
Copy link
Contributor

hjabird commented May 28, 2024

In term of the API:

  • I think the flexibility with regards to accepting floats for the input to the function signature that uses complex-complex data is intentional, and exists for backwards compatibility. @lhuot may be able to comment more on this.
  • Changing what the function does at runtime based on the committed descriptor would not conform to the oneAPI/oneMKL specification, and could potentially also be confusing.

My preferred solution to this is to throw a useful error if the function that gets called does not match the information with which the descriptor was committed. Do you think this is sufficient @CETuke ?

@lhuot
Copy link
Contributor

lhuot commented May 28, 2024

  • I think the flexibility with regards to accepting floats for the input to the function signature that uses complex-complex data is intentional, and exists for backwards compatibility. @lhuot may be able to comment more on this.

I believe this is for real to complex and complex to real FFT. But as documented in the specification in https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/dft/config_params/data_layouts.html there should not be any conflict so it looks to me that this is a gap in the way the data type and descriptor consistency check is implemented (using valid_compute_arg).

@CETuke
Copy link
Author

CETuke commented May 30, 2024

Thanks for taking this on, apologies for the late reply as I've been away,
My original issue is with the Complex to Complex out-of-place Buffer memory version (not USM, sorry I didn't make that more clear).

Could you please double-check the following is true:
The problem can be worked-around by giving template arguments explicitly:

    auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);

This is the auctual code I am using (essentially, your USM reproducer in this case):

int reproducer() {
    sycl::queue sycl_queue(dev);
    auto x_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
    auto y_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);

    // 1. create descriptors
    oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
                                 oneapi::mkl::dft::domain::COMPLEX>
        desc(static_cast<std::int64_t>(8));

    // 2. variadic set_value
    desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                   oneapi::mkl::dft::config_value::NOT_INPLACE);

    // 3. commit_descriptor (runtime dispatch)
    desc.commit(sycl_queue);

    // 4. compute_forward / compute_backward (runtime dispatch)
   auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);

    // Do something with transformed data.
    compute_event.wait();

    // 5. Free USM allocation.
    sycl::free(x_usm, sycl_queue);
    sycl::free(y_usm, sycl_queue);

}

Because even if I use your reproducer code verbatim and give the template arguments explicitly (as in the code above) it is still selecting the incorrect In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format. In fact, I've just tried your pull request code and now my driver code won't compile at all due to the added check on line 133 in forward.hpp.

/install/include/oneapi/mkl/dft/forward.hpp:53:19: error: static assertion failed due to requirement '!detail::is_complex_arg<std::complex<float>>::value': expected real type for data_type
   53 |     static_assert(!detail::is_complex_arg<data_type>::value, "expected real type for data_type");
      |                   ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

I guess this could be a difference in compiler behaviour (I'm using cmake 3.22.1 with clang++ 19.0.0git), but in my case explicit arguments aren't differentiatiing between the following two functions when both in and out are of type std::complex<float> (or std::complex<double>).

line 128:

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
                            const std::vector<sycl::event> &dependencies = {})

line 140:

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
                            const std::vector<sycl::event> &dependencies = {})                             

@hjabird
Copy link
Contributor

hjabird commented May 30, 2024

@CETuke For the USM variation, I still get the same results:

    sycl::queue sycl_queue(dev, exception_handler);
    auto x_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);
    auto y_usm = sycl::malloc_shared<std::complex<float>>(8, sycl_queue);

    // 1. create descriptors
    oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
                                 oneapi::mkl::dft::domain::COMPLEX>
        desc(static_cast<std::int64_t>(8));

    // 2. variadic set_value
    desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                   oneapi::mkl::dft::config_value::NOT_INPLACE);

    // 3. commit_descriptor (runtime dispatch)
    desc.commit(sycl_queue);

    // 4. compute_forward / compute_backward (runtime dispatch)

    // The following selects the correct compute_forward for me
    auto compute_event = oneapi::mkl::dft::compute_forward<decltype(desc), std::complex<float>, std::complex<float>>(desc, x_usm, y_usm);

    // The following either selects the wrong version of compute forward
    // * Generates a static_assert with the patch
    // * Calls incorrect variant without the patch, with CuFFT backend generating the error:
    //   "oneMKL: DFT/compute_forward(desc, inout_re, inout_im, dependencies): function is not implemented cuFFT does not support real-real complex storage."
    // auto compute_event = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);

    // Do something with transformed data.
    compute_event.wait();

    // 5. Free USM allocation.
    sycl::free(x_usm, sycl_queue);

I'm slightly surprised by your result. The fact it isn't compiling due to the static_assert is a feature to me - its stopping you from compiling the in-place version that you didn't want. However, the three template arguments that you're giving it (as I understand) (decltype(desc), std::complex<float>, std::complex<float>) are more than the two that the overload generating this static assert supports.

I'm going to look into a patch that uses SFINAE instead to help you choose the correct overload instead of generating errors. Perhaps you can give that a go once its ready? - I'll update this issue.

@CETuke
Copy link
Author

CETuke commented May 30, 2024

However, the three template arguments that you're giving it (as I understand) (decltype(desc), std::complex, std::complex) are more than the two that the overload generating this static assert supports.

Yes, I am giving it that template argument but my compiler (which is actually from a nightly snapshot from Intel's branch of LLVM /opt/llvm/llvm-nightly-2024-03-04/build/bin/clang) seems to be still choosing the two template argumented REAL_REAL.

Sure, I'd be happy to try a SFINAE patch if you get one, I'll see if I can get someone here to independently reproduce the behaviour I'm seeing in the meantime.

Thanks for the help.

@hjabird
Copy link
Contributor

hjabird commented May 30, 2024

An implementation using SFINAE is available at https://github.com/hjabird/oneMKL/tree/dft_overload_resolution_with_enableif

  • @CETuke hopefully this will achieve what you're looking for - let me know if it doesn't.
  • @lhuot @raphael-egan @Rbiessy Let me know if you think this is what we're after and I'll create a PR.

@CETuke
Copy link
Author

CETuke commented May 30, 2024

* @CETuke hopefully this will achieve what you're looking for - let me know if it doesn't.

@hjabird I can confirm this solution works well.

In the spirit of openness, I also managed to track down the previous problem with my compiler apparently choosing the REAL_REAL versions despite explicit template arguments to an untemplated call to the compute_forward() function elsewhere in my driver code I hadn't noticed. So the compile error from your pull request code was actually being triggered by that and not the section of code I was looking at, apologies.

However, I think your SFINAE solution is, by far, the better as we no longer need to provide explicit template arguments just to get the common use-case OOP C2C functions to work.

@raphael-egan
Copy link

An implementation using SFINAE is available at https://github.com/hjabird/oneMKL/tree/dft_overload_resolution_with_enableif

  • @CETuke hopefully this will achieve what you're looking for - let me know if it doesn't.
  • @lhuot @raphael-egan @Rbiessy Let me know if you think this is what we're after and I'll create a PR.

Thanks, @hjabird! That SFINAE approach looks like a significant improvement in my opinion. I have one concern but I am a little rusty on the metaprogramming front so you may need to correct me below. Apologies for commenting here but I couldn't find a way to comment directly where relevant.

Consider a single-precision complex descriptor desc successfully committed, configured with COMPLEX_COMPLEX and NOT_INPLACE values set for configuration parameters COMPLEX_STORAGE and PLACEMENT, respectively. Let's assume that the user communicates I/O data as float *in, *out pointers to device-accessible USM allocations (note: not std::complex<float>*).

  1. My understanding of the specs and of the implementation would be that, in such a case, a call like compute_{for,back}ward<descriptor<precision::SINGLE, domain::COMPLEX>, float>(desc, in, out); can't be expected to result in an out-of-place, transform (likely intended by the user), and your branch complies with that, unless I'm mistaken. Do we agree on that?
  2. Is the same to be expected for compute_{for,back}ward(desc, in, out); (no explicit specification of template parameters)? If not, why?
  3. Finally, could you please confirm that compute_{for,back}ward<descriptor<precision::SINGLE, domain::COMPLEX>, float, float>(desc, in, out); would be the out-of-place, transform likely intended by the user, with the changes from that branch despite the implementation of valid_oop_iotypes?

@hjabird
Copy link
Contributor

hjabird commented Jul 2, 2024

Closing since the spec change and oneMKL changes are merged.

@hjabird hjabird closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API A request to add/change/fix/improve the API
Projects
None yet
4 participants