Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume commutative multiplication exactly when necessary #540

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 70 additions & 48 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NoTangent(), -Ω' * ΔΩ * Ω'
return NoTangent(), Ω' * -ΔΩ * Ω'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask why you moved the minus?

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that if ΔΩ is an AbstractZero or a UniformScaling, then the negation is cheaper.

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

I didn't follow this. How is this fused into the mul!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this was not an important change, and I'm happy to remove)

Copy link
Member

@mcabbott mcabbott Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't think about those. For dense matrices there's a 4-arg method which fuses this:

julia> f1(Ω, ΔΩ) = Ω' * -ΔΩ * Ω';

julia> f2(Ω, ΔΩ) = -true * Ω' * ΔΩ * Ω';

julia> @btime f1(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 74.708 μs, mean 101.133 μs (6 allocations, 234.52 KiB. GC mean 6.51%)

julia> @btime f2(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 73.125 μs, mean 92.756 μs (4 allocations, 156.34 KiB. GC mean 4.82%)

julia> @which -1 * ones(2,2) * ones(2,2) * ones(2,2)
*(α::Union{Real, Complex}, B::AbstractMatrix{<:Union{Real, Complex}}, C::AbstractMatrix{<:Union{Real, Complex}}, D::AbstractMatrix{<:Union{Real, Complex}}) in LinearAlgebra at /Users/me/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:1134

But with I, no fusion, hence f2 is slower. Maybe * should have some extra methods for cases with I.

end
return Ω, inv_pullback
end
Expand All @@ -23,12 +23,7 @@ frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB)

frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC


function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
function rrule(::typeof(*), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Expand All @@ -46,8 +41,8 @@ end
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411
function rrule(
::typeof(*),
A::StridedMatrix{<:CommutativeMulNumber},
B::StridedVecOrMat{<:CommutativeMulNumber},
A::StridedMatrix{<:Number},
B::StridedVecOrMat{<:Number},
)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
Expand All @@ -64,13 +59,7 @@ function rrule(
return A * B, times_pullback
end



function rrule(
::typeof(*),
A::AbstractVector{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
function rrule(::typeof(*), A::AbstractVector{<:Number}, B::AbstractMatrix{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Expand All @@ -97,15 +86,24 @@ end
#####

function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
::typeof(*), A::Number, B::AbstractArray{<:Number}
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
@thunk(project_A(dot(Ȳ, B)')),
Thunk() do
if eltype(B) isa CommutativeMulNumber
project_A(dot(Ȳ, B)')
elseif ndims(B) < 3
# https://github.com/JuliaLang/julia/issues/44152
project_A(dot(conj(Ȳ), conj(B)))
else
project_A(dot(conj(vec(Ȳ)), conj(vec(B))))
end
end,
InplaceableThunk(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
Expand All @@ -116,7 +114,7 @@ function rrule(
end

function rrule(
::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber
::typeof(*), B::AbstractArray{<:Number}, A::Number
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
Expand All @@ -128,7 +126,17 @@ function rrule(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
),
@thunk(project_A(dot(Ȳ, B)')),
# @thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))),
Thunk() do
if eltype(B) isa CommutativeMulNumber
project_A(dot(Ȳ, B)')
elseif ndims(B) < 3
# https://github.com/JuliaLang/julia/issues/44152
project_A(dot(conj(Ȳ), conj(B)))
else
project_A(dot(conj(vec(Ȳ)), conj(vec(B))))
end
end,
)
end
return A * B, times_pullback
Expand Down Expand Up @@ -217,9 +225,9 @@ end

function rrule(
::typeof(muladd),
A::AbstractMatrix{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
A::AbstractMatrix{<:Number},
B::AbstractVecOrMat{<:Number},
z::Union{Number, AbstractVecOrMat{<:Number}},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
Expand Down Expand Up @@ -255,9 +263,9 @@ end

function rrule(
::typeof(muladd),
ut::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
v::AbstractVector{<:CommutativeMulNumber},
z::CommutativeMulNumber,
ut::LinearAlgebra.AdjOrTransAbsVec{<:Number},
v::AbstractVector{<:Number},
z::Number,
)
project_ut = ProjectTo(ut)
project_v = ProjectTo(v)
Expand All @@ -268,11 +276,11 @@ function rrule(
dy = unthunk(ȳ)
ut_thunk = InplaceableThunk(
dut -> dut .+= v' .* dy,
@thunk(project_ut(v' .* dy)),
@thunk(project_ut((v * dy')')),
)
v_thunk = InplaceableThunk(
dv -> dv .+= ut' .* dy,
@thunk(project_v(ut' .* dy)),
@thunk(project_v(ut' * dy)),
)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy))
end
Expand All @@ -281,9 +289,9 @@ end

function rrule(
::typeof(muladd),
u::AbstractVector{<:CommutativeMulNumber},
vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
u::AbstractVector{<:Number},
vt::LinearAlgebra.AdjOrTransAbsVec{<:Number},
z::Union{Number, AbstractVecOrMat{<:Number}},
)
project_u = ProjectTo(u)
project_vt = ProjectTo(vt)
Expand All @@ -293,8 +301,8 @@ function rrule(
function muladd_pullback_3(ȳ)
Ȳ = unthunk(ȳ)
proj = (
@thunk(project_u(vec(sum(.* conj.(vt), dims=2)))),
@thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')),
@thunk(project_u(Ȳ * vec(vt'))),
@thunk(project_vt((Ȳ' * u)')),
)
addon = if z isa Bool
NoTangent()
Expand All @@ -315,14 +323,14 @@ end
##### `/`
#####

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
function rrule(::typeof(/), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback()
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb()
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
Expand All @@ -337,7 +345,7 @@ end
##### `\`
#####

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
function rrule(::typeof(\), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)

Expand All @@ -362,32 +370,46 @@ end
##### `\`, `/` matrix-scalar_rule
#####

function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
return A/b, ΔA/b - A*(Δb/b^2)

function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:Number}, b::Number)
Y = A / b
return Y, muladd(Y, -Δb, ΔA) / b
end
function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber})
return B/a, ΔB/a - B*(Δa/a^2)
function frule((_, Δa, ΔB), ::typeof(\), a::Number, B::AbstractArray{<:Number})
Y = a \ B
return Y, a \ muladd(-Δa, Y, ΔB)
end

function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
function rrule(::typeof(/), A::AbstractArray{<:Number}, b::Number)
Y = A/b
function slash_pullback_scalar(ȳ)
Ȳ = unthunk(ȳ)
Athunk = InplaceableThunk(
dA -> dA .+= Ȳ ./ conj(b),
@thunk(Ȳ / conj(b)),
)
bthunk = @thunk(-dot(A,Ȳ) / conj(b^2))
bthunk = @thunk(-dot(Y,Ȳ) / conj(b))
return (NoTangent(), Athunk, bthunk)
end
return Y, slash_pullback_scalar
end

function rrule(::typeof(\), b::CommutativeMulNumber, A::AbstractArray{<:CommutativeMulNumber})
Y, back = rrule(/, A, b)
function backslash_pullback(dY) # just reverses the arguments!
d0, dA, db = back(dY)
return (d0, db, dA)
function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number})
Y = b \ A
function backslash_pullback(ȳ)
Ȳ = unthunk(ȳ)
Athunk = InplaceableThunk(
dA -> dA .+= conj(b) .\ Ȳ,
@thunk(conj(b) \ Ȳ),
)
bthunk = if eltype(Y) isa CommutativeMulNumber
@thunk(-conj(b) \ dot(Y, Ȳ))
else
# NOTE: dot(Ȳ', Y') currently incorrect for non-commutative numbers
# https://github.com/JuliaLang/julia/issues/44152
@thunk(-conj(b) \ dot(conj(Ȳ), conj(Y)))
end
return (NoTangent(), bthunk, Athunk)
end
return Y, backslash_pullback
end
Expand Down
Loading