Skip to content

Commit

Permalink
Add FFT2 and FFT2 inverse (#2845)
Browse files Browse the repository at this point in the history
* Added 2D FFT

* Format java

* Add default fft2

* Convert array to vectors

* Add inverse fft2

* Add better assersion in ifft2 test

* Add really better assersion in ifft2 test

* Move cast bellow ifft2 for unsupported exception

* Format java

* changed dims to axes

* changed dims to axes
  • Loading branch information
TalGrbr authored Nov 13, 2023
1 parent f84d3bb commit f39640c
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 0 deletions.
42 changes: 42 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -3393,6 +3393,48 @@ NDArray stft(
boolean normalize,
boolean returnComplex);

/**
* Computes the two-dimensional Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray fft2(long[] sizes, long[] axes);

/**
* Computes the two-dimensional Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the last two axes
*/
default NDArray fft2(long[] sizes) {
return fft2(sizes, new long[] {-2, -1});
}

/**
* Computes the two-dimensional inverse Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-Inverse-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray ifft2(long[] sizes, long[] axes);

/**
* Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the axes.
*/
default NDArray ifft2(long[] sizes) {
return ifft2(sizes, new long[] {-2, -1});
}

/**
* Reshapes this {@code NDArray} to the given {@link Shape}.
*
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,18 @@ public NDArray stft(
this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex);
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
return JniUtils.fft2(this, sizes, axes);
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
return JniUtils.ifft2(this, sizes, axes);
}

/** {@inheritDoc} */
@Override
public PtNDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,18 @@ public static PtNDArray stft(
return new PtNDArray(ndArray.getManager(), handle);
}

public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes));
}

public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes));
}

public static PtNDArray real(PtNDArray ndArray) {
long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle());
if (handle == -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ native long torchStft(
boolean normalize,
boolean returnComplex);

native long torchFft2(long handle, long[] sizes, long[] axes);

native long torchIfft2(long handle, long[] sizes, long[] axes);

native long torchViewAsReal(long handle);

native long torchViewAsComplex(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const std::vector<int64_t> sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
const std::vector<int64_t> axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const std::vector<int64_t> sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
const std::vector<int64_t> axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle,
jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) {
#ifdef V1_11_X
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1087,4 +1087,58 @@ public void testStft() {
Assertions.assertAlmostEquals(result.real().flatten(), expected);
}
}

@Test
public void testFft2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array =
manager.create(
new float[][] {
{1f, 6.6f, 4.315f, 2.0f},
{16.9f, 6.697f, 2.399f, 67.9f},
{0f, 5f, 67.09f, 9.87f}
});
NDArray result = array.fft2(new long[] {3, 4}, new long[] {0, 1});
result = result.real().flatten(1, 2); // flatten complex numbers
NDArray expected =
manager.create(
new float[][] {
{189.771f, 0f, -55.904f, 61.473f, -6.363f, 0f, -55.904f, -61.473f},
{
-74.013f,
-10.3369f,
71.7653f,
-108.2964f,
-1.746f,
93.1133f,
-25.8063f,
-33.0234f
},
{
-74.013f, 10.3369f, -25.8063f, 33.0234f, -1.746f, -93.1133f,
71.7653f, 108.2964f
}
});
Assertions.assertAlmostEquals(result, expected);
}
}

@Test
public void testIfft2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array =
manager.create(
new float[][] {
{1f, 6.6f, 4.315f, 2.0f},
{16.9f, 6.697f, 2.399f, 67.9f},
{0f, 5f, 67.09f, 9.87f}
});
long[] sizes = {3, 4};
long[] axes = {0, 1};
NDArray fft2 = array.fft2(sizes, axes);
NDArray actual = fft2.ifft2(sizes, axes).real();
NDArray expected = array.toType(DataType.COMPLEX64, true).real();
Assertions.assertAlmostEquals(expected, actual);
}
}
}

0 comments on commit f39640c

Please sign in to comment.