Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chain rules for DCT #273

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
11 changes: 11 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved

[extensions]
FFTWChainRulesCoreExt = "ChainRulesCore"

[compat]
AbstractFFTs = "1.0"
ChainRulesCore = "1"
FFTW_jll = "3.3.9"
MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023"
Preferences = "1.2"
Reexport = "0.2, 1.0"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
72 changes: 72 additions & 0 deletions ext/FFTWChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module FFTWChainRulesCoreExt

using FFTW
using FFTW: r2r
using ChainRulesCore

# DCT

function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region...)
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
Δx = Δ[2]
y = dct(x, region...)
Δy = dct(Δx, region...)
return y, Δy
end

function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region...)
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
y = dct(x, region...)
project_x = ProjectTo(x)

function dct_pullback(ȳ)
f̄ = NoTangent()
x̄ = project_x(idct(unthunk(ȳ), region...))
r̄ = NoTangent()

if isempty(region)
return f̄, x̄
else
return f̄, x̄, r̄
end
end

return y, dct_pullback
end

# IDCT

function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region...)
Δx = Δ[2]
y = idct(x, region...)
Δy = idct(Δx, region...)
return y, Δy
end

function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region...)
y = idct(x, region...)
project_x = ProjectTo(x)

function idct_pullback(ȳ)
f̄ = NoTangent()
x̄ = project_x(dct(unthunk(ȳ), region...))
r̄ = NoTangent()

if isempty(region)
return f̄, x̄
else
return f̄, x̄, r̄
end
end

return y, idct_pullback
end

# R2R

function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the rrule for r2r is missing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The R2R transforms are not unitary. There is some scaling involved that depends on the kind of R2R transform. Because it looks like an involved task, I chose to skip that for now. I am happy to look into that in a separate PR

Δx = Δ[2]
y = r2r(x, region...)
Δy = r2r(Δx, region...)
return y, Δy
end

end # module
10 changes: 10 additions & 0 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!

include("providers.jl")

@static if !isdefined(Base, :get_extension)
import Requires
end

vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
function __init__()
# If someone is trying to set the provider via the old environment variable, warn them that they
# should instead use `set_provider!()` instead.
Expand All @@ -35,6 +39,12 @@ function __init__()
libfftw3[] = MKL_jll.libmkl_rt_path
libfftw3f[] = MKL_jll.libmkl_rt_path
end

@static if !isdefined(Base, :get_extension)
Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin
include("../ext/FFTWChainRulesCoreExt.jl")
end
end
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
end

# most FFTW calls other than fftw_execute should be protected by a lock to be thread-safe
Expand Down
6 changes: 2 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around
# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences
# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
46 changes: 45 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file was formerly a part of Julia. License is MIT: https://julialang.org/license
using FFTW
using FFTW: fftw_provider
using FFTW: fftw_provider, r2r
using AbstractFFTs: Plan, plan_inv
using Test
using LinearAlgebra
Expand Down Expand Up @@ -577,3 +577,47 @@ end
end
end
end

@testset "ChainRules" begin

if isdefined(Base, :get_extension)
CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt)
@test isnothing(CRCEXT)
end

using ChainRulesTestUtils

if isdefined(Base, :get_extension)
CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt)
@test !isnothing(CRCEXT)
end

@testset "DCT" begin
for f in (dct, idct)
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
test_frule(f, x)
test_rrule(f, x)

N = ndims(x)
for region in unique((1, 1:N, N))
test_frule(f, x, region)
test_rrule(f, x, region)
end # for region
end # for x
end # for f
end

@testset "r2r" begin
for k in 4 #0:10
for x in (randn(3), )#randn(3, 4), randn(3, 4, 5))
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
test_frule(r2r, x, k)

N = ndims(x)
for region in unique((1, 1:N, N))
test_frule(r2r, x, k, region)
end # for region
end # for x
end # for f
end

end