From 4d39e77b9f7894ca933c9bcf6c16221b6be251d4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:02:06 +0200 Subject: [PATCH] Better explanation of preparation (#536) * Clarify documentation on preparation * Handle Val in Enzyme --- .../docs/src/explanation/backends.md | 11 ++++++++--- .../docs/src/explanation/operators.md | 11 ++++++----- .../forward_onearg.jl | 8 ++++---- .../reverse_onearg.jl | 8 ++++++-- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index f1edbaad6..0ad276905 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -137,9 +137,12 @@ For every operator, preparation generates an [executable function](https://brian ### FiniteDiff Whenever possible, preparation creates a cache object. +Pushforward is implemented rather slowly using a closure. ### FiniteDifferences +Nothing specific to mention. + ### ForwardDiff We implement [`pushforward`](@ref) directly using [`Dual` numbers](https://juliadiff.org/ForwardDiff.jl/stable/dev/how_it_works/), and preparation allocates the necessary space. @@ -152,9 +155,12 @@ Most operators fall back on `AutoForwardDiff`. ### ReverseDiff Wherever possible, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution. +This tape is computed from the arguments `x` and `contexts...` provided at preparation time. +It is control-flow dependent, so only one branch is recorded at each `if` statement. -!!! warning - This tape is specific to the control flow inside the function, and cannot be reused if the control flow is value-dependent (like `if x[1] > 0`). +!!! danger + If your function has value-specific control flow (like `if x[1] > 0` or `if c == 1`), you may get silently wrong results whenever it takes new branches that were not taken during preparation. + You must make sure to run preparation with an input and contexts whose values trigger the correct control flow for future executions. ### Symbolics @@ -176,4 +182,3 @@ Same-point preparation runs the forward sweep and returns the pullback closure a We implement `pullback` based on `Zygote.pullback`. Same-point preparation runs the forward sweep and returns the pullback closure at `x`. - diff --git a/DifferentiationInterface/docs/src/explanation/operators.md b/DifferentiationInterface/docs/src/explanation/operators.md index 1f6a2e84f..d8e257c54 100644 --- a/DifferentiationInterface/docs/src/explanation/operators.md +++ b/DifferentiationInterface/docs/src/explanation/operators.md @@ -125,13 +125,14 @@ Here are the general rules that we strive to implement: For different-point preparation, the output `prep` of `prepare_op(f, b, x, [t])` can be reused in `op(f, prep, b, other_x, [other_t])`, provided that: -- the inputs `x` and `other_x` have similar types and equal shapes -- the tangents in `t` and `other_t` have similar types and equal shapes +- the inputs `x` and `other_x` have the same types and sizes +- the tangents in `t` and `other_t` have the same types and sizes For same-point preparation, the output `prep` of `prepare_op_same_point(f, b, x, [t])` can be reused in `op(f, prep, b, x, other_t)`, provided that: -- the input `x` remains the same (as well as the [`Context`](@ref) constants) -- the tangents in `t` and `other_t` have similar types and equal shapes +- the input `x` remains exactly the same (as well as any [`Constant`](@ref) context) +- the tangents in `t` and `other_t` have the same types and sizes !!! warning - These rules hold for the majority of backends, but there are some exceptions. \ No newline at end of file + These rules hold for the majority of backends, but there are some exceptions. + The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function. diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index e15818590..25beaf461 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -117,8 +117,8 @@ end function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x ) where {F} - B = pick_batchsize(backend, length(x)) - shadows = create_shadows(Val(B), x) + valB = pick_batchsize(backend, length(x)) + shadows = create_shadows(valB, x) return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows) end @@ -180,8 +180,8 @@ function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x ) where {F} y = f(x) - B = pick_batchsize(backend, length(x)) - shadows = create_shadows(Val(B), x) + valB = pick_batchsize(backend, length(x)) + shadows = create_shadows(valB, x) return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 52d99a414..67152f97f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -349,11 +349,15 @@ end struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end +function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B} + return EnzymeReverseOneArgJacobianPrep{Sy,B}() +end + function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F} y = f(x) Sy = size(y) - B = pick_batchsize(backend, prod(Sy)) - return EnzymeReverseOneArgJacobianPrep{Sy,B}() + valB = pick_batchsize(backend, prod(Sy)) + return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB) end function DI.jacobian(