Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: in-place accum #981

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ include("lib/lib.jl")
include("lib/literal_getproperty.jl")
include("lib/number.jl")
include("lib/base.jl")
include("lib/protect.jl")
include("lib/array.jl")
include("lib/buffer.jl")
include("lib/broadcast.jl")
Expand Down
17 changes: 16 additions & 1 deletion src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,26 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven
struct ZBack{F} <: Function
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# @inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = nothing

function (s::ZBack)(dy)
dxs = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
dxs === nothing && return
ptrs = map(_pointer, dxs)
map(dxs) do dx
ptr = _pointer(dx)
if ptr !== nothing && count(isequal(ptr), ptrs) > 1
# @debug "wrapping for chainrules" summary(dy) ptr
_protect(dx)
else
dx
end
end
end

"""
chain_rrule(config, f, args...)

Expand Down
44 changes: 34 additions & 10 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))

function accum(x::DenseArray, ys::OneElement...)
for y in ys
x[y.ind...] += y.val
end
x
end
function accum(x::OneElement, ys::OneElement...)
z = fill!(similar(x), 0)
z[x.ind...] = x.val
for y in ys
z[y.ind...] += y.val
end
z
end

_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
Expand Down Expand Up @@ -444,7 +458,7 @@ end
@adjoint function inv(A::Union{Number, AbstractMatrix})
Ainv = inv(A)
return Ainv, function (Δ)
∇A = - Ainv' * Δ * Ainv'
∇A = - Ainv' * _unprotect(Δ) * Ainv'
return (∇A, )
end
end
Expand All @@ -460,7 +474,10 @@ end
rtol::Real = (eps(real(float(one(T))))*min(size(A)...))*iszero(atol),
) where {T}
Y = pinv(A)
return Y, Δ->(-Y' * Δ * Y' + (I - A * Y) * Δ' * Y * Y' + Y' * Y * Δ' * (I - Y * A),)
return Y, _Δ -> begin
Δ = _unprotect(_Δ)
(-Y' * Δ * Y' + (I - A * Y) * Δ' * Y * Y' + Y' * Y * Δ' * (I - Y * A),)
end
end

# When `A` is guaranteed to be square, definitely use the simple expression for the adjoint.
Expand All @@ -474,23 +491,26 @@ end
B::AbstractVecOrMat,
)
Y = A \ B
return Y, function(Ȳ)
return Y, function(_Ȳ)
Ȳ = _unprotect(_Ȳ)
B̄ = A' \ Ȳ
return (-B̄ * Y', B̄)
end
end

@adjoint function /(A::AbstractMatrix, B::Union{Diagonal, AbstractTriangular})
Y = A / B
return Y, function(Ȳ)
return Y, function(_Ȳ)
Ȳ = _unprotect(_Ȳ)
Ā = Ȳ / B'
return (Ā, -Y' * Ā)
end
end

@adjoint function \(A::AbstractMatrix, B::AbstractVecOrMat)
Z = A \ B
return Z, function(Z̄)
return Z, function(_Z̄)
Z̄ = _unprotect(_Z̄)
B̄ = A' \ Z̄
if size(A, 1) == size(A, 2)
return (-B̄ * Z', B̄)
Expand All @@ -514,7 +534,8 @@ end
# This is basically a hack while we don't have a working `ldiv!`.
@adjoint function \(A::Cholesky, B::AbstractVecOrMat)
Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B)
return Y, function(Ȳ)
return Y, function(_Ȳ)
Ȳ = _unprotect(_Ȳ)
Ā_factors, B̄ = back(Ȳ)
return ((uplo=nothing, info=nothing, factors=Ā_factors), B̄)
end
Expand Down Expand Up @@ -595,7 +616,8 @@ end

@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
return X, function (X̄)
return X, function (_X̄)
X̄ = _unprotect(_X̄)
C̄ = lyap(collect(A'), X̄)
Ā = C̄*X' + C̄'*X
return (Ā, C̄)
Expand Down Expand Up @@ -766,7 +788,7 @@ end
return S - A, Δ->((λ=tr(Δ),), -Δ)
end

@adjoint +(A::AbstractArray, B::AbstractArray) = A + B, Δ->(Δ, Δ)
@adjoint +(A::AbstractArray, B::AbstractArray) = A + B, Δ->(_protect(Δ), _protect(Δ))
@adjoint -(A::AbstractArray, B::AbstractArray) = A - B, Δ->(Δ, -Δ)
@adjoint -(A::AbstractArray) = -A, Δ->(-Δ,)

Expand All @@ -792,14 +814,16 @@ AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims
end

@adjoint function *(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
return P * xs, function(_Δ)
Δ = _unprotect(_Δ)
N = prod(size(xs)[[P.region...]])
return (nothing, N * (P \ Δ))
end
end

@adjoint function \(P::AbstractFFTs.Plan, xs)
return P \ xs, function(Δ)
return P \ xs, function(_Δ)
Δ = _unprotect(_Δ)
N = prod(size(Δ)[[P.region...]])
return (nothing, (P * Δ)/N)
end
Expand Down
2 changes: 1 addition & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
# right arrays.

@adjoint broadcasted(::typeof(+), xs::Numeric...) =
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, _protect(ȳ)), xs)...)

@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ)))
Expand Down
9 changes: 9 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ accum(x, y, zs...) = accum(accum(x, y), zs...)
accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)

function accum(x::DenseArray, ys::AbstractArray...)
if ndims(x)==1 && any(ndims(y)>1 for y in ys)
# work around bug fixed in https://github.com/JuliaLang/julia/pull/39859
broadcast!(accum, x, vec(x), map(vec, ys)...)
else
broadcast!(accum, x, x, ys...)
end
end

@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
fieldnames(y) ⊆ fieldnames(x) || throw(ArgumentError("$y keys must be a subset of $x keys"))
Expand Down
114 changes: 114 additions & 0 deletions src/lib/protect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@

"""
NoWrite(Δ::AbstractArray)

This is a trivial wrapper, without `setindex!`, to prevent mutation of gradients
when the same array may be in use elsewhere. It should be applied using `_protect`
in rules like `@adjoint +(A,B) = A+B, Δ -> (_protect(Δ), _protect(Δ))`.
(This will be handled automatically for rules defined via ChainRules.jl.)
"""
struct NoWrite{T,N,P} <: AbstractArray{T,N}
data::P
NoWrite(x::P) where {P <: AbstractArray{T,N}} where {T,N} = new{T,N,P}(x)
end

Base.parent(x::NoWrite) = x.data

Base.@propagate_inbounds Base.getindex(A::NoWrite, i...) = getindex(A.data, i...)

for f in (:size, :axes, :length, :similar, :copy, :IndexStyle, :strides, :pointer, :Tuple, :iterate)
@eval Base.$f(x::NoWrite) = Base.$f(x.data)
end

Base.showarg(io::IO, x::NoWrite, top) = begin print(io, "NoWrite("); Base.showarg(io, x.data, false); print(io, ")") end

_unprotect(A::NoWrite) = parent(A) # for use on the RHS of rules, e.g. to avoid generic matmul
_unprotect(A) = A

_protect(A::DenseArray) = NoWrite(A)
_protect(A::NoWrite) = A # never need to wrap twice
_protect(A::AbstractArray) = _maybewrite(A) ? NoWrite(A) : A # protect anything which could be upwrapped to be writable?
_protect(A) = A

_maybewrite(A) = false
_maybewrite(A::DenseArray) = true
_maybewrite(A::AbstractArray) = A===parent(A) ? false : _maybewrite(parent(A))

##### For Params & Grads, don't accumulate NoWrite objects

Base.setindex!(dict::IdDict, dx::NoWrite, x) = dict[x] = copy(dx.data)

##### For ChainRules rules, unwrap & re-wrap automatically:

_pointer(A::Array) = pointer(A) # pointer survives reshape, objectid does not
_pointer(A::AbstractArray) = A===parent(A) ? NaN : _pointer(parent(A)) # not strictly necc
_pointer(A) = nothing # compares == self

@inline function (s::ZBack)(dy::NoWrite)
ptr_orig = _pointer(dy.data)
# @debug "unwrapping for chainrules" summary(dy.data) ptr s.back
dxs = wrap_chainrules_output(s.back(wrap_chainrules_input(dy.data)))
dxs === nothing && return
ptrs = map(_pointer, dxs)
map(dxs) do dx
ptr = _pointer(dx)
if ptr === nothing
dx
elseif ptr == ptr_orig
# @debug "re-wrapping for chainrules" summary(dy.data) ptr
_protect(dx)
elseif count(isequal(ptr), ptrs) > 1
# @debug "wrapping for chainrules" summary(dy.data) ptr
_protect(dx)
else
dx
end
end
end


###### For @adjoint rules:

Broadcast.broadcastable(A::NoWrite) = A.data # always unwrap on the RHS of broadcasting

Base.mapreduce(f, op, A::NoWrite; kw...) = mapreduce(f, op, A.data; kw...) # always unwrap within sum, etc.

# Try to keep NoWrite outside, to maximise chances of sucessful unwrapping:
Base._reshape(x::NoWrite, dims::Tuple{Vararg{Int}}) = NoWrite(reshape(x, dims))
for f in (:transpose, :adjoint, :Transpose, :Adjoint, :Diagonal)
@eval LinearAlgebra.$f(x::NoWrite) = NoWrite(LinearAlgebra.$f(x.data))
end

using AbstractFFTs # many rules, easier to overload here:

for f in (:fft, :bfft, :ifft, :rfft, :irfft, :brfft, :fftshift, :ifftshift)
@eval AbstractFFTs.$f(x::NoWrite, dims...) = AbstractFFTs.$f(_unprotect(x), dims...)
end

# LinearAlgebra.:\(A::AbstractMatrix, B::NoWriteVecOrMat)

# The dispatch for * is very messy, better just to unwrap by hand. For debugging:

NoWriteVector{T} = NoWrite{T,1}
NoWriteMatrix{T} = NoWrite{T,2}
NoWriteVecOrMat{T} = Union{NoWriteVector{T}, NoWriteMatrix{T}}

LinearAlgebra.generic_matvecmul!(C::AbstractVector, tA, A::NoWriteVecOrMat, B::AbstractVector, _add::LinearAlgebra.MulAddMul) = _mulv(C, tA, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::NoWriteVector, _add::LinearAlgebra.MulAddMul) = _mulv(C, tA, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::AbstractVector, tA, A::NoWriteVecOrMat, B::NoWriteVector, _add::LinearAlgebra.MulAddMul) = _mulv(C, tA, A, B, _add)

function _mulv(C, tA, A, B, _add)
# @debug "generic matrix-vector due to NoWrite" summary(A) summary(B)
invoke(LinearAlgebra.generic_matvecmul!, Tuple{AbstractVector, Any, AbstractVecOrMat, AbstractVector, LinearAlgebra.MulAddMul}, C, tA, A, B, _add)
end

LinearAlgebra.generic_matmatmul!(C::AbstractMatrix, tA, A::NoWriteMatrix, B::AbstractMatrix, _add::LinearAlgebra.MulAddMul) = _mulm(C, tA, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::AbstractMatrix, tA, A::AbstractMatrix, B::NoWriteMatrix, _add::LinearAlgebra.MulAddMul) = _mulm(C, tA, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::AbstractMatrix, tA, A::NoWriteMatrix, B::NoWriteMatrix, _add::LinearAlgebra.MulAddMul) = _mulm(C, tA, A, B, _add)

function _mulm(C, tA, A, B, _add)
# @debug "generic matrix-matrix multiplication due to NoWrite" summary(A) summary(B)
invoke(LinearAlgebra.generic_matmatmul!, Tuple{AbstractMatrix, Any, AbstractMatrix, AbstractMatrix, LinearAlgebra.MulAddMul}, C, tA, A, B, _add)
end