Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilize ChainRulesCore thunks #966

Merged
merged 22 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ ZygoteTrackerExt = "Tracker"

[compat]
AbstractFFTs = "1.3.1"
ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRules = "1.72.2"
ChainRulesCore = "1.25.1"
ChainRulesTestUtils = "1"
Colors = "0.12, 0.13"
DiffRules = "1.4"
Expand Down
3 changes: 2 additions & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
16 changes: 14 additions & 2 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
oschulz marked this conversation as resolved.
Show resolved Hide resolved
@non_differentiable unthunk_tangent(::IdDict)


struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
context::CTX
end
Expand Down Expand Up @@ -107,7 +118,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
oschulz marked this conversation as resolved.
Show resolved Hide resolved
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
Expand Down Expand Up @@ -261,7 +271,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
_pullback(config.context, f_args...)
end

ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
ad_pullback(Δ) = zygote2differential(
pb(wrap_chainrules_output(unthunk_tangent(Δ))),
f_args)
return y, ad_pullback
end

Expand Down
15 changes: 12 additions & 3 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ end
_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)
tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x))

# unthunking is essentially an identity operation on a lazy value, but
# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make
# nested AD work, so define
@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),)


"""
pullback(f, args...)
Expand Down Expand Up @@ -351,6 +357,9 @@ function copy!(x::AbstractVector, ps::Params)
x
end

_maybe_unthunk(x::AbstractThunk) = unthunk(x)
_maybe_unthunk(x) = x

"""
Grads(...)

Expand Down Expand Up @@ -385,7 +394,7 @@ end

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
return gs.grads[x]
return _maybe_unthunk(gs.grads[x])
end

"""
Expand Down Expand Up @@ -468,7 +477,7 @@ function pullback(f, ps::Params)
cache(cx)[p] = nothing
end
back(Δ)
Grads(cx.cache, ps) # TODO make a copy
Grads(_maybe_unthunk(cx.cache), ps)
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
insertafter!, finish, expand!, prune!, substitute!, substitute,
block, block!, branch!, return!, stmt, meta


# TODO: Temporary, to be removed when ChainRulesCore rrules are required to
oschulz marked this conversation as resolved.
Show resolved Hide resolved
# support thunks as an input and all instances of _adjoint_keepthunks in
# Zygote have been replaces by rrules:
macro _adjoint_keepthunks(ex)
ZygoteRules.gradm(ex, false, true)
end
macro _adjoint_keepthunks!(ex)
ZygoteRules.gradm(ex, true, true)
end


@inline tuple_va(N, xs) = xs
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))
Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
x̄ = reconstruct_if_dict(x̄, _keys) # return a dictionary if needed
(nothing, (f = f̄, iter = x̄),)
end
y, collect_pullback
y, collect_pullback ∘ unthunk_tangent
end

collect_if_dict(x::Dict) = collect(x), collect(keys(x))
Expand Down
3 changes: 2 additions & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

function unbroadcast(x::AbstractArray, x̄)
function unbroadcast(x::AbstractArray, maybethunked_x̄)
x̄ = unthunk_tangent(maybethunked_x̄)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
Expand Down
52 changes: 29 additions & 23 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue)
return x
end

accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)

accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))

# Core functions
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing

@adjoint ifelse(cond::Bool, t, f) =
@_adjoint_keepthunks ifelse(cond::Bool, t, f) =
ifelse(cond, t, f),
Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ)

@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
Expand All @@ -70,11 +77,11 @@ end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

unwrap(ref, x) = x

@adjoint unwrap(ref, x) = unwrap(x), function (x̄)
@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄)
accum_global(__context__, ref, x̄)
(accum_param(__context__, x, x̄),)
end
Expand All @@ -88,7 +95,7 @@ function global_set(ref, val)
end
end

@adjoint! function global_set(ref, x)
@_adjoint_keepthunks! function global_set(ref, x)
global_set(ref, x), function (x̄)
gs = cache(__context__)
x̄ = accum(get(gs, ref, nothing), x̄)
Expand All @@ -101,9 +108,9 @@ end

using Base: tail

@adjoint tuple(xs...) = xs, identity
@_adjoint_keepthunks tuple(xs...) = xs, identity

@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -112,7 +119,7 @@ using Base: tail
val, back
end

@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -121,10 +128,10 @@ end
return val, back
end

@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
(xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing))

@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
val = xs[r]
function back(Δ)
dxs = ntuple(Val(length(xs))) do x
Expand Down Expand Up @@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
end

# Needed for iteration lowering
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))

@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
(xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing))

@adjoint function Base.first(xs::Tuple)
@_adjoint_keepthunks function Base.first(xs::Tuple)
drest = map(_->nothing, tail(xs))
first(xs), Δ -> ((Δ, drest...),)
end

@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)

_empty(x) = length(x)
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
Expand All @@ -188,7 +195,7 @@ end

unapply(t, xs) = _unapply(t, xs)[1]

@adjoint! function Core._apply(f, args...)
@_adjoint_keepthunks! function Core._apply(f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
end
end

@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...)
@_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -223,7 +230,7 @@ end
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)
@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,)

@adjoint function literal_getfield(x, ::Val{f}) where f
@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand Down Expand Up @@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x)
end
end


@adjoint! function setfield!(x, f, val)
@_adjoint_keepthunks! function setfield!(x, f, val)
y = setfield!(x, f, val)
g = grad_mut(__context__, x)
y, function (_)
Expand All @@ -290,13 +296,13 @@ end

Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)

@adjoint! function __new__(T, args...)
@_adjoint_keepthunks! function __new__(T, args...)
x = __new__(T, args...)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),false}(g)
end

@adjoint! function __splatnew__(T, args)
@_adjoint_keepthunks! function __splatnew__(T, args)
x = __splatnew__(T, args)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),true}(g)
Expand Down
16 changes: 8 additions & 8 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ function ngradient(f, xs::AbstractArray...)
return grads
end

function gradcheck(f, xs...)
function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5)
grad_zygote = gradient(f, xs...)
grad_finite_difference = ngradient(f, xs...)
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5))
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
end

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kwargs...)
gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...)

# utilities for using gradcheck with complex matrices
_splitreim(A) = (real(A),)
Expand Down Expand Up @@ -160,8 +160,8 @@ end
@test gradient(y, x, z) == ([1, 1, 2], nothing)

# https://github.com/FluxML/Zygote.jl/issues/376
_, back = Zygote._pullback(x->x[1]*im, randn(2))
@test back(1.0)[2] == real([-im, 0]) == [0, 0]
_, back = Zygote.pullback(x -> x[1] * im, randn(2))
@test back(1.0)[1] == real([-im, 0]) == [0, 0]

# _droplike
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
Expand Down Expand Up @@ -949,8 +949,8 @@ end
_hermsymtype(::Type{<:Symmetric}) = Symmetric
_hermsymtype(::Type{<:Hermitian}) = Hermitian

function _gradtest_hermsym(f, ST, A)
gradtest(_splitreim(collect(A))...) do (args...)
function _gradtest_hermsym(f, ST, A; kwargs...)
gradtest(_splitreim(collect(A))...; kwargs...) do (args...)
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
return sum(_splitreim(B))
end
Expand Down
1 change: 0 additions & 1 deletion test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,4 @@ end
@test sgs[d.b] ≈ fill(1.f0, size(d.b))
end


end
Loading