Skip to content

Commit

Permalink
Some Adjoint and Transpose stuff (#324)
Browse files Browse the repository at this point in the history
* Extend Adjoint / Transpose

* Allow Composite cotangents

* Remove parent implementations for now

* Remove TODO comment

* Account for transposing adjoints

* Test transpose

* Loosen restriction

* Bump patch
  • Loading branch information
willtebbutt authored Dec 10, 2020
1 parent 22eddb4 commit abfe271
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.38"
version = "0.7.39"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
57 changes: 24 additions & 33 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,64 +90,55 @@ end
##### `Adjoint`
#####

# ✖️✖️✖️TODO: Deal with complex-valued arrays as well
function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real})
function Adjoint_pullback(ȳ)
return (NO_FIELDS, adjoint(ȳ))
end
function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number})
Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ))
return Adjoint(A), Adjoint_pullback
end

function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real})
function Adjoint_pullback(ȳ)
return (NO_FIELDS, vec(adjoint(ȳ)))
end
function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number})
Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ)))
return Adjoint(A), Adjoint_pullback
end

function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real})
function adjoint_pullback(ȳ)
return (NO_FIELDS, adjoint(ȳ))
end
function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number})
adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ))
return adjoint(A), adjoint_pullback
end

function rrule(::typeof(adjoint), A::AbstractVector{<:Real})
function adjoint_pullback(ȳ)
return (NO_FIELDS, vec(adjoint(ȳ)))
end
function rrule(::typeof(adjoint), A::AbstractVector{<:Number})
adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ)))
return adjoint(A), adjoint_pullback
end

#####
##### `Transpose`
#####

function rrule(::Type{<:Transpose}, A::AbstractMatrix)
function Transpose_pullback(ȳ)
return (NO_FIELDS, transpose(ȳ))
end
function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number})
Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ))
return Transpose(A), Transpose_pullback
end

function rrule(::Type{<:Transpose}, A::AbstractVector)
function Transpose_pullback(ȳ)
return (NO_FIELDS, vec(transpose(ȳ)))
end
function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number})
Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ)))
return Transpose(A), Transpose_pullback
end

function rrule(::typeof(transpose), A::AbstractMatrix)
function transpose_pullback(ȳ)
return (NO_FIELDS, transpose(ȳ))
end
function rrule(::typeof(transpose), A::AbstractMatrix{<:Number})
transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ))
return transpose(A), transpose_pullback
end

function rrule(::typeof(transpose), A::AbstractVector)
function transpose_pullback(ȳ)
return (NO_FIELDS, vec(transpose(ȳ)))
end
function rrule(::typeof(transpose), A::AbstractVector{<:Number})
transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ)))
return transpose(A), transpose_pullback
end

Expand Down
50 changes: 47 additions & 3 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,55 @@
end
end
end
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
@testset "$f, $T" for
f in (Adjoint, adjoint, Transpose, transpose),
T in (Float64, ComplexF64)

n = 5
m = 3
rrule_test(f, randn(m, n), (randn(n, m), randn(n, m)))
rrule_test(f, randn(1, n), (randn(n), randn(n)))
@testset "$f(::Matrix{$T})" begin
A = randn(T, n, m)
= randn(T, n, m)
Y = f(A)
Ȳ_mat = randn(T, m, n)
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))

rrule_test(f, Ȳ_mat, (A, Ā))

_, pb = rrule(f, A)
@test pb(Ȳ_mat) == pb(Ȳ_composite)
end

@testset "$f(::Vector{$T})" begin
a = randn(T, n)
= randn(T, n)
y = f(a)
ȳ_mat = randn(T, 1, n)
ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))

rrule_test(f, ȳ_mat, (a, ā))

_, pb = rrule(f, a)
@test pb(ȳ_mat) == pb(ȳ_composite)
end

@testset "$f(::Adjoint{$T, Vector{$T})" begin
a = randn(T, n)'
= randn(T, n)'
y = f(a)
= randn(T, n)

rrule_test(f, ȳ, (a, ā))
end

@testset "$f(::Transpose{$T, Vector{$T})" begin
a = transpose(randn(T, n))
= transpose(randn(T, n))
y = f(a)
= randn(T, n)

rrule_test(f, ȳ, (a, ā))
end
end
@testset "$T" for T in (UpperTriangular, LowerTriangular)
n = 5
Expand Down

2 comments on commit abfe271

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/26209

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.39 -m "<description of version>" abfe2712c64905e83b0ab9c80fc55ebe95fadd17
git push origin v0.7.39

Please sign in to comment.