diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 23d452be367..2b4a9df095a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -3393,6 +3393,48 @@ NDArray stft( boolean normalize, boolean returnComplex); + /** + * Computes the two-dimensional Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray fft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the last two axes + */ + default NDArray fft2(long[] sizes) { + return fft2(sizes, new long[] {-2, -1}); + } + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-Inverse-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray ifft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the axes. + */ + default NDArray ifft2(long[] sizes) { + return ifft2(sizes, new long[] {-2, -1}); + } + /** * Reshapes this {@code NDArray} to the given {@link Shape}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 855c7183003..9a4ad8db93a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -912,6 +912,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 19eec837259..8b884b3993a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -1160,6 +1160,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 1e9ac83c173..606aaf24e00 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1103,6 +1103,18 @@ public NDArray stft( this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + return JniUtils.fft2(this, sizes, axes); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + return JniUtils.ifft2(this, sizes, axes); + } + /** {@inheritDoc} */ @Override public PtNDArray reshape(Shape shape) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 8e6be7b8d15..40a6a0065bc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1040,6 +1040,18 @@ public static PtNDArray stft( return new PtNDArray(ndArray.getManager(), handle); } + public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes)); + } + + public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes)); + } + public static PtNDArray real(PtNDArray ndArray) { long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle()); if (handle == -1) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index a1829306d20..54fc5419145 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -273,6 +273,10 @@ native long torchStft( boolean normalize, boolean returnComplex); + native long torchFft2(long handle, long[] sizes, long[] axes); + + native long torchIfft2(long handle, long[] sizes, long[] axes); + native long torchViewAsReal(long handle); native long torchViewAsComplex(long handle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc index 5a65e1eca69..08932098da9 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc @@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle, jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) { #ifdef V1_11_X diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 44488537765..419be4c09f6 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -1184,6 +1184,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 00e7465f745..66bb136ab37 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -1087,4 +1087,58 @@ public void testStft() { Assertions.assertAlmostEquals(result.real().flatten(), expected); } } + + @Test + public void testFft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + NDArray result = array.fft2(new long[] {3, 4}, new long[] {0, 1}); + result = result.real().flatten(1, 2); // flatten complex numbers + NDArray expected = + manager.create( + new float[][] { + {189.771f, 0f, -55.904f, 61.473f, -6.363f, 0f, -55.904f, -61.473f}, + { + -74.013f, + -10.3369f, + 71.7653f, + -108.2964f, + -1.746f, + 93.1133f, + -25.8063f, + -33.0234f + }, + { + -74.013f, 10.3369f, -25.8063f, 33.0234f, -1.746f, -93.1133f, + 71.7653f, 108.2964f + } + }); + Assertions.assertAlmostEquals(result, expected); + } + } + + @Test + public void testIfft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + long[] sizes = {3, 4}; + long[] axes = {0, 1}; + NDArray fft2 = array.fft2(sizes, axes); + NDArray actual = fft2.ifft2(sizes, axes).real(); + NDArray expected = array.toType(DataType.COMPLEX64, true).real(); + Assertions.assertAlmostEquals(expected, actual); + } + } }