From 8e0636fa68d0f3fa504f9fe0f14abfc2a582c761 Mon Sep 17 00:00:00 2001 From: Finlay Date: Thu, 22 Jun 2023 17:37:45 +0100 Subject: [PATCH] [DFT] Add rocFFT backend for DFT interface (#330) * Add rocfft backend * avoid creating plans with invalid strides * update example * Update readme to show rocfft support * Update product and version information * increase tolerances * update README * formatting changes * fix unique_ptr creation * Apply rule of three to rocfft commit class --- CMakeLists.txt | 8 +- README.md | 17 +- examples/README.md | 27 +- .../dft/run_time_dispatching/CMakeLists.txt | 11 +- include/oneapi/mkl/detail/backends.hpp | 22 +- include/oneapi/mkl/detail/backends_table.hpp | 6 + .../oneapi/mkl/dft/detail/descriptor_impl.hpp | 4 + .../dft/detail/rocfft/onemkl_dft_rocfft.hpp | 49 ++ src/config.hpp.in | 11 +- src/dft/backends/CMakeLists.txt | 4 + src/dft/backends/rocfft/CMakeLists.txt | 76 ++++ src/dft/backends/rocfft/backward.cpp | 264 +++++++++++ src/dft/backends/rocfft/commit.cpp | 430 ++++++++++++++++++ src/dft/backends/rocfft/compute_signature.cpp | 24 + src/dft/backends/rocfft/descriptor.cpp | 51 +++ src/dft/backends/rocfft/execute_helper.hpp | 97 ++++ src/dft/backends/rocfft/forward.cpp | 266 +++++++++++ .../rocfft/mkl_dft_rocfft_wrappers.cpp | 32 ++ src/dft/backends/rocfft/rocfft_handle.hpp | 34 ++ tests/unit_tests/CMakeLists.txt | 5 + .../unit_tests/dft/include/compute_tester.hpp | 2 +- tests/unit_tests/include/test_helper.hpp | 14 + tests/unit_tests/main_test.cpp | 5 +- 23 files changed, 1429 insertions(+), 30 deletions(-) create mode 100644 include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp create mode 100644 src/dft/backends/rocfft/CMakeLists.txt create mode 100644 src/dft/backends/rocfft/backward.cpp create mode 100644 src/dft/backends/rocfft/commit.cpp create mode 100644 src/dft/backends/rocfft/compute_signature.cpp create mode 100644 src/dft/backends/rocfft/descriptor.cpp create mode 100644 src/dft/backends/rocfft/execute_helper.hpp create mode 100644 src/dft/backends/rocfft/forward.cpp create mode 100644 src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp create mode 100644 src/dft/backends/rocfft/rocfft_handle.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f07b8c677..10e7f8b60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(ENABLE_ROCSOLVER_BACKEND "Enable the rocSOLVER backend for the LAPACK int # dft option(ENABLE_CUFFT_BACKEND "Enable the cuFFT backend for the DFT interface" OFF) +option(ENABLE_ROCFFT_BACKEND "Enable the rocFFT backend for the DFT interface" OFF) set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler") @@ -100,7 +101,8 @@ if(ENABLE_MKLCPU_BACKEND endif() if(ENABLE_MKLGPU_BACKEND OR ENABLE_MKLCPU_BACKEND - OR ENABLE_CUFFT_BACKEND) + OR ENABLE_CUFFT_BACKEND + OR ENABLE_ROCFFT_BACKEND) list(APPEND DOMAINS_LIST "dft") endif() @@ -119,8 +121,8 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") string(REPLACE "\\" "/" CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER}) endif() else() - if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUFFT_BACKEND - OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND + OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND) set(CMAKE_CXX_COMPILER "clang++") elseif(ENABLE_MKLGPU_BACKEND) if(UNIX) diff --git a/README.md b/README.md index 0095e1865..1a443db2f 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ oneMKL is part of [oneAPI](https://oneapi.io). - oneMKL interface - oneMKL selector + oneMKL interface + oneMKL selector Intel(R) oneAPI Math Kernel Library for x86 CPU x86 CPU @@ -59,6 +59,10 @@ oneMKL is part of [oneAPI](https://oneapi.io). AMD rocRAND for AMD GPU AMD GPU + + AMD rocFFT for AMD GPU + AMD GPU + SYCL-BLAS x86 CPU, Intel GPU, NVIDIA GPU, AMD GPU @@ -238,7 +242,7 @@ Supported domains: BLAS, LAPACK, RNG, DFT LLVM*, hipSYCL - DFT + DFT Intel GPU Intel(R) oneAPI Math Kernel Library Dynamic, Static @@ -255,6 +259,12 @@ Supported domains: BLAS, LAPACK, RNG, DFT Dynamic, Static DPC++ + + AMD GPU + AMD rocFFT + Dynamic, Static + DPC++ + @@ -464,6 +474,7 @@ Python | 3.6 or higher | No | *N/A* | *Pre-installed or Installed by user* | [PS [AMD rocBLAS](https://rocblas.readthedocs.io/en/rocm-4.5.2/) | 4.5 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/LICENSE.md) [AMD rocRAND](https://github.com/ROCmSoftwarePlatform/rocRAND) | 5.1.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt) [AMD rocSOLVER](https://github.com/ROCmSoftwarePlatform/rocSOLVER) | 5.0.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt) +[AMD rocFFT](https://github.com/ROCmSoftwarePlatform/rocFFT) | rocm-5.4.3 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocFFT/blob/rocm-5.4.3/LICENSE.md) [NETLIB LAPACK](https://www.netlib.org/) | 3.7.1 | Yes | conan-community | ~/.conan/data or $CONAN_USER_HOME/.conan/data | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt) [Sphinx](https://www.sphinx-doc.org/en/master/) | 2.4.4 | Yes | pip | ~/.local/bin (or similar user local directory) | [BSD License](https://github.com/sphinx-doc/sphinx/blob/3.x/LICENSE) [SYCL-BLAS](https://github.com/codeplaysoftware/sycl-blas) | 0.1 | No | *N/A* | *Installed by user* | [Apache License v2.0](https://github.com/codeplaysoftware/sycl-blas/blob/master/LICENSE) diff --git a/examples/README.md b/examples/README.md index f370fc21a..6b90ba208 100644 --- a/examples/README.md +++ b/examples/README.md @@ -380,7 +380,7 @@ Running with single precision real data type on: DFT Complex USM example ran OK on MKLGPU ``` -Runtime dispatching example with both MKLGPU and cuFFT backend +Runtime dispatching example with MKLGPU, cuFFT, and rocFFT backends: ```none SYCL_DEVICE_FILTER=gpu ./bin/example_dft_real_fwd_usm @@ -431,3 +431,28 @@ Running with single precision real data type: DFT example run_time dispatch DFT example ran OK ``` + +```none +./bin/example_dft_real_fwd_usm + +######################################################################## +# DFTI complex in-place forward transform with USM API example: +# +# Using APIs: +# USM forward complex in-place +# Run-time dispatch +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable SYCL_DEVICE_FILTER can be used to specify +# SYCL device +# +######################################################################## + +Running DFT complex forward example on GPU device +Device name is: AMD Radeon PRO W6800 +Running with single precision real data type: +DFT example run_time dispatch +DFT example ran OK +``` \ No newline at end of file diff --git a/examples/dft/run_time_dispatching/CMakeLists.txt b/examples/dft/run_time_dispatching/CMakeLists.txt index bba311307..c9bc1a2ce 100644 --- a/examples/dft/run_time_dispatching/CMakeLists.txt +++ b/examples/dft/run_time_dispatching/CMakeLists.txt @@ -18,20 +18,17 @@ #=============================================================================== # NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example (no need to specify backend when building with CMake) +include(WarningsUtils) + # Build object from all example sources set(DFT_RT_SOURCES "") -if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND) - list(APPEND DFT_RT_SOURCES "real_fwd_usm") -endif() - -include(WarningsUtils) - # Set up for the right backend for run-time dispatching examples # If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to # overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend set(DEVICE_FILTERS "") -if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND) +if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_ROCFFT_BACKEND) + list(APPEND DFT_RT_SOURCES "real_fwd_usm") list(APPEND DEVICE_FILTERS "gpu") endif() diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp index 0d775bca4..6c9619cab 100644 --- a/include/oneapi/mkl/detail/backends.hpp +++ b/include/oneapi/mkl/detail/backends.hpp @@ -38,19 +38,25 @@ enum class backend { rocrand, syclblas, cufft, + rocfft, unsupported }; typedef std::map backendmap; -static backendmap backend_map = { - { backend::mklcpu, "mklcpu" }, { backend::mklgpu, "mklgpu" }, - { backend::cublas, "cublas" }, { backend::cusolver, "cusolver" }, - { backend::curand, "curand" }, { backend::netlib, "netlib" }, - { backend::rocblas, "rocblas" }, { backend::rocrand, "rocrand" }, - { backend::rocsolver, "rocsolver" }, { backend::syclblas, "syclblas" }, - { backend::cufft, "cufft" }, { backend::unsupported, "unsupported" } -}; +static backendmap backend_map = { { backend::mklcpu, "mklcpu" }, + { backend::mklgpu, "mklgpu" }, + { backend::cublas, "cublas" }, + { backend::cusolver, "cusolver" }, + { backend::curand, "curand" }, + { backend::netlib, "netlib" }, + { backend::rocblas, "rocblas" }, + { backend::rocrand, "rocrand" }, + { backend::rocsolver, "rocsolver" }, + { backend::syclblas, "syclblas" }, + { backend::cufft, "cufft" }, + { backend::rocfft, "rocfft" }, + { backend::unsupported, "unsupported" } }; } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index 8070e5a0e..cfae0e43e 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -96,6 +96,12 @@ static std::map>> libraries = { #ifdef ENABLE_MKLGPU_BACKEND LIB_NAME("dft_mklgpu") +#endif + } }, + { device::amdgpu, + { +#ifdef ENABLE_ROCFFT_BACKEND + LIB_NAME("dft_rocfft"), #endif } }, { device::nvidiagpu, diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index f650d1f61..82f31b792 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -72,6 +72,10 @@ class descriptor { void commit(backend_selector selector); #endif +#ifdef ENABLE_ROCFFT_BACKEND + void commit(backend_selector selector); +#endif + const dft_values& get_values() const noexcept { return values_; }; diff --git a/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp new file mode 100644 index 000000000..8ccf57858 --- /dev/null +++ b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp @@ -0,0 +1,49 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_ROCFFT_HPP_ +#define _ONEMKL_DFT_ROCFFT_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +namespace oneapi::mkl::dft { + +namespace detail { +// Forward declarations +template +class commit_impl; + +template +class descriptor; +} // namespace detail + +namespace rocfft { +#include "oneapi/mkl/dft/detail/dft_ct.hxx" +} // namespace rocfft + +} // namespace oneapi::mkl::dft + +#endif // _ONEMKL_DFT_ROCFFT_HPP_ diff --git a/src/config.hpp.in b/src/config.hpp.in index 702c11943..8a24befc5 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -21,19 +21,20 @@ #define ONEMKL_CONFIG_H #cmakedefine ENABLE_CUBLAS_BACKEND -#cmakedefine ENABLE_CUSOLVER_BACKEND #cmakedefine ENABLE_CUFFT_BACKEND -#cmakedefine ENABLE_ROCBLAS_BACKEND -#cmakedefine ENABLE_ROCRAND_BACKEND -#cmakedefine ENABLE_ROCSOLVER_BACKEND #cmakedefine ENABLE_CURAND_BACKEND +#cmakedefine ENABLE_CUSOLVER_BACKEND #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND +#cmakedefine ENABLE_ROCBLAS_BACKEND +#cmakedefine ENABLE_ROCFFT_BACKEND +#cmakedefine ENABLE_ROCRAND_BACKEND +#cmakedefine ENABLE_ROCSOLVER_BACKEND #cmakedefine ENABLE_SYCLBLAS_BACKEND +#cmakedefine ENABLE_SYCLBLAS_BACKEND_AMD_GPU #cmakedefine ENABLE_SYCLBLAS_BACKEND_INTEL_CPU #cmakedefine ENABLE_SYCLBLAS_BACKEND_INTEL_GPU -#cmakedefine ENABLE_SYCLBLAS_BACKEND_AMD_GPU #cmakedefine ENABLE_SYCLBLAS_BACKEND_NVIDIA_GPU #cmakedefine BUILD_SHARED_LIBS #cmakedefine REF_BLAS_LIBNAME "@REF_BLAS_LIBNAME@" diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt index 1390cbee1..1fbea19e4 100644 --- a/src/dft/backends/CMakeLists.txt +++ b/src/dft/backends/CMakeLists.txt @@ -28,3 +28,7 @@ endif() if(ENABLE_CUFFT_BACKEND) add_subdirectory(cufft) endif() + +if(ENABLE_ROCFFT_BACKEND) + add_subdirectory(rocfft) +endif() diff --git a/src/dft/backends/rocfft/CMakeLists.txt b/src/dft/backends/rocfft/CMakeLists.txt new file mode 100644 index 000000000..7d27854d6 --- /dev/null +++ b/src/dft/backends/rocfft/CMakeLists.txt @@ -0,0 +1,76 @@ +#=============================================================================== +# Copyright Codeplay Software Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_dft_rocfft) +set(LIB_OBJ ${LIB_NAME}_obj) + + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + descriptor.cpp + commit.cpp + forward.cpp + backward.cpp + compute_signature.cpp + $<$: mkl_dft_rocfft_wrappers.cpp> +) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${MKL_INCLUDE} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) + +find_package(HIP REQUIRED) +find_package(rocfft REQUIRED) + +target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocfft) + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +# Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/rocfft/backward.cpp b/src/dft/backends/rocfft/backward.cpp new file mode 100644 index 000000000..6a4616c8b --- /dev/null +++ b/src/dft/backends/rocfft/backward.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include +#include + +#include "execute_helper.hpp" +#include "oneapi/mkl/dft/backward.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "rocfft_handle.hpp" + +namespace oneapi::mkl::dft::rocfft { +namespace detail { +template +rocfft_plan get_bwd_plan(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[1].plan.value(); +} + +template +rocfft_execution_info get_bwd_info(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[1].info.value(); +} +} // namespace detail +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, inout)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto inout_native = detail::native_mem(ih, inout_acc); + detail::execute_checked(func_name, plan, &inout_native, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto inout_re_acc = inout_re.template get_access(cgh); + auto inout_im_acc = inout_im.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, inout_re, inout_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native = { detail::native_mem(ih, inout_re_acc), + detail::native_mem(ih, inout_im_acc) }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in, out)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto in_native = detail::native_mem(ih, in_acc); + auto out_native = detail::native_mem(ih, out_acc); + detail::execute_checked(func_name, plan, &in_native, &out_native, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_re_acc = in_re.template get_access(cgh); + auto in_im_acc = in_im.template get_access(cgh); + auto out_re_acc = out_re.template get_access(cgh); + auto out_im_acc = out_im.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native = { detail::native_mem(ih, in_re_acc), + detail::native_mem(ih, in_im_acc) }; + std::array out_native = { detail::native_mem(ih, out_re_acc), + detail::native_mem(ih, out_im_acc) }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, inout, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *inout_ptr = inout; + detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, + data_type *inout_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, inout_re, inout_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native = { inout_re, inout_im }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in, out, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *in_ptr = in; + void *out_ptr = out; + detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re, + input_type *in_im, output_type *out_re, + output_type *out_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = + "compute_backward(desc, in_re, in_im, out_re, out_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native = { in_re, in_im }; + std::array out_native = { out_re, out_im }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +// Template function instantiations +#include "dft/backends/backend_backward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/commit.cpp b/src/dft/backends/rocfft/commit.cpp new file mode 100644 index 000000000..723322d15 --- /dev/null +++ b/src/dft/backends/rocfft/commit.cpp @@ -0,0 +1,430 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include +#include +#include + +#include +#include + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "rocfft_handle.hpp" + +namespace oneapi::mkl::dft::rocfft { +namespace detail { + +// rocfft has global setup and cleanup functions which use some global state internally. +// Each can be called multiple times in an application, but due to the global nature, they always need to alternate. +// I don't believe its possible to avoid the user calling rocfft_cleanup in their own code, +// breaking our code, but we can try avoid it for them. +// rocfft_cleanup internally uses some singletons, so it is very difficult to decide if this is safe due to +// the static initialisation order fiasco. +class rocfft_singleton { + rocfft_singleton() { + const auto result = rocfft_setup(); + if (result != rocfft_status_success) { + throw mkl::exception( + "DFT", "rocfft", + "Failed to setup rocfft. returned status " + std::to_string(result)); + } + } + + ~rocfft_singleton() { + (void)rocfft_cleanup(); + } + + // no copies or moves allowed + rocfft_singleton(const rocfft_singleton& other) = delete; + rocfft_singleton(rocfft_singleton&& other) noexcept = delete; + rocfft_singleton& operator=(const rocfft_singleton& other) = delete; + rocfft_singleton& operator=(rocfft_singleton&& other) noexcept = delete; + +public: + static void init() { + static rocfft_singleton instance; + (void)instance; + } +}; + +/// Commit impl class specialization for rocFFT. +template +class rocfft_commit final : public dft::detail::commit_impl { +private: + // For real to complex transforms, the "transform_type" arg also encodes the direction (e.g. rocfft_transform_type_*_forward vs rocfft_transform_type_*_backward) + // in the plan so we must have one for each direction. + // We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while rocFFT uses a directional "in_distance" and "out_distance". + // The same is also true for "FORWARD_SCALE" and "BACKWARD_SCALE". + // handles[0] is forward, handles[1] is backward + std::array handles{}; + +public: + rocfft_commit(sycl::queue& queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::rocfft) { + if constexpr (prec == dft::detail::precision::DOUBLE) { + if (!queue.get_device().has(sycl::aspect::fp64)) { + throw mkl::exception("DFT", "commit", "Device does not support double precision."); + } + } + // initialise the rocFFT global state + rocfft_singleton::init(); + } + + void clean_plans() { + if (handles[0].plan) { + if (rocfft_plan_destroy(handles[0].plan.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy forward plan."); + } + handles[0].plan = std::nullopt; + } + if (handles[1].plan) { + if (rocfft_plan_destroy(handles[1].plan.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy backward plan."); + } + handles[1].plan = std::nullopt; + } + + if (handles[0].info) { + if (rocfft_execution_info_destroy(handles[0].info.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy forward execution info ."); + } + handles[0].info = std::nullopt; + } + if (handles[1].info) { + if (rocfft_execution_info_destroy(handles[1].info.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy backward execution info ."); + } + handles[1].info = std::nullopt; + } + if (handles[0].buffer) { + if (hipFree(handles[0].buffer.value()) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to free forward buffer."); + } + handles[0].buffer = std::nullopt; + } + if (handles[1].buffer) { + if (hipFree(handles[1].buffer.value()) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to free backward buffer."); + } + handles[1].buffer = std::nullopt; + } + } + + void commit(const dft::detail::dft_values& config_values) override { + // this could be a recommit + clean_plans(); + + const rocfft_result_placement placement = + (config_values.placement == dft::config_value::INPLACE) ? rocfft_placement_inplace + : rocfft_placement_notinplace; + + constexpr rocfft_transform_type fwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + return rocfft_transform_type_complex_forward; + } + else { + return rocfft_transform_type_real_forward; + } + }(); + constexpr rocfft_transform_type bwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + return rocfft_transform_type_complex_inverse; + } + else { + return rocfft_transform_type_real_inverse; + } + }(); + + constexpr rocfft_precision precision = [] { + if constexpr (prec == dft::precision::SINGLE) { + return rocfft_precision_single; + } + else { + return rocfft_precision_double; + } + }(); + + const std::size_t dimensions = config_values.dimensions.size(); + + constexpr std::size_t max_supported_dims = 3; + std::array lengths; + // rocfft does dimensions in the reverse order to oneMKL + std::copy(config_values.dimensions.crbegin(), config_values.dimensions.crend(), + lengths.data()); + + const std::size_t number_of_transforms = + static_cast(config_values.number_of_transforms); + + const std::size_t fwd_distance = static_cast(config_values.fwd_dist); + const std::size_t bwd_distance = static_cast(config_values.bwd_dist); + + const rocfft_array_type fwd_array_ty = [&config_values]() { + if constexpr (dom == dft::domain::COMPLEX) { + if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) { + return rocfft_array_type_complex_interleaved; + } + else { + return rocfft_array_type_complex_planar; + } + } + else { + return rocfft_array_type_real; + } + }(); + const rocfft_array_type bwd_array_ty = [&config_values]() { + if constexpr (dom == dft::domain::COMPLEX) { + if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) { + return rocfft_array_type_complex_interleaved; + } + else { + return rocfft_array_type_complex_planar; + } + } + else { + if (config_values.conj_even_storage != dft::config_value::COMPLEX_COMPLEX) { + throw mkl::exception( + "dft/backends/rocfft", __FUNCTION__, + "only COMPLEX_COMPLEX conjugate_even_storage is supported"); + } + return rocfft_array_type_hermitian_interleaved; + } + }(); + + std::array in_offsets{ + static_cast(config_values.input_strides[0]), + static_cast(config_values.input_strides[0]) + }; + std::array out_offsets{ + static_cast(config_values.output_strides[0]), + static_cast(config_values.output_strides[0]) + }; + + std::array in_strides; + std::array out_strides; + + for (std::size_t i = 0; i != dimensions; ++i) { + in_strides[i] = config_values.input_strides[dimensions - i]; + out_strides[i] = config_values.output_strides[dimensions - i]; + } + + rocfft_plan_description plan_desc; + if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create plan description."); + } + + // plan_description can be destroyed afted plan_create + auto description_destroy = [](rocfft_plan_description p) { + if (rocfft_plan_description_destroy(p) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy plan description."); + } + }; + std::unique_ptr + description_destroyer(plan_desc, description_destroy); + + // When creating real-complex descriptions, the strides will always be wrong for one of the directions. + // This is because the least significant dimension is symmetric. + // If the strides are invalid (too small to fit) then just don't bother creating the plan. + const bool ignore_strides = dom == dft::domain::COMPLEX || dimensions == 1; + const bool valid_forward = + ignore_strides || (lengths[0] <= in_strides[1] && lengths[0] / 2 + 1 <= out_strides[1]); + const bool valid_backward = + ignore_strides || (lengths[0] <= out_strides[1] && lengths[0] / 2 + 1 <= in_strides[1]); + + if (valid_forward) { + auto res = + rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty, + in_offsets.data(), // in offsets + out_offsets.data(), // out offsets + dimensions, + in_strides.data(), //in strides + fwd_distance, // in distance + dimensions, + out_strides.data(), // out strides + bwd_distance // out distance + ); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set forward data layout."); + } + + if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set forward scale factor."); + } + + rocfft_plan fwd_plan; + res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions, + lengths.data(), number_of_transforms, plan_desc); + + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create forward plan."); + } + + handles[0].plan = fwd_plan; + + rocfft_execution_info fwd_info; + if (rocfft_execution_info_create(&fwd_info) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create forward execution info."); + } + handles[0].info = fwd_info; + + // plan work buffer + std::size_t work_buf_size; + if (rocfft_plan_get_work_buffer_size(fwd_plan, &work_buf_size) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get forward work buffer size."); + } + if (work_buf_size != 0) { + void* work_buf; + if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get allocate forward work buffer."); + } + handles[0].buffer = work_buf; + if (rocfft_execution_info_set_work_buffer(fwd_info, work_buf, work_buf_size) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set forward work buffer."); + } + } + } + + if (valid_backward) { + auto res = + rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty, + in_offsets.data(), // in offsets + out_offsets.data(), // out offsets + dimensions, + in_strides.data(), //in strides + bwd_distance, // in distance + dimensions, + out_strides.data(), // out strides + fwd_distance // out distance + ); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set backward data layout."); + } + + if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set backward scale factor."); + } + + rocfft_plan bwd_plan; + res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions, + lengths.data(), number_of_transforms, plan_desc); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create backward rocFFT plan."); + } + handles[1].plan = bwd_plan; + + rocfft_execution_info bwd_info; + if (rocfft_execution_info_create(&bwd_info) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create backward execution info."); + } + handles[1].info = bwd_info; + + std::size_t work_buf_size; + if (rocfft_plan_get_work_buffer_size(bwd_plan, &work_buf_size) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get backward work buffer size."); + } + + if (work_buf_size != 0) { + void* work_buf; + if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get allocate backward work buffer."); + } + handles[1].buffer = work_buf; + + if (rocfft_execution_info_set_work_buffer(bwd_info, work_buf, work_buf_size) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set backward work buffer."); + } + } + } + } + + ~rocfft_commit() override { + clean_plans(); + } + + // Rule of three. Copying could lead to memory safety issues. + rocfft_commit(const rocfft_commit& other) = delete; + rocfft_commit& operator=(const rocfft_commit& other) = delete; + + void* get_handle() noexcept override { + return handles.data(); + } +}; +} // namespace detail + +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::rocfft_commit(sycl_queue, desc.get_values()); +} + +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/compute_signature.cpp b/src/dft/backends/rocfft/compute_signature.cpp new file mode 100644 index 000000000..c26826fcc --- /dev/null +++ b/src/dft/backends/rocfft/compute_signature.cpp @@ -0,0 +1,24 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" + +#define BACKEND rocfft + +#include "dft/backends/backend_compute_signature.cxx" diff --git a/src/dft/backends/rocfft/descriptor.cpp b/src/dft/backends/rocfft/descriptor.cpp new file mode 100644 index 000000000..83fdbe1dc --- /dev/null +++ b/src/dft/backends/rocfft/descriptor.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(backend_selector selector) { + if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) { + if (pimpl_) { + pimpl_->get_queue().wait(); + } + pimpl_.reset(rocfft::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(values_); +} + +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/backends/rocfft/execute_helper.hpp b/src/dft/backends/rocfft/execute_helper.hpp new file mode 100644 index 000000000..4dff6831d --- /dev/null +++ b/src/dft/backends/rocfft/execute_helper.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ +#define _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include +#include + +namespace oneapi::mkl::dft::rocfft::detail { + +template +inline dft::detail::commit_impl *checked_get_commit( + dft::detail::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::rocfft) { + throw mkl::invalid_argument("dft/backends/rocfft", "get_commit", + "DFT descriptor has not been commited for rocFFT"); + } + return commit_handle; +} + +/// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +/// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("dft/backends/rocfft", "expect_config", message); + } +} + +template +inline void *native_mem(sycl::interop_handle &ih, Acc &buf) { + return ih.get_native_mem(buf); +} + +inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &ih, + rocfft_execution_info info) { + auto stream = ih.get_native_queue(); + auto result = rocfft_execution_info_set_stream(info, stream); + if (result != rocfft_status_success) { + throw oneapi::mkl::exception( + "dft/backends/rocfft", func, + "rocfft_execution_info_set_stream returned " + std::to_string(result)); + } + return stream; +} + +inline void sync_checked(const std::string &func, hipStream_t stream) { + auto result = hipStreamSynchronize(stream); + if (result != hipSuccess) { + throw oneapi::mkl::exception("dft/backends/rocfft", func, + "hipStreamSynchronize returned " + std::to_string(result)); + } +} + +inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[], + void *out_buffer[], rocfft_execution_info info) { + auto result = rocfft_execute(plan, in_buffer, out_buffer, info); + if (result != rocfft_status_success) { + throw oneapi::mkl::exception("dft/backends/rocfft", func, + "rocfft_execute returned " + std::to_string(result)); + } +} + +} // namespace oneapi::mkl::dft::rocfft::detail + +#endif diff --git a/src/dft/backends/rocfft/forward.cpp b/src/dft/backends/rocfft/forward.cpp new file mode 100644 index 000000000..9b92d9097 --- /dev/null +++ b/src/dft/backends/rocfft/forward.cpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include +#include + +#include "execute_helper.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/forward.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "rocfft_handle.hpp" + +namespace oneapi::mkl::dft::rocfft { + +namespace detail { +template +rocfft_plan get_fwd_plan(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[0].plan.value(); +} + +template +rocfft_execution_info get_fwd_info(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[0].info.value(); +} +} // namespace detail + +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, inout)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto inout_native = detail::native_mem(ih, inout_acc); + detail::execute_checked(func_name, plan, &inout_native, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto inout_re_acc = inout_re.template get_access(cgh); + auto inout_im_acc = inout_im.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, inout_re, inout_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ detail::native_mem(ih, inout_re_acc), + detail::native_mem(ih, inout_im_acc) }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in, out)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto in_native = detail::native_mem(ih, in_acc); + auto out_native = detail::native_mem(ih, out_acc); + detail::execute_checked(func_name, plan, &in_native, &out_native, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_re_acc = in_re.template get_access(cgh); + auto in_im_acc = in_im.template get_access(cgh); + auto out_re_acc = out_re.template get_access(cgh); + auto out_im_acc = out_im.template get_access(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ detail::native_mem(ih, in_re_acc), + detail::native_mem(ih, in_im_acc) }; + std::array out_native{ detail::native_mem(ih, out_re_acc), + detail::native_mem(ih, out_im_acc) }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, inout, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *inout_ptr = inout; + detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, + data_type *inout_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, inout_re, inout_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ inout_re, inout_im }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in, out, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *in_ptr = in; + void *out_ptr = out; + detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re, + input_type *in_im, output_type *out_re, + output_type *out_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = + "compute_forward(desc, in_re, in_im, out_re, out_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ in_re, in_im }; + std::array out_native{ out_re, out_im }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +// Template function instantiations +#include "dft/backends/backend_forward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp new file mode 100644 index 000000000..c8f0e35c7 --- /dev/null +++ b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND rocfft + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#include "dft/backends/backend_wrappers.cxx" +}; + +#undef WRAPPER_VERSION +#undef BACKEND diff --git a/src/dft/backends/rocfft/rocfft_handle.hpp b/src/dft/backends/rocfft/rocfft_handle.hpp new file mode 100644 index 000000000..ea4f44d68 --- /dev/null +++ b/src/dft/backends/rocfft/rocfft_handle.hpp @@ -0,0 +1,34 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_ +#define _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_ + +#include + +struct rocfft_plan_t; +struct rocfft_execution_info_t; + +struct rocfft_handle { + std::optional plan = std::nullopt; + std::optional info = std::nullopt; + std::optional buffer = std::nullopt; +}; + +#endif diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index c377bbcf3..a676fa6ef 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -154,6 +154,11 @@ foreach(domain ${TARGET_DOMAINS}) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cufft) endif() + if(domain STREQUAL "dft" AND ENABLE_ROCFFT_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_dft_rocfft) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_dft_rocfft) + endif() + target_link_libraries(test_main_${domain}_ct PUBLIC gtest gtest_main diff --git a/tests/unit_tests/dft/include/compute_tester.hpp b/tests/unit_tests/dft/include/compute_tester.hpp index 53858b228..491afcc1e 100644 --- a/tests/unit_tests/dft/include/compute_tester.hpp +++ b/tests/unit_tests/dft/include/compute_tester.hpp @@ -130,7 +130,7 @@ struct DFT_Test { }); // Heuristic for the average-case error margins abs_error_margin = - std::abs(max_norm_ref) * std::log2(static_cast(forward_elements)); + 10 * std::abs(max_norm_ref) * std::log2(static_cast(forward_elements)); rel_error_margin = 200.0 * std::log2(static_cast(forward_elements)); return !skip_test(mem_acc); } diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index ebdcb1c03..40d1e8644 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -136,6 +136,16 @@ #define TEST_RUN_NVIDIAGPU_CUFFT_SELECT(q, func, ...) #endif +#ifdef ENABLE_ROCFFT_BACKEND +#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func) \ + func(oneapi::mkl::backend_selector{ q }) +#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func) +#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...) +#endif + #ifndef __HIPSYCL__ #define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu() #else @@ -156,6 +166,9 @@ else if (vendor_id == NVIDIA_ID) { \ TEST_RUN_NVIDIAGPU_CUFFT_SELECT_NO_ARGS(q, func); \ } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func); \ + } \ } \ } while (0); @@ -177,6 +190,7 @@ TEST_RUN_AMDGPU_ROCBLAS_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_AMDGPU_ROCRAND_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_AMDGPU_ROCSOLVER_SELECT(q, func, __VA_ARGS__); \ + TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, __VA_ARGS__); \ } \ } \ TEST_RUN_SYCLBLAS_SELECT(q, func, __VA_ARGS__); \ diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index 5010d54b1..84faad518 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -124,8 +124,9 @@ int main(int argc, char** argv) { if (dev.is_gpu() && vendor_id == NVIDIA_ID) continue; #endif -#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \ - !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_SYCLBLAS_BACKEND_AMD_GPU) +#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \ + !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_SYCLBLAS_BACKEND_AMD_GPU) && \ + !defined(ENABLE_ROCFFT_BACKEND) if (dev.is_gpu() && vendor_id == AMD_ID) continue; #endif