Skip to content

Commit

Permalink
Implement Short-time Fourier transform and its inverse (#587)
Browse files Browse the repository at this point in the history
* Initial stft implementation

* Finish STFT

* Cleanup

* Bump AMDGPU compat

* Install GPU backends only when testing them

* Add spectrogram

* Move audio documentation to its own page

- Convert examples to doctests or evaluate during build time
-More tests

* Fixes

* Use Makie for spectrogram plots

* Add mel-scale filterbanks

* Run doctests when building documentation instead of a separate CI stage

* Minor fix

* Move FFTW-dependent functions to extension
  • Loading branch information
pxl-th authored Jul 4, 2024
1 parent 62f6074 commit 4159154
Show file tree
Hide file tree
Showing 19 changed files with 808 additions and 55 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ jobs:
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: |
julia --color=yes --project=docs/ -e '
using NNlib
# using Pkg; Pkg.activate("docs")
using Documenter
using Documenter: doctest
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true)
doctest(NNlib)'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
32 changes: 6 additions & 26 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,31 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"

[compat]
AMDGPU = "0.8, 0.9"
AMDGPU = "0.9.4"
Adapt = "3.2, 4"
Atomix = "0.1"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
EnzymeCore = "0.5, 0.6, 0.7"
FFTW = "1.8.0"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
LinearAlgebra = "<0.0.1, 1"
Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.9"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"]
6 changes: 6 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
26 changes: 14 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using Documenter, NNlib

DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true)
DocMeta.setdocmeta!(NNlib, :DocTestSetup,
:(using FFTW, NNlib, UnicodePlots); recursive = true)

makedocs(modules = [NNlib],
sitename = "NNlib.jl",
doctest = false,
pages = ["Home" => "index.md",
"Reference" => "reference.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)
sitename = "NNlib.jl",
doctest = true,
pages = ["Home" => "index.md",
"Reference" => "reference.md",
"Audio" => "audio.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)

deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
target = "build",
Expand Down
Binary file added docs/src/assets/jfk.flac
Binary file not shown.
61 changes: 61 additions & 0 deletions docs/src/audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Reference

!!! note
Spectral functions require importing `FFTW` package to enable them.

## Window functions

```@docs
hann_window
hamming_window
```

## Spectral

```@docs
stft
istft
NNlib.power_to_db
NNlib.db_to_power
```

## Spectrogram

```@docs
melscale_filterbanks
spectrogram
```

Example:

```@example 1
using FFTW # <- required for STFT support.
using NNlib
using FileIO
using Makie, CairoMakie
CairoMakie.activate!()
waveform, sampling_rate = load("./assets/jfk.flac")
fig = lines(reshape(waveform, :))
save("waveform.png", fig)
# Spectrogram.
n_fft = 1024
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
save("spectrogram.png", fig)
# Mel-scale spectrogram.
n_freqs = n_fft ÷ 2 + 1
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
save("mel-spectrogram.png", fig)
nothing # hide
```

|Waveform|Spectrogram|Mel Spectrogram|
|:---:|:---:|:---:|
|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|
9 changes: 9 additions & 0 deletions ext/NNlibFFTWExt/NNlibFFTWExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NNlibFFTWExt

using FFTW
using NNlib
using KernelAbstractions

include("stft.jl")

end
127 changes: 127 additions & 0 deletions ext/NNlibFFTWExt/stft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
function NNlib.stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
kab = get_backend(x)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as stft input `x` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end

if center
pad_amount = n_fft ÷ 2
x = pad_reflect(x, pad_amount; dims=1)
end

n = size(x, 1)
(0 < n_fft n) || throw(ArgumentError(
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))

n_frames = 1 + (n - n_fft) ÷ hop_length

# time2col.
# Reshape `x` to (n_fft, n_frames, B) if needed.
# Each row in `n_frames` is shifted by `hop_length`.
if n_frames > 1
# TODO can be more efficient if we support something like torch.as_strided
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
use_window && (x = x .* window;)
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)

normalized && (y = y .* eltype(y)(n_fft^-0.5);)
return y
end

function NNlib.istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
kab = get_backend(y)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as istft input `y` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# TODO check `y` eltype is complex

n_frames = size(y, 2)

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end

# Denormalize.
normalized && (y = y .* eltype(y)(n_fft^0.5);)

region = 1
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)

# De-apply window.
use_window && (x = x ./ window;)

# col2time.
expected_output_len = n_fft + hop_length * (n_frames - 1)

ids = Vector{Int}(undef, expected_output_len)
in_idx, out_idx = 0, 0
prev_e, v = 0, 0

for col in 0:(n_frames - 1)
for row in 1:n_fft
in_idx += 1
v = row + hop_length * col
v > prev_e || continue

out_idx += 1
ids[out_idx] = in_idx
end
prev_e = v
end

# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
right = if isnothing(original_length)
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
else
left + original_length - 1
end
x = x[left:right, nd...]
return x
end
5 changes: 5 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,9 @@ include("deprecations.jl")
include("rotation.jl")
export imrotate, ∇imrotate

include("audio/stft.jl")
include("audio/spectrogram.jl")
include("audio/mel.jl")
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks

end # module NNlib
Loading

0 comments on commit 4159154

Please sign in to comment.