diff --git a/CMakeLists.txt b/CMakeLists.txt
index f07b8c677..10e7f8b60 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -62,6 +62,7 @@ option(ENABLE_ROCSOLVER_BACKEND "Enable the rocSOLVER backend for the LAPACK int
# dft
option(ENABLE_CUFFT_BACKEND "Enable the cuFFT backend for the DFT interface" OFF)
+option(ENABLE_ROCFFT_BACKEND "Enable the rocFFT backend for the DFT interface" OFF)
set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler")
@@ -100,7 +101,8 @@ if(ENABLE_MKLCPU_BACKEND
endif()
if(ENABLE_MKLGPU_BACKEND
OR ENABLE_MKLCPU_BACKEND
- OR ENABLE_CUFFT_BACKEND)
+ OR ENABLE_CUFFT_BACKEND
+ OR ENABLE_ROCFFT_BACKEND)
list(APPEND DOMAINS_LIST "dft")
endif()
@@ -119,8 +121,8 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++")
string(REPLACE "\\" "/" CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER})
endif()
else()
- if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUFFT_BACKEND
- OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND)
+ if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND
+ OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND)
set(CMAKE_CXX_COMPILER "clang++")
elseif(ENABLE_MKLGPU_BACKEND)
if(UNIX)
diff --git a/README.md b/README.md
index 0095e1865..1a443db2f 100644
--- a/README.md
+++ b/README.md
@@ -18,8 +18,8 @@ oneMKL is part of [oneAPI](https://oneapi.io).
- oneMKL interface |
- oneMKL selector |
+ oneMKL interface |
+ oneMKL selector |
Intel(R) oneAPI Math Kernel Library for x86 CPU |
x86 CPU |
@@ -59,6 +59,10 @@ oneMKL is part of [oneAPI](https://oneapi.io).
AMD rocRAND for AMD GPU |
AMD GPU |
+
+ AMD rocFFT for AMD GPU |
+ AMD GPU |
+
SYCL-BLAS |
x86 CPU, Intel GPU, NVIDIA GPU, AMD GPU |
@@ -238,7 +242,7 @@ Supported domains: BLAS, LAPACK, RNG, DFT
LLVM*, hipSYCL |
- DFT |
+ DFT |
Intel GPU |
Intel(R) oneAPI Math Kernel Library |
Dynamic, Static |
@@ -255,6 +259,12 @@ Supported domains: BLAS, LAPACK, RNG, DFT
Dynamic, Static |
DPC++ |
+
+ AMD GPU |
+ AMD rocFFT |
+ Dynamic, Static |
+ DPC++ |
+
@@ -464,6 +474,7 @@ Python | 3.6 or higher | No | *N/A* | *Pre-installed or Installed by user* | [PS
[AMD rocBLAS](https://rocblas.readthedocs.io/en/rocm-4.5.2/) | 4.5 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/LICENSE.md)
[AMD rocRAND](https://github.com/ROCmSoftwarePlatform/rocRAND) | 5.1.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt)
[AMD rocSOLVER](https://github.com/ROCmSoftwarePlatform/rocSOLVER) | 5.0.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt)
+[AMD rocFFT](https://github.com/ROCmSoftwarePlatform/rocFFT) | rocm-5.4.3 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocFFT/blob/rocm-5.4.3/LICENSE.md)
[NETLIB LAPACK](https://www.netlib.org/) | 3.7.1 | Yes | conan-community | ~/.conan/data or $CONAN_USER_HOME/.conan/data | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt)
[Sphinx](https://www.sphinx-doc.org/en/master/) | 2.4.4 | Yes | pip | ~/.local/bin (or similar user local directory) | [BSD License](https://github.com/sphinx-doc/sphinx/blob/3.x/LICENSE)
[SYCL-BLAS](https://github.com/codeplaysoftware/sycl-blas) | 0.1 | No | *N/A* | *Installed by user* | [Apache License v2.0](https://github.com/codeplaysoftware/sycl-blas/blob/master/LICENSE)
diff --git a/examples/README.md b/examples/README.md
index f370fc21a..6b90ba208 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -380,7 +380,7 @@ Running with single precision real data type on:
DFT Complex USM example ran OK on MKLGPU
```
-Runtime dispatching example with both MKLGPU and cuFFT backend
+Runtime dispatching example with MKLGPU, cuFFT, and rocFFT backends:
```none
SYCL_DEVICE_FILTER=gpu ./bin/example_dft_real_fwd_usm
@@ -431,3 +431,28 @@ Running with single precision real data type:
DFT example run_time dispatch
DFT example ran OK
```
+
+```none
+./bin/example_dft_real_fwd_usm
+
+########################################################################
+# DFTI complex in-place forward transform with USM API example:
+#
+# Using APIs:
+# USM forward complex in-place
+# Run-time dispatch
+#
+# Using single precision (float) data type
+#
+# Device will be selected during runtime.
+# The environment variable SYCL_DEVICE_FILTER can be used to specify
+# SYCL device
+#
+########################################################################
+
+Running DFT complex forward example on GPU device
+Device name is: AMD Radeon PRO W6800
+Running with single precision real data type:
+DFT example run_time dispatch
+DFT example ran OK
+```
\ No newline at end of file
diff --git a/examples/dft/run_time_dispatching/CMakeLists.txt b/examples/dft/run_time_dispatching/CMakeLists.txt
index bba311307..c9bc1a2ce 100644
--- a/examples/dft/run_time_dispatching/CMakeLists.txt
+++ b/examples/dft/run_time_dispatching/CMakeLists.txt
@@ -18,20 +18,17 @@
#===============================================================================
# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example (no need to specify backend when building with CMake)
+include(WarningsUtils)
+
# Build object from all example sources
set(DFT_RT_SOURCES "")
-if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND)
- list(APPEND DFT_RT_SOURCES "real_fwd_usm")
-endif()
-
-include(WarningsUtils)
-
# Set up for the right backend for run-time dispatching examples
# If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to
# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend
set(DEVICE_FILTERS "")
-if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND)
+if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_ROCFFT_BACKEND)
+ list(APPEND DFT_RT_SOURCES "real_fwd_usm")
list(APPEND DEVICE_FILTERS "gpu")
endif()
diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp
index 0d775bca4..6c9619cab 100644
--- a/include/oneapi/mkl/detail/backends.hpp
+++ b/include/oneapi/mkl/detail/backends.hpp
@@ -38,19 +38,25 @@ enum class backend {
rocrand,
syclblas,
cufft,
+ rocfft,
unsupported
};
typedef std::map backendmap;
-static backendmap backend_map = {
- { backend::mklcpu, "mklcpu" }, { backend::mklgpu, "mklgpu" },
- { backend::cublas, "cublas" }, { backend::cusolver, "cusolver" },
- { backend::curand, "curand" }, { backend::netlib, "netlib" },
- { backend::rocblas, "rocblas" }, { backend::rocrand, "rocrand" },
- { backend::rocsolver, "rocsolver" }, { backend::syclblas, "syclblas" },
- { backend::cufft, "cufft" }, { backend::unsupported, "unsupported" }
-};
+static backendmap backend_map = { { backend::mklcpu, "mklcpu" },
+ { backend::mklgpu, "mklgpu" },
+ { backend::cublas, "cublas" },
+ { backend::cusolver, "cusolver" },
+ { backend::curand, "curand" },
+ { backend::netlib, "netlib" },
+ { backend::rocblas, "rocblas" },
+ { backend::rocrand, "rocrand" },
+ { backend::rocsolver, "rocsolver" },
+ { backend::syclblas, "syclblas" },
+ { backend::cufft, "cufft" },
+ { backend::rocfft, "rocfft" },
+ { backend::unsupported, "unsupported" } };
} //namespace mkl
} //namespace oneapi
diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp
index 8070e5a0e..cfae0e43e 100644
--- a/include/oneapi/mkl/detail/backends_table.hpp
+++ b/include/oneapi/mkl/detail/backends_table.hpp
@@ -96,6 +96,12 @@ static std::map>> libraries =
{
#ifdef ENABLE_MKLGPU_BACKEND
LIB_NAME("dft_mklgpu")
+#endif
+ } },
+ { device::amdgpu,
+ {
+#ifdef ENABLE_ROCFFT_BACKEND
+ LIB_NAME("dft_rocfft"),
#endif
} },
{ device::nvidiagpu,
diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp
index f650d1f61..82f31b792 100644
--- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp
+++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp
@@ -72,6 +72,10 @@ class descriptor {
void commit(backend_selector selector);
#endif
+#ifdef ENABLE_ROCFFT_BACKEND
+ void commit(backend_selector selector);
+#endif
+
const dft_values& get_values() const noexcept {
return values_;
};
diff --git a/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp
new file mode 100644
index 000000000..8ccf57858
--- /dev/null
+++ b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp
@@ -0,0 +1,49 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMKL_DFT_ROCFFT_HPP_
+#define _ONEMKL_DFT_ROCFFT_HPP_
+
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include "oneapi/mkl/detail/export.hpp"
+#include "oneapi/mkl/dft/detail/types_impl.hpp"
+
+namespace oneapi::mkl::dft {
+
+namespace detail {
+// Forward declarations
+template
+class commit_impl;
+
+template
+class descriptor;
+} // namespace detail
+
+namespace rocfft {
+#include "oneapi/mkl/dft/detail/dft_ct.hxx"
+} // namespace rocfft
+
+} // namespace oneapi::mkl::dft
+
+#endif // _ONEMKL_DFT_ROCFFT_HPP_
diff --git a/src/config.hpp.in b/src/config.hpp.in
index 702c11943..8a24befc5 100644
--- a/src/config.hpp.in
+++ b/src/config.hpp.in
@@ -21,19 +21,20 @@
#define ONEMKL_CONFIG_H
#cmakedefine ENABLE_CUBLAS_BACKEND
-#cmakedefine ENABLE_CUSOLVER_BACKEND
#cmakedefine ENABLE_CUFFT_BACKEND
-#cmakedefine ENABLE_ROCBLAS_BACKEND
-#cmakedefine ENABLE_ROCRAND_BACKEND
-#cmakedefine ENABLE_ROCSOLVER_BACKEND
#cmakedefine ENABLE_CURAND_BACKEND
+#cmakedefine ENABLE_CUSOLVER_BACKEND
#cmakedefine ENABLE_MKLCPU_BACKEND
#cmakedefine ENABLE_MKLGPU_BACKEND
#cmakedefine ENABLE_NETLIB_BACKEND
+#cmakedefine ENABLE_ROCBLAS_BACKEND
+#cmakedefine ENABLE_ROCFFT_BACKEND
+#cmakedefine ENABLE_ROCRAND_BACKEND
+#cmakedefine ENABLE_ROCSOLVER_BACKEND
#cmakedefine ENABLE_SYCLBLAS_BACKEND
+#cmakedefine ENABLE_SYCLBLAS_BACKEND_AMD_GPU
#cmakedefine ENABLE_SYCLBLAS_BACKEND_INTEL_CPU
#cmakedefine ENABLE_SYCLBLAS_BACKEND_INTEL_GPU
-#cmakedefine ENABLE_SYCLBLAS_BACKEND_AMD_GPU
#cmakedefine ENABLE_SYCLBLAS_BACKEND_NVIDIA_GPU
#cmakedefine BUILD_SHARED_LIBS
#cmakedefine REF_BLAS_LIBNAME "@REF_BLAS_LIBNAME@"
diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt
index 1390cbee1..1fbea19e4 100644
--- a/src/dft/backends/CMakeLists.txt
+++ b/src/dft/backends/CMakeLists.txt
@@ -28,3 +28,7 @@ endif()
if(ENABLE_CUFFT_BACKEND)
add_subdirectory(cufft)
endif()
+
+if(ENABLE_ROCFFT_BACKEND)
+ add_subdirectory(rocfft)
+endif()
diff --git a/src/dft/backends/rocfft/CMakeLists.txt b/src/dft/backends/rocfft/CMakeLists.txt
new file mode 100644
index 000000000..7d27854d6
--- /dev/null
+++ b/src/dft/backends/rocfft/CMakeLists.txt
@@ -0,0 +1,76 @@
+#===============================================================================
+# Copyright Codeplay Software Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions
+# and limitations under the License.
+#
+#
+# SPDX-License-Identifier: Apache-2.0
+#===============================================================================
+
+set(LIB_NAME onemkl_dft_rocfft)
+set(LIB_OBJ ${LIB_NAME}_obj)
+
+
+add_library(${LIB_NAME})
+add_library(${LIB_OBJ} OBJECT
+ descriptor.cpp
+ commit.cpp
+ forward.cpp
+ backward.cpp
+ compute_signature.cpp
+ $<$: mkl_dft_rocfft_wrappers.cpp>
+)
+
+target_include_directories(${LIB_OBJ}
+ PRIVATE ${PROJECT_SOURCE_DIR}/include
+ ${PROJECT_SOURCE_DIR}/src
+ ${CMAKE_BINARY_DIR}/bin
+ ${MKL_INCLUDE}
+)
+
+target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT})
+
+find_package(HIP REQUIRED)
+find_package(rocfft REQUIRED)
+
+target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocfft)
+
+target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL})
+
+set_target_properties(${LIB_OBJ} PROPERTIES
+ POSITION_INDEPENDENT_CODE ON
+)
+target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ})
+
+# Set oneMKL libraries as not transitive for dynamic
+if(BUILD_SHARED_LIBS)
+ set_target_properties(${LIB_NAME} PROPERTIES
+ INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL
+ )
+endif()
+
+# Add major version to the library
+set_target_properties(${LIB_NAME} PROPERTIES
+ SOVERSION ${PROJECT_VERSION_MAJOR}
+)
+
+# Add dependencies rpath to the library
+list(APPEND CMAKE_BUILD_RPATH $)
+
+# Add the library to install package
+install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets)
+install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets
+ RUNTIME DESTINATION bin
+ ARCHIVE DESTINATION lib
+ LIBRARY DESTINATION lib
+)
diff --git a/src/dft/backends/rocfft/backward.cpp b/src/dft/backends/rocfft/backward.cpp
new file mode 100644
index 000000000..6a4616c8b
--- /dev/null
+++ b/src/dft/backends/rocfft/backward.cpp
@@ -0,0 +1,264 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include
+#include
+
+#include "execute_helper.hpp"
+#include "oneapi/mkl/dft/backward.hpp"
+#include "oneapi/mkl/dft/detail/commit_impl.hpp"
+#include "oneapi/mkl/dft/types.hpp"
+#include "oneapi/mkl/exceptions.hpp"
+#include "rocfft_handle.hpp"
+
+namespace oneapi::mkl::dft::rocfft {
+namespace detail {
+template
+rocfft_plan get_bwd_plan(dft::detail::commit_impl *commit) {
+ return static_cast(commit->get_handle())[1].plan.value();
+}
+
+template
+rocfft_execution_info get_bwd_info(dft::detail::commit_impl *commit) {
+ return static_cast(commit->get_handle())[1].info.value();
+}
+} // namespace detail
+// BUFFER version
+
+//In-place transform
+template
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto inout_acc = inout.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, inout)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ auto inout_native = detail::native_mem(ih, inout_acc);
+ detail::execute_checked(func_name, plan, &inout_native, nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re,
+ sycl::buffer &inout_im) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto inout_re_acc = inout_re.template get_access(cgh);
+ auto inout_im_acc = inout_im.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, inout_re, inout_im)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array inout_native = { detail::native_mem(ih, inout_re_acc),
+ detail::native_mem(ih, inout_im_acc) };
+ detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform
+template
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in,
+ sycl::buffer &out) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto in_acc = in.template get_access(cgh);
+ auto out_acc = out.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, in, out)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ auto in_native = detail::native_mem(ih, in_acc);
+ auto out_native = detail::native_mem(ih, out_acc);
+ detail::execute_checked(func_name, plan, &in_native, &out_native, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re,
+ sycl::buffer &in_im,
+ sycl::buffer &out_re,
+ sycl::buffer &out_im) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto in_re_acc = in_re.template get_access(cgh);
+ auto in_im_acc = in_im.template get_access(cgh);
+ auto out_re_acc = out_re.template get_access(cgh);
+ auto out_im_acc = out_im.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array in_native = { detail::native_mem(ih, in_re_acc),
+ detail::native_mem(ih, in_im_acc) };
+ std::array out_native = { detail::native_mem(ih, out_re_acc),
+ detail::native_mem(ih, out_im_acc) };
+ detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//USM version
+
+//In-place transform
+template
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout,
+ const std::vector &deps) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, inout, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ void *inout_ptr = inout;
+ detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re,
+ data_type *inout_im,
+ const std::vector &deps) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, inout_re, inout_im, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array inout_native = { inout_re, inout_im };
+ detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform
+template
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out,
+ const std::vector &deps) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_backward(desc, in, out, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ void *in_ptr = in;
+ void *out_ptr = out;
+ detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re,
+ input_type *in_im, output_type *out_re,
+ output_type *out_im,
+ const std::vector &deps) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_bwd_plan(commit);
+ auto info = detail::get_bwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name =
+ "compute_backward(desc, in_re, in_im, out_re, out_im, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array in_native = { in_re, in_im };
+ std::array out_native = { out_re, out_im };
+ detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+// Template function instantiations
+#include "dft/backends/backend_backward_instantiations.cxx"
+
+} // namespace oneapi::mkl::dft::rocfft
diff --git a/src/dft/backends/rocfft/commit.cpp b/src/dft/backends/rocfft/commit.cpp
new file mode 100644
index 000000000..723322d15
--- /dev/null
+++ b/src/dft/backends/rocfft/commit.cpp
@@ -0,0 +1,430 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include "oneapi/mkl/dft/detail/commit_impl.hpp"
+#include "oneapi/mkl/dft/detail/descriptor_impl.hpp"
+#include "oneapi/mkl/dft/detail/types_impl.hpp"
+#include "oneapi/mkl/dft/types.hpp"
+#include "oneapi/mkl/exceptions.hpp"
+#include "rocfft_handle.hpp"
+
+namespace oneapi::mkl::dft::rocfft {
+namespace detail {
+
+// rocfft has global setup and cleanup functions which use some global state internally.
+// Each can be called multiple times in an application, but due to the global nature, they always need to alternate.
+// I don't believe its possible to avoid the user calling rocfft_cleanup in their own code,
+// breaking our code, but we can try avoid it for them.
+// rocfft_cleanup internally uses some singletons, so it is very difficult to decide if this is safe due to
+// the static initialisation order fiasco.
+class rocfft_singleton {
+ rocfft_singleton() {
+ const auto result = rocfft_setup();
+ if (result != rocfft_status_success) {
+ throw mkl::exception(
+ "DFT", "rocfft",
+ "Failed to setup rocfft. returned status " + std::to_string(result));
+ }
+ }
+
+ ~rocfft_singleton() {
+ (void)rocfft_cleanup();
+ }
+
+ // no copies or moves allowed
+ rocfft_singleton(const rocfft_singleton& other) = delete;
+ rocfft_singleton(rocfft_singleton&& other) noexcept = delete;
+ rocfft_singleton& operator=(const rocfft_singleton& other) = delete;
+ rocfft_singleton& operator=(rocfft_singleton&& other) noexcept = delete;
+
+public:
+ static void init() {
+ static rocfft_singleton instance;
+ (void)instance;
+ }
+};
+
+/// Commit impl class specialization for rocFFT.
+template
+class rocfft_commit final : public dft::detail::commit_impl {
+private:
+ // For real to complex transforms, the "transform_type" arg also encodes the direction (e.g. rocfft_transform_type_*_forward vs rocfft_transform_type_*_backward)
+ // in the plan so we must have one for each direction.
+ // We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while rocFFT uses a directional "in_distance" and "out_distance".
+ // The same is also true for "FORWARD_SCALE" and "BACKWARD_SCALE".
+ // handles[0] is forward, handles[1] is backward
+ std::array handles{};
+
+public:
+ rocfft_commit(sycl::queue& queue, const dft::detail::dft_values& config_values)
+ : oneapi::mkl::dft::detail::commit_impl(queue, backend::rocfft) {
+ if constexpr (prec == dft::detail::precision::DOUBLE) {
+ if (!queue.get_device().has(sycl::aspect::fp64)) {
+ throw mkl::exception("DFT", "commit", "Device does not support double precision.");
+ }
+ }
+ // initialise the rocFFT global state
+ rocfft_singleton::init();
+ }
+
+ void clean_plans() {
+ if (handles[0].plan) {
+ if (rocfft_plan_destroy(handles[0].plan.value()) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to destroy forward plan.");
+ }
+ handles[0].plan = std::nullopt;
+ }
+ if (handles[1].plan) {
+ if (rocfft_plan_destroy(handles[1].plan.value()) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to destroy backward plan.");
+ }
+ handles[1].plan = std::nullopt;
+ }
+
+ if (handles[0].info) {
+ if (rocfft_execution_info_destroy(handles[0].info.value()) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to destroy forward execution info .");
+ }
+ handles[0].info = std::nullopt;
+ }
+ if (handles[1].info) {
+ if (rocfft_execution_info_destroy(handles[1].info.value()) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to destroy backward execution info .");
+ }
+ handles[1].info = std::nullopt;
+ }
+ if (handles[0].buffer) {
+ if (hipFree(handles[0].buffer.value()) != hipSuccess) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to free forward buffer.");
+ }
+ handles[0].buffer = std::nullopt;
+ }
+ if (handles[1].buffer) {
+ if (hipFree(handles[1].buffer.value()) != hipSuccess) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to free backward buffer.");
+ }
+ handles[1].buffer = std::nullopt;
+ }
+ }
+
+ void commit(const dft::detail::dft_values& config_values) override {
+ // this could be a recommit
+ clean_plans();
+
+ const rocfft_result_placement placement =
+ (config_values.placement == dft::config_value::INPLACE) ? rocfft_placement_inplace
+ : rocfft_placement_notinplace;
+
+ constexpr rocfft_transform_type fwd_type = [] {
+ if constexpr (dom == dft::domain::COMPLEX) {
+ return rocfft_transform_type_complex_forward;
+ }
+ else {
+ return rocfft_transform_type_real_forward;
+ }
+ }();
+ constexpr rocfft_transform_type bwd_type = [] {
+ if constexpr (dom == dft::domain::COMPLEX) {
+ return rocfft_transform_type_complex_inverse;
+ }
+ else {
+ return rocfft_transform_type_real_inverse;
+ }
+ }();
+
+ constexpr rocfft_precision precision = [] {
+ if constexpr (prec == dft::precision::SINGLE) {
+ return rocfft_precision_single;
+ }
+ else {
+ return rocfft_precision_double;
+ }
+ }();
+
+ const std::size_t dimensions = config_values.dimensions.size();
+
+ constexpr std::size_t max_supported_dims = 3;
+ std::array lengths;
+ // rocfft does dimensions in the reverse order to oneMKL
+ std::copy(config_values.dimensions.crbegin(), config_values.dimensions.crend(),
+ lengths.data());
+
+ const std::size_t number_of_transforms =
+ static_cast(config_values.number_of_transforms);
+
+ const std::size_t fwd_distance = static_cast(config_values.fwd_dist);
+ const std::size_t bwd_distance = static_cast(config_values.bwd_dist);
+
+ const rocfft_array_type fwd_array_ty = [&config_values]() {
+ if constexpr (dom == dft::domain::COMPLEX) {
+ if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) {
+ return rocfft_array_type_complex_interleaved;
+ }
+ else {
+ return rocfft_array_type_complex_planar;
+ }
+ }
+ else {
+ return rocfft_array_type_real;
+ }
+ }();
+ const rocfft_array_type bwd_array_ty = [&config_values]() {
+ if constexpr (dom == dft::domain::COMPLEX) {
+ if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) {
+ return rocfft_array_type_complex_interleaved;
+ }
+ else {
+ return rocfft_array_type_complex_planar;
+ }
+ }
+ else {
+ if (config_values.conj_even_storage != dft::config_value::COMPLEX_COMPLEX) {
+ throw mkl::exception(
+ "dft/backends/rocfft", __FUNCTION__,
+ "only COMPLEX_COMPLEX conjugate_even_storage is supported");
+ }
+ return rocfft_array_type_hermitian_interleaved;
+ }
+ }();
+
+ std::array in_offsets{
+ static_cast(config_values.input_strides[0]),
+ static_cast(config_values.input_strides[0])
+ };
+ std::array out_offsets{
+ static_cast(config_values.output_strides[0]),
+ static_cast(config_values.output_strides[0])
+ };
+
+ std::array in_strides;
+ std::array out_strides;
+
+ for (std::size_t i = 0; i != dimensions; ++i) {
+ in_strides[i] = config_values.input_strides[dimensions - i];
+ out_strides[i] = config_values.output_strides[dimensions - i];
+ }
+
+ rocfft_plan_description plan_desc;
+ if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to create plan description.");
+ }
+
+ // plan_description can be destroyed afted plan_create
+ auto description_destroy = [](rocfft_plan_description p) {
+ if (rocfft_plan_description_destroy(p) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to destroy plan description.");
+ }
+ };
+ std::unique_ptr
+ description_destroyer(plan_desc, description_destroy);
+
+ // When creating real-complex descriptions, the strides will always be wrong for one of the directions.
+ // This is because the least significant dimension is symmetric.
+ // If the strides are invalid (too small to fit) then just don't bother creating the plan.
+ const bool ignore_strides = dom == dft::domain::COMPLEX || dimensions == 1;
+ const bool valid_forward =
+ ignore_strides || (lengths[0] <= in_strides[1] && lengths[0] / 2 + 1 <= out_strides[1]);
+ const bool valid_backward =
+ ignore_strides || (lengths[0] <= out_strides[1] && lengths[0] / 2 + 1 <= in_strides[1]);
+
+ if (valid_forward) {
+ auto res =
+ rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty,
+ in_offsets.data(), // in offsets
+ out_offsets.data(), // out offsets
+ dimensions,
+ in_strides.data(), //in strides
+ fwd_distance, // in distance
+ dimensions,
+ out_strides.data(), // out strides
+ bwd_distance // out distance
+ );
+ if (res != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set forward data layout.");
+ }
+
+ if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set forward scale factor.");
+ }
+
+ rocfft_plan fwd_plan;
+ res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions,
+ lengths.data(), number_of_transforms, plan_desc);
+
+ if (res != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to create forward plan.");
+ }
+
+ handles[0].plan = fwd_plan;
+
+ rocfft_execution_info fwd_info;
+ if (rocfft_execution_info_create(&fwd_info) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to create forward execution info.");
+ }
+ handles[0].info = fwd_info;
+
+ // plan work buffer
+ std::size_t work_buf_size;
+ if (rocfft_plan_get_work_buffer_size(fwd_plan, &work_buf_size) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to get forward work buffer size.");
+ }
+ if (work_buf_size != 0) {
+ void* work_buf;
+ if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to get allocate forward work buffer.");
+ }
+ handles[0].buffer = work_buf;
+ if (rocfft_execution_info_set_work_buffer(fwd_info, work_buf, work_buf_size) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set forward work buffer.");
+ }
+ }
+ }
+
+ if (valid_backward) {
+ auto res =
+ rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty,
+ in_offsets.data(), // in offsets
+ out_offsets.data(), // out offsets
+ dimensions,
+ in_strides.data(), //in strides
+ bwd_distance, // in distance
+ dimensions,
+ out_strides.data(), // out strides
+ fwd_distance // out distance
+ );
+ if (res != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set backward data layout.");
+ }
+
+ if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set backward scale factor.");
+ }
+
+ rocfft_plan bwd_plan;
+ res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions,
+ lengths.data(), number_of_transforms, plan_desc);
+ if (res != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to create backward rocFFT plan.");
+ }
+ handles[1].plan = bwd_plan;
+
+ rocfft_execution_info bwd_info;
+ if (rocfft_execution_info_create(&bwd_info) != rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to create backward execution info.");
+ }
+ handles[1].info = bwd_info;
+
+ std::size_t work_buf_size;
+ if (rocfft_plan_get_work_buffer_size(bwd_plan, &work_buf_size) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to get backward work buffer size.");
+ }
+
+ if (work_buf_size != 0) {
+ void* work_buf;
+ if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to get allocate backward work buffer.");
+ }
+ handles[1].buffer = work_buf;
+
+ if (rocfft_execution_info_set_work_buffer(bwd_info, work_buf, work_buf_size) !=
+ rocfft_status_success) {
+ throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
+ "Failed to set backward work buffer.");
+ }
+ }
+ }
+ }
+
+ ~rocfft_commit() override {
+ clean_plans();
+ }
+
+ // Rule of three. Copying could lead to memory safety issues.
+ rocfft_commit(const rocfft_commit& other) = delete;
+ rocfft_commit& operator=(const rocfft_commit& other) = delete;
+
+ void* get_handle() noexcept override {
+ return handles.data();
+ }
+};
+} // namespace detail
+
+template
+dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc,
+ sycl::queue& sycl_queue) {
+ return new detail::rocfft_commit(sycl_queue, desc.get_values());
+}
+
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+
+} // namespace oneapi::mkl::dft::rocfft
diff --git a/src/dft/backends/rocfft/compute_signature.cpp b/src/dft/backends/rocfft/compute_signature.cpp
new file mode 100644
index 000000000..c26826fcc
--- /dev/null
+++ b/src/dft/backends/rocfft/compute_signature.cpp
@@ -0,0 +1,24 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp"
+
+#define BACKEND rocfft
+
+#include "dft/backends/backend_compute_signature.cxx"
diff --git a/src/dft/backends/rocfft/descriptor.cpp b/src/dft/backends/rocfft/descriptor.cpp
new file mode 100644
index 000000000..83fdbe1dc
--- /dev/null
+++ b/src/dft/backends/rocfft/descriptor.cpp
@@ -0,0 +1,51 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include "oneapi/mkl/dft/descriptor.hpp"
+#include "../../descriptor.cxx"
+
+#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp"
+
+namespace oneapi {
+namespace mkl {
+namespace dft {
+
+template
+void descriptor::commit(backend_selector selector) {
+ if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
+ if (pimpl_) {
+ pimpl_->get_queue().wait();
+ }
+ pimpl_.reset(rocfft::create_commit(*this, selector.get_queue()));
+ }
+ pimpl_->commit(values_);
+}
+
+template void descriptor::commit(
+ backend_selector);
+template void descriptor::commit(
+ backend_selector);
+template void descriptor::commit(
+ backend_selector);
+template void descriptor::commit(
+ backend_selector);
+
+} //namespace dft
+} //namespace mkl
+} //namespace oneapi
diff --git a/src/dft/backends/rocfft/execute_helper.hpp b/src/dft/backends/rocfft/execute_helper.hpp
new file mode 100644
index 000000000..4dff6831d
--- /dev/null
+++ b/src/dft/backends/rocfft/execute_helper.hpp
@@ -0,0 +1,97 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_
+#define _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_
+
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include "oneapi/mkl/dft/detail/commit_impl.hpp"
+#include "oneapi/mkl/dft/detail/descriptor_impl.hpp"
+#include "oneapi/mkl/dft/types.hpp"
+#include "oneapi/mkl/exceptions.hpp"
+
+#include
+#include
+
+namespace oneapi::mkl::dft::rocfft::detail {
+
+template
+inline dft::detail::commit_impl *checked_get_commit(
+ dft::detail::descriptor &desc) {
+ auto commit_handle = dft::detail::get_commit(desc);
+ if (commit_handle == nullptr || commit_handle->get_backend() != backend::rocfft) {
+ throw mkl::invalid_argument("dft/backends/rocfft", "get_commit",
+ "DFT descriptor has not been commited for rocFFT");
+ }
+ return commit_handle;
+}
+
+/// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match
+/// the expected value.
+template
+inline auto expect_config(DescT &desc, const char *message) {
+ dft::config_value actual{ 0 };
+ desc.get_value(Param, &actual);
+ if (actual != Expected) {
+ throw mkl::invalid_argument("dft/backends/rocfft", "expect_config", message);
+ }
+}
+
+template
+inline void *native_mem(sycl::interop_handle &ih, Acc &buf) {
+ return ih.get_native_mem(buf);
+}
+
+inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &ih,
+ rocfft_execution_info info) {
+ auto stream = ih.get_native_queue();
+ auto result = rocfft_execution_info_set_stream(info, stream);
+ if (result != rocfft_status_success) {
+ throw oneapi::mkl::exception(
+ "dft/backends/rocfft", func,
+ "rocfft_execution_info_set_stream returned " + std::to_string(result));
+ }
+ return stream;
+}
+
+inline void sync_checked(const std::string &func, hipStream_t stream) {
+ auto result = hipStreamSynchronize(stream);
+ if (result != hipSuccess) {
+ throw oneapi::mkl::exception("dft/backends/rocfft", func,
+ "hipStreamSynchronize returned " + std::to_string(result));
+ }
+}
+
+inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[],
+ void *out_buffer[], rocfft_execution_info info) {
+ auto result = rocfft_execute(plan, in_buffer, out_buffer, info);
+ if (result != rocfft_status_success) {
+ throw oneapi::mkl::exception("dft/backends/rocfft", func,
+ "rocfft_execute returned " + std::to_string(result));
+ }
+}
+
+} // namespace oneapi::mkl::dft::rocfft::detail
+
+#endif
diff --git a/src/dft/backends/rocfft/forward.cpp b/src/dft/backends/rocfft/forward.cpp
new file mode 100644
index 000000000..9b92d9097
--- /dev/null
+++ b/src/dft/backends/rocfft/forward.cpp
@@ -0,0 +1,266 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include
+#include
+
+#include "execute_helper.hpp"
+#include "oneapi/mkl/dft/detail/commit_impl.hpp"
+#include "oneapi/mkl/dft/forward.hpp"
+#include "oneapi/mkl/dft/types.hpp"
+#include "oneapi/mkl/exceptions.hpp"
+#include "rocfft_handle.hpp"
+
+namespace oneapi::mkl::dft::rocfft {
+
+namespace detail {
+template
+rocfft_plan get_fwd_plan(dft::detail::commit_impl *commit) {
+ return static_cast(commit->get_handle())[0].plan.value();
+}
+
+template
+rocfft_execution_info get_fwd_info(dft::detail::commit_impl *commit) {
+ return static_cast(commit->get_handle())[0].info.value();
+}
+} // namespace detail
+
+// BUFFER version
+
+//In-place transform
+template
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto inout_acc = inout.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, inout)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ auto inout_native = detail::native_mem(ih, inout_acc);
+ detail::execute_checked(func_name, plan, &inout_native, nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re,
+ sycl::buffer &inout_im) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto inout_re_acc = inout_re.template get_access(cgh);
+ auto inout_im_acc = inout_im.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, inout_re, inout_im)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array inout_native{ detail::native_mem(ih, inout_re_acc),
+ detail::native_mem(ih, inout_im_acc) };
+ detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform
+template
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in,
+ sycl::buffer &out) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto in_acc = in.template get_access(cgh);
+ auto out_acc = out.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, in, out)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ auto in_native = detail::native_mem(ih, in_acc);
+ auto out_native = detail::native_mem(ih, out_acc);
+ detail::execute_checked(func_name, plan, &in_native, &out_native, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re,
+ sycl::buffer &in_im,
+ sycl::buffer &out_re,
+ sycl::buffer &out_im) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ queue.submit([&](sycl::handler &cgh) {
+ auto in_re_acc = in_re.template get_access(cgh);
+ auto in_im_acc = in_im.template get_access(cgh);
+ auto out_re_acc = out_re.template get_access(cgh);
+ auto out_im_acc = out_im.template get_access(cgh);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array in_native{ detail::native_mem(ih, in_re_acc),
+ detail::native_mem(ih, in_im_acc) };
+ std::array out_native{ detail::native_mem(ih, out_re_acc),
+ detail::native_mem(ih, out_im_acc) };
+ detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//USM version
+
+//In-place transform
+template
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout,
+ const std::vector &deps) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, inout, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ void *inout_ptr = inout;
+ detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re,
+ data_type *inout_im,
+ const std::vector &deps) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, inout_re, inout_im, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array inout_native{ inout_re, inout_im };
+ detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform
+template
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
+ const std::vector &deps) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name = "compute_forward(desc, in, out, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ void *in_ptr = in;
+ void *out_ptr = out;
+ detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
+template
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re,
+ input_type *in_im, output_type *out_re,
+ output_type *out_im,
+ const std::vector &deps) {
+ auto commit = detail::checked_get_commit(desc);
+ auto queue = commit->get_queue();
+ auto plan = detail::get_fwd_plan(commit);
+ auto info = detail::get_fwd_info(commit);
+
+ return queue.submit([&](sycl::handler &cgh) {
+ cgh.depends_on(deps);
+
+ cgh.host_task([=](sycl::interop_handle ih) {
+ const std::string func_name =
+ "compute_forward(desc, in_re, in_im, out_re, out_im, deps)";
+ auto stream = detail::setup_stream(func_name, ih, info);
+
+ std::array in_native{ in_re, in_im };
+ std::array out_native{ out_re, out_im };
+ detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
+ detail::sync_checked(func_name, stream);
+ });
+ });
+}
+
+// Template function instantiations
+#include "dft/backends/backend_forward_instantiations.cxx"
+
+} // namespace oneapi::mkl::dft::rocfft
diff --git a/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp
new file mode 100644
index 000000000..c8f0e35c7
--- /dev/null
+++ b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp
@@ -0,0 +1,32 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp"
+#include "dft/function_table.hpp"
+
+#define WRAPPER_VERSION 1
+#define BACKEND rocfft
+
+extern "C" dft_function_table_t mkl_dft_table = {
+ WRAPPER_VERSION,
+#include "dft/backends/backend_wrappers.cxx"
+};
+
+#undef WRAPPER_VERSION
+#undef BACKEND
diff --git a/src/dft/backends/rocfft/rocfft_handle.hpp b/src/dft/backends/rocfft/rocfft_handle.hpp
new file mode 100644
index 000000000..ea4f44d68
--- /dev/null
+++ b/src/dft/backends/rocfft/rocfft_handle.hpp
@@ -0,0 +1,34 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_
+#define _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_
+
+#include
+
+struct rocfft_plan_t;
+struct rocfft_execution_info_t;
+
+struct rocfft_handle {
+ std::optional plan = std::nullopt;
+ std::optional info = std::nullopt;
+ std::optional buffer = std::nullopt;
+};
+
+#endif
diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt
index c377bbcf3..a676fa6ef 100644
--- a/tests/unit_tests/CMakeLists.txt
+++ b/tests/unit_tests/CMakeLists.txt
@@ -154,6 +154,11 @@ foreach(domain ${TARGET_DOMAINS})
list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cufft)
endif()
+ if(domain STREQUAL "dft" AND ENABLE_ROCFFT_BACKEND)
+ add_dependencies(test_main_${domain}_ct onemkl_dft_rocfft)
+ list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_dft_rocfft)
+ endif()
+
target_link_libraries(test_main_${domain}_ct PUBLIC
gtest
gtest_main
diff --git a/tests/unit_tests/dft/include/compute_tester.hpp b/tests/unit_tests/dft/include/compute_tester.hpp
index 53858b228..491afcc1e 100644
--- a/tests/unit_tests/dft/include/compute_tester.hpp
+++ b/tests/unit_tests/dft/include/compute_tester.hpp
@@ -130,7 +130,7 @@ struct DFT_Test {
});
// Heuristic for the average-case error margins
abs_error_margin =
- std::abs(max_norm_ref) * std::log2(static_cast(forward_elements));
+ 10 * std::abs(max_norm_ref) * std::log2(static_cast(forward_elements));
rel_error_margin = 200.0 * std::log2(static_cast(forward_elements));
return !skip_test(mem_acc);
}
diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp
index ebdcb1c03..40d1e8644 100644
--- a/tests/unit_tests/include/test_helper.hpp
+++ b/tests/unit_tests/include/test_helper.hpp
@@ -136,6 +136,16 @@
#define TEST_RUN_NVIDIAGPU_CUFFT_SELECT(q, func, ...)
#endif
+#ifdef ENABLE_ROCFFT_BACKEND
+#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func) \
+ func(oneapi::mkl::backend_selector{ q })
+#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...) \
+ func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__)
+#else
+#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func)
+#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...)
+#endif
+
#ifndef __HIPSYCL__
#define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu()
#else
@@ -156,6 +166,9 @@
else if (vendor_id == NVIDIA_ID) { \
TEST_RUN_NVIDIAGPU_CUFFT_SELECT_NO_ARGS(q, func); \
} \
+ else if (vendor_id == AMD_ID) { \
+ TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func); \
+ } \
} \
} while (0);
@@ -177,6 +190,7 @@
TEST_RUN_AMDGPU_ROCBLAS_SELECT(q, func, __VA_ARGS__); \
TEST_RUN_AMDGPU_ROCRAND_SELECT(q, func, __VA_ARGS__); \
TEST_RUN_AMDGPU_ROCSOLVER_SELECT(q, func, __VA_ARGS__); \
+ TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, __VA_ARGS__); \
} \
} \
TEST_RUN_SYCLBLAS_SELECT(q, func, __VA_ARGS__); \
diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp
index 5010d54b1..84faad518 100644
--- a/tests/unit_tests/main_test.cpp
+++ b/tests/unit_tests/main_test.cpp
@@ -124,8 +124,9 @@ int main(int argc, char** argv) {
if (dev.is_gpu() && vendor_id == NVIDIA_ID)
continue;
#endif
-#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \
- !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_SYCLBLAS_BACKEND_AMD_GPU)
+#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \
+ !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_SYCLBLAS_BACKEND_AMD_GPU) && \
+ !defined(ENABLE_ROCFFT_BACKEND)
if (dev.is_gpu() && vendor_id == AMD_ID)
continue;
#endif