From 92aa08c2def6a50ec21a17e591aab72cdfc477c2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:07:30 +0200 Subject: [PATCH 1/2] Revert "Better explanation of preparation (#536)" This reverts commit 4d39e77b9f7894ca933c9bcf6c16221b6be251d4. --- .../docs/src/explanation/backends.md | 11 +++-------- .../docs/src/explanation/operators.md | 11 +++++------ .../forward_onearg.jl | 8 ++++---- .../reverse_onearg.jl | 8 ++------ 4 files changed, 14 insertions(+), 24 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 0ad276905..f1edbaad6 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -137,12 +137,9 @@ For every operator, preparation generates an [executable function](https://brian ### FiniteDiff Whenever possible, preparation creates a cache object. -Pushforward is implemented rather slowly using a closure. ### FiniteDifferences -Nothing specific to mention. - ### ForwardDiff We implement [`pushforward`](@ref) directly using [`Dual` numbers](https://juliadiff.org/ForwardDiff.jl/stable/dev/how_it_works/), and preparation allocates the necessary space. @@ -155,12 +152,9 @@ Most operators fall back on `AutoForwardDiff`. ### ReverseDiff Wherever possible, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution. -This tape is computed from the arguments `x` and `contexts...` provided at preparation time. -It is control-flow dependent, so only one branch is recorded at each `if` statement. -!!! danger - If your function has value-specific control flow (like `if x[1] > 0` or `if c == 1`), you may get silently wrong results whenever it takes new branches that were not taken during preparation. - You must make sure to run preparation with an input and contexts whose values trigger the correct control flow for future executions. +!!! warning + This tape is specific to the control flow inside the function, and cannot be reused if the control flow is value-dependent (like `if x[1] > 0`). ### Symbolics @@ -182,3 +176,4 @@ Same-point preparation runs the forward sweep and returns the pullback closure a We implement `pullback` based on `Zygote.pullback`. Same-point preparation runs the forward sweep and returns the pullback closure at `x`. + diff --git a/DifferentiationInterface/docs/src/explanation/operators.md b/DifferentiationInterface/docs/src/explanation/operators.md index d8e257c54..1f6a2e84f 100644 --- a/DifferentiationInterface/docs/src/explanation/operators.md +++ b/DifferentiationInterface/docs/src/explanation/operators.md @@ -125,14 +125,13 @@ Here are the general rules that we strive to implement: For different-point preparation, the output `prep` of `prepare_op(f, b, x, [t])` can be reused in `op(f, prep, b, other_x, [other_t])`, provided that: -- the inputs `x` and `other_x` have the same types and sizes -- the tangents in `t` and `other_t` have the same types and sizes +- the inputs `x` and `other_x` have similar types and equal shapes +- the tangents in `t` and `other_t` have similar types and equal shapes For same-point preparation, the output `prep` of `prepare_op_same_point(f, b, x, [t])` can be reused in `op(f, prep, b, x, other_t)`, provided that: -- the input `x` remains exactly the same (as well as any [`Constant`](@ref) context) -- the tangents in `t` and `other_t` have the same types and sizes +- the input `x` remains the same (as well as the [`Context`](@ref) constants) +- the tangents in `t` and `other_t` have similar types and equal shapes !!! warning - These rules hold for the majority of backends, but there are some exceptions. - The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function. + These rules hold for the majority of backends, but there are some exceptions. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 25beaf461..e15818590 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -117,8 +117,8 @@ end function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x ) where {F} - valB = pick_batchsize(backend, length(x)) - shadows = create_shadows(valB, x) + B = pick_batchsize(backend, length(x)) + shadows = create_shadows(Val(B), x) return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows) end @@ -180,8 +180,8 @@ function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x ) where {F} y = f(x) - valB = pick_batchsize(backend, length(x)) - shadows = create_shadows(valB, x) + B = pick_batchsize(backend, length(x)) + shadows = create_shadows(Val(B), x) return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 67152f97f..52d99a414 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -349,15 +349,11 @@ end struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end -function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B} - return EnzymeReverseOneArgJacobianPrep{Sy,B}() -end - function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F} y = f(x) Sy = size(y) - valB = pick_batchsize(backend, prod(Sy)) - return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB) + B = pick_batchsize(backend, prod(Sy)) + return EnzymeReverseOneArgJacobianPrep{Sy,B}() end function DI.jacobian( From aad99b7b1370c3d3c3bdfde0ba3ab2b63dbf1202 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:08:20 +0200 Subject: [PATCH 2/2] Fix --- .../docs/src/explanation/backends.md | 11 ++++++++--- .../docs/src/explanation/operators.md | 11 ++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index f1edbaad6..0ad276905 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -137,9 +137,12 @@ For every operator, preparation generates an [executable function](https://brian ### FiniteDiff Whenever possible, preparation creates a cache object. +Pushforward is implemented rather slowly using a closure. ### FiniteDifferences +Nothing specific to mention. + ### ForwardDiff We implement [`pushforward`](@ref) directly using [`Dual` numbers](https://juliadiff.org/ForwardDiff.jl/stable/dev/how_it_works/), and preparation allocates the necessary space. @@ -152,9 +155,12 @@ Most operators fall back on `AutoForwardDiff`. ### ReverseDiff Wherever possible, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution. +This tape is computed from the arguments `x` and `contexts...` provided at preparation time. +It is control-flow dependent, so only one branch is recorded at each `if` statement. -!!! warning - This tape is specific to the control flow inside the function, and cannot be reused if the control flow is value-dependent (like `if x[1] > 0`). +!!! danger + If your function has value-specific control flow (like `if x[1] > 0` or `if c == 1`), you may get silently wrong results whenever it takes new branches that were not taken during preparation. + You must make sure to run preparation with an input and contexts whose values trigger the correct control flow for future executions. ### Symbolics @@ -176,4 +182,3 @@ Same-point preparation runs the forward sweep and returns the pullback closure a We implement `pullback` based on `Zygote.pullback`. Same-point preparation runs the forward sweep and returns the pullback closure at `x`. - diff --git a/DifferentiationInterface/docs/src/explanation/operators.md b/DifferentiationInterface/docs/src/explanation/operators.md index 1f6a2e84f..d8e257c54 100644 --- a/DifferentiationInterface/docs/src/explanation/operators.md +++ b/DifferentiationInterface/docs/src/explanation/operators.md @@ -125,13 +125,14 @@ Here are the general rules that we strive to implement: For different-point preparation, the output `prep` of `prepare_op(f, b, x, [t])` can be reused in `op(f, prep, b, other_x, [other_t])`, provided that: -- the inputs `x` and `other_x` have similar types and equal shapes -- the tangents in `t` and `other_t` have similar types and equal shapes +- the inputs `x` and `other_x` have the same types and sizes +- the tangents in `t` and `other_t` have the same types and sizes For same-point preparation, the output `prep` of `prepare_op_same_point(f, b, x, [t])` can be reused in `op(f, prep, b, x, other_t)`, provided that: -- the input `x` remains the same (as well as the [`Context`](@ref) constants) -- the tangents in `t` and `other_t` have similar types and equal shapes +- the input `x` remains exactly the same (as well as any [`Constant`](@ref) context) +- the tangents in `t` and `other_t` have the same types and sizes !!! warning - These rules hold for the majority of backends, but there are some exceptions. \ No newline at end of file + These rules hold for the majority of backends, but there are some exceptions. + The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function.