From abfe2712c64905e83b0ab9c80fc55ebe95fadd17 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 10 Dec 2020 22:32:42 +0000 Subject: [PATCH] Some Adjoint and Transpose stuff (#324) * Extend Adjoint / Transpose * Allow Composite cotangents * Remove parent implementations for now * Remove TODO comment * Account for transposing adjoints * Test transpose * Loosen restriction * Bump patch --- Project.toml | 2 +- src/rulesets/LinearAlgebra/structured.jl | 57 ++++++++++------------- test/rulesets/LinearAlgebra/structured.jl | 50 ++++++++++++++++++-- 3 files changed, 72 insertions(+), 37 deletions(-) diff --git a/Project.toml b/Project.toml index eb7e58431..56a91477f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.38" +version = "0.7.39" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index b80673f65..753fae7ca 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -90,32 +90,27 @@ end ##### `Adjoint` ##### -# ✖️✖️✖️TODO: Deal with complex-valued arrays as well -function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) - function Adjoint_pullback(ȳ) - return (NO_FIELDS, adjoint(ȳ)) - end +function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) + Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end -function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) - function Adjoint_pullback(ȳ) - return (NO_FIELDS, vec(adjoint(ȳ))) - end +function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) + Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end -function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) - function adjoint_pullback(ȳ) - return (NO_FIELDS, adjoint(ȳ)) - end +function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) + adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) return adjoint(A), adjoint_pullback end -function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) - function adjoint_pullback(ȳ) - return (NO_FIELDS, vec(adjoint(ȳ))) - end +function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) + adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -123,31 +118,27 @@ end ##### `Transpose` ##### -function rrule(::Type{<:Transpose}, A::AbstractMatrix) - function Transpose_pullback(ȳ) - return (NO_FIELDS, transpose(ȳ)) - end +function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) + Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) return Transpose(A), Transpose_pullback end -function rrule(::Type{<:Transpose}, A::AbstractVector) - function Transpose_pullback(ȳ) - return (NO_FIELDS, vec(transpose(ȳ))) - end +function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) + Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end -function rrule(::typeof(transpose), A::AbstractMatrix) - function transpose_pullback(ȳ) - return (NO_FIELDS, transpose(ȳ)) - end +function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) + transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) return transpose(A), transpose_pullback end -function rrule(::typeof(transpose), A::AbstractVector) - function transpose_pullback(ȳ) - return (NO_FIELDS, vec(transpose(ȳ))) - end +function rrule(::typeof(transpose), A::AbstractVector{<:Number}) + transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) return transpose(A), transpose_pullback end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 1dc04e487..6070637fa 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -104,11 +104,55 @@ end end end - @testset "$f" for f in (Adjoint, adjoint, Transpose, transpose) + @testset "$f, $T" for + f in (Adjoint, adjoint, Transpose, transpose), + T in (Float64, ComplexF64) + n = 5 m = 3 - rrule_test(f, randn(m, n), (randn(n, m), randn(n, m))) - rrule_test(f, randn(1, n), (randn(n), randn(n))) + @testset "$f(::Matrix{$T})" begin + A = randn(T, n, m) + Ā = randn(T, n, m) + Y = f(A) + Ȳ_mat = randn(T, m, n) + Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat))) + + rrule_test(f, Ȳ_mat, (A, Ā)) + + _, pb = rrule(f, A) + @test pb(Ȳ_mat) == pb(Ȳ_composite) + end + + @testset "$f(::Vector{$T})" begin + a = randn(T, n) + ā = randn(T, n) + y = f(a) + ȳ_mat = randn(T, 1, n) + ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat))) + + rrule_test(f, ȳ_mat, (a, ā)) + + _, pb = rrule(f, a) + @test pb(ȳ_mat) == pb(ȳ_composite) + end + + @testset "$f(::Adjoint{$T, Vector{$T})" begin + a = randn(T, n)' + ā = randn(T, n)' + y = f(a) + ȳ = randn(T, n) + + rrule_test(f, ȳ, (a, ā)) + end + + @testset "$f(::Transpose{$T, Vector{$T})" begin + a = transpose(randn(T, n)) + ā = transpose(randn(T, n)) + y = f(a) + ȳ = randn(T, n) + + rrule_test(f, ȳ, (a, ā)) + end end @testset "$T" for T in (UpperTriangular, LowerTriangular) n = 5