Skip to content

Commit

Permalink
PreRun support for r2c and other fft related fixes (#494)
Browse files Browse the repository at this point in the history
* PreRun support for r2c and other fft related fixes
  • Loading branch information
tbensonatl authored Sep 29, 2023
1 parent cdcdffd commit 539c1b7
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 11 deletions.
1 change: 1 addition & 0 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#pragma once

#include <cassert>
#include <type_traits>
#include <cuda/std/functional>
#include "matx/core/error.h"
Expand Down
32 changes: 24 additions & 8 deletions include/matx/operators/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ namespace matx
std::array<index_t, OpA::Rank()> out_dims_;
mutable matx::tensor_t<std::conditional_t<is_complex_v<typename OpA::scalar_type>,
typename OpA::scalar_type,
typename scalar_to_complex<OpA>::ctype>, OpA::Rank()> tmp_out_;
typename scalar_to_complex<typename OpA::scalar_type>::ctype>, OpA::Rank()> tmp_out_;

public:
using matxop = bool;
using scalar_type = typename OpA::scalar_type;
using scalar_type = std::conditional_t<is_complex_v<typename OpA::scalar_type>,
typename OpA::scalar_type,
typename scalar_to_complex<typename OpA::scalar_type>::ctype>;
using matx_transform_op = bool;
using fft_xform_op = bool;

Expand All @@ -76,11 +78,21 @@ namespace matx
}

if (fft_size_ != 0) {
if constexpr (std::is_same_v<PermDims, no_permute_t>) {
out_dims_[Rank() - 1] = fft_size_;
}
else {
out_dims_[perm_[0]] = fft_size_;
if constexpr (is_complex_v<typename OpA::scalar_type>) {
if constexpr (std::is_same_v<PermDims, no_permute_t>) {
out_dims_[Rank() - 1] = fft_size_;
}
else {
out_dims_[perm_[0]] = fft_size_;
}
} else {
// R2C transforms pack the results in fft_size_/2 + 1 complex elements
if constexpr (std::is_same_v<PermDims, no_permute_t>) {
out_dims_[Rank() - 1] = fft_size_ / 2 + 1;
}
else {
out_dims_[perm_[0]] = fft_size_ / 2 + 1;
}
}
}
else {
Expand All @@ -91,8 +103,12 @@ namespace matx
else {
out_dims_[Rank() - 1] = out_dims_[Rank() - 1] / 2 + 1;
}
// The output dimension could correspond to an input dimension of either
// (out_dim-1)*2 or (out_dim-1)*2+1. The FFT transform will be unable to
// deduce which is correct, so explicitly set the transform size here.
fft_size_ = a.Size(a.Rank()-1);
}
}
}
}

template <typename... Is>
Expand Down
18 changes: 17 additions & 1 deletion include/matx/operators/r2c.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace matx
using matxop = bool;
using scalar_type = typename T1::scalar_type;

__MATX_INLINE__ std::string str() const { return "r2c(" + op_.str() + ")"; }
__MATX_INLINE__ std::string str() const { return "r2c(" + op_.str() + ")"; }

__MATX_INLINE__ R2COp(T1 op, index_t orig) : op_(op), orig_size_(orig) {
static_assert(Rank() >= 1, "R2COp must have a rank 1 operator or higher");
Expand Down Expand Up @@ -97,6 +97,22 @@ namespace matx
return op_.Size(dim);
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, [[maybe_unused]] Executor &&ex) const noexcept
{
if constexpr (is_matx_op<T1>()) {
op_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, [[maybe_unused]] Executor &&ex) const noexcept
{
if constexpr (is_matx_op<T1>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
};
}

Expand Down
84 changes: 84 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3235,6 +3235,90 @@ TYPED_TEST(OperatorTestsNumericAllExecs, Downsample)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsFloatNonComplexNonHalf, R2COp)
{
MATX_ENTER_HANDLER();
using TestType = TypeParam;
using ComplexType = detail::complex_from_scalar_t<TypeParam>;

const int N1 = 5;
const int N2 = 6;

auto t1 = make_tensor<TestType>({N1});
auto t2 = make_tensor<TestType>({N2});
auto T1 = make_tensor<ComplexType>({N1});
auto T2 = make_tensor<ComplexType>({N2});

for (int i = 0; i < N1; i++) { t1(i) = static_cast<TestType>(i+1); }
for (int i = 0; i < N2; i++) { t2(i) = static_cast<TestType>(i+1); }
cudaStreamSynchronize(0);

const std::array<ComplexType, N1> T1_expected = {{
{ 15.0, 0.0 }, { -2.5, static_cast<TestType>(3.4409548) }, { -2.5, static_cast<TestType>(0.81229924) },
{ -2.5, static_cast<TestType>(-0.81229924) }, { -2.5, static_cast<TestType>(-3.4409548) }
}};

const std::array<ComplexType, N2> T2_expected = {{
{ 21.0, 0.0 }, { -3.0, static_cast<TestType>(5.19615242) }, { -3.0, static_cast<TestType>(1.73205081) },
{ -3.0, static_cast<TestType>(-4.44089210e-16) }, { -3.0, static_cast<TestType>(-1.73205081) },
{ -3.0, static_cast<TestType>(-5.19615242) }
}};

const TestType thresh = static_cast<TestType>(1.0e-6);

// Test the regular r2c path with fft() deducing the transform size
(T1 = r2c(fft(t1), N1)).run();
(T2 = r2c(fft(t2), N2)).run();

cudaStreamSynchronize(0);

for (int i = 0; i < N1; i++) {
ASSERT_NEAR(T1(i).real(), T1_expected[i].real(), thresh);
ASSERT_NEAR(T1(i).imag(), T1_expected[i].imag(), thresh);
}

for (int i = 0; i < N2; i++) {
ASSERT_NEAR(T2(i).real(), T2_expected[i].real(), thresh);
ASSERT_NEAR(T2(i).imag(), T2_expected[i].imag(), thresh);
}

// Test the r2c path when specifying the fft() transform size
(T1 = r2c(fft(t1, N1), N1)).run();
(T2 = r2c(fft(t2, N2), N2)).run();

cudaStreamSynchronize(0);

for (int i = 0; i < N1; i++) {
ASSERT_NEAR(T1(i).real(), T1_expected[i].real(), thresh);
ASSERT_NEAR(T1(i).imag(), T1_expected[i].imag(), thresh);
}

for (int i = 0; i < N2; i++) {
ASSERT_NEAR(T2(i).real(), T2_expected[i].real(), thresh);
ASSERT_NEAR(T2(i).imag(), T2_expected[i].imag(), thresh);
}

// Add an ifft to the composition to return the original tensor,
// but now in complex rather than real form. The imaginary components
// should be ~0.
(T1 = ifft(r2c(fft(t1), N1))).run();
(T2 = ifft(r2c(fft(t2), N2))).run();

cudaStreamSynchronize(0);

for (int i = 0; i < N1; i++) {
ASSERT_NEAR(T1(i).real(), t1(i), thresh);
ASSERT_NEAR(T1(i).imag(), static_cast<TestType>(0.0), thresh);
}

for (int i = 0; i < N2; i++) {
ASSERT_NEAR(T2(i).real(), t2(i), thresh);
ASSERT_NEAR(T2(i).imag(), static_cast<TestType>(0.0), thresh);
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsFloatNonHalf, FftShiftWithTransform)
{
MATX_ENTER_HANDLER();
Expand Down
4 changes: 2 additions & 2 deletions test/00_transform/FFT.cu
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ TYPED_TEST(FFTTestComplexNonHalfTypes, FFT1D1024PadR2C)
tensor_t<TypeParam, 1> avo{{fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");

(avo = fft(av)).run();
(avo = fft(av, fft_dim*2)).run();
cudaStreamSynchronize(0);

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
Expand All @@ -556,7 +556,7 @@ TYPED_TEST(FFTTestComplexNonHalfTypes, FFT1D1024PadBatchedR2C)
tensor_t<TypeParam, 2> avo{{fft_dim, fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");

(avo = fft(av)).run();
(avo = fft(av, fft_dim*2)).run();
cudaStreamSynchronize(0);

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
Expand Down

0 comments on commit 539c1b7

Please sign in to comment.