Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation of modes #1895

Merged
merged 10 commits into from
Sep 28, 2024
263 changes: 229 additions & 34 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,50 +189,172 @@ end
abstract type ABI

Abstract type for what ABI will be used.

# Subtypes

- [`FFIABI`](@ref) (the default)
- [`InlineABI`](@ref)
- [`NonGenABI`](@ref)
"""
abstract type ABI end

"""
struct FFIABI <: ABI

Foreign function call ABI. JIT the differentiated function, then inttoptr call the address.
Foreign function call [`ABI`](@ref). JIT the differentiated function, then inttoptr call the address.
"""
struct FFIABI <: ABI end

"""
struct InlineABI <: ABI

Inlining function call ABI.
Inlining function call [`ABI`](@ref).
"""
struct InlineABI <: ABI end

"""
struct NonGenABI <: ABI

Non-generated function ABI.
Non-generated function [`ABI`](@ref).
"""
struct NonGenABI <: ABI end

const DefaultABI = FFIABI

"""
abstract type Mode
abstract type Mode{ABI,ErrIfFuncWritten,RuntimeActivity}

Abstract type for which differentiation mode will be used.

Abstract type for what differentiation mode will be used.
# Subtypes

- [`ForwardMode`](@ref)
- [`ReverseMode`](@ref)
- [`ReverseModeSplit`](@ref)

# Type parameters

- `ABI`: what runtime [`ABI`](@ref) to use
- `ErrIfFuncWritten`: whether to error when the function differentiated is a closure and written to.
- `RuntimeActivity`: whether to enable runtime activity (default off)

!!! warning
The type parameters of `Mode` are not part of the public API and can change without notice.
You can modify them with the following helper functions:
- [`WithPrimal`](@ref) / [`NoPrimal`](@ref)
- [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref)
- [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref)
- [`set_abi`](@ref)
"""
abstract type Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end

"""
struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity}
set_err_if_func_written(::Mode)

Return a new mode which throws an error for any attempt to write into an unannotated function object.
"""
function set_err_if_func_written end

"""
clear_err_if_func_written(::Mode)

Return a new mode which doesn't throw an error for attempts to write into an unannotated function object.
"""
function clear_err_if_func_written end

"""
set_runtime_activity(::Mode)

Return a new mode where runtime activity analysis is activated.
"""
function set_runtime_activity end

"""
clear_runtime_activity(::Mode)

Return a new mode where runtime activity analysis is deactivated.
"""
function clear_runtime_activity end

"""
set_abi(::Mode, ::Type{ABI})

Return a new mode with its [`ABI`](@ref) set to the chosen type.
"""
function set_abi end

"""
WithPrimal(::Mode)

Return a new mode which includes the primal value.
"""
function WithPrimal end

Reverse mode differentiation.
- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward.
- `RuntimeActivity`: Should Enzyme enable runtime activity (default off)
- `ABI`: What runtime ABI to use
- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz
- `ErrIfFuncWritten`: Should Enzyme err if the function differentiated is a closure and written to.
"""
NoPrimal(::Mode)

Return a new mode which excludes the primal value.
"""
function NoPrimal end

"""
struct ReverseMode{
ReturnPrimal,
RuntimeActivity,
ABI,
Holomorphic,
ErrIfFuncWritten
} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}

Subtype of [`Mode`](@ref) for reverse mode differentiation.

# Type parameters

- `ReturnPrimal`: whether to return the primal return value from the augmented-forward pass.
- `Holomorphic`: Whether the complex result function is holomorphic and we should compute `d/dz`
- other parameters: see [`Mode`](@ref)

!!! warning
The type parameters of `ReverseMode` are not part of the public API and can change without notice.
Please use one of the following concrete instantiations instead:
- [`Reverse`](@ref)
- [`ReverseWithPrimal`](@ref)
- [`ReverseHolomorphic`](@ref)
- [`ReverseHolomorphicWithPrimal`](@ref)
You can modify them with the following helper functions:
- [`WithPrimal`](@ref) / [`NoPrimal`](@ref)
- [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref)
- [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref)
- [`set_abi`](@ref)
"""
struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end

"""
const Reverse

Default instance of [`ReverseMode`](@ref) that doesn't return the primal
"""
const Reverse = ReverseMode{false,false,DefaultABI, false, false}()

"""
const ReverseWithPrimal

Default instance of [`ReverseMode`](@ref) that also returns the primal.
"""
const ReverseWithPrimal = ReverseMode{true,false,DefaultABI, false, false}()

"""
const ReverseHolomorphic

Holomorphic instance of [`ReverseMode`](@ref) that doesn't return the primal
"""
const ReverseHolomorphic = ReverseMode{false,false,DefaultABI, true, false}()

"""
const ReverseHolomorphicWithPrimal

Holomorphic instance of [`ReverseMode`](@ref) that also returns the primal
"""
const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, false}()

@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}()
Expand All @@ -244,35 +366,75 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa
@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}()

@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()

@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()

"""
struct ReverseModeSplit{
ReturnPrimal,
ReturnShadow,
Width,
RuntimeActivity,
ModifiedBetween,
ABI,
ErrFuncIfWritten
} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}
WithPrimal(::Enzyme.Mode)

Modifies the mode to include the primal value.
"""
@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()
Subtype of [`Mode`](@ref) for split reverse mode differentiation, to use in [`autodiff_thunk`](@ref) and variants.

# Type parameters

- `ReturnShadow`: whether to return the shadow return value from the augmented-forward.
- `Width`: batch size (pick `0` to derive it automatically)
- `ModifiedBetween`: `Tuple` of each argument's "modified between" state (pick `true` to derive it automatically).
- other parameters: see [`ReverseMode`](@ref)

!!! warning
The type parameters of `ReverseModeSplit` are not part of the public API and can change without notice.
Please use one of the following concrete instantiations instead:
- [`ReverseSplitNoPrimal`](@ref)
- [`ReverseSplitWithPrimal`](@ref)
You can modify them with the following helper functions:
- [`WithPrimal`](@ref) / [`NoPrimal`](@ref)
- [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref)
- [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref)
- [`set_abi`](@ref)
- [`ReverseSplitModified`](@ref), [`ReverseSplitWidth`](@ref)
"""
NoPrimal(::Enzyme.Mode)
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end

Modifies the mode to exclude the primal value.
"""
@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()
const ReverseSplitNoPrimal

Default instance of [`ReverseModeSplit`](@ref) that doesn't return the primal
"""
const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}()

"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}
const ReverseSplitWithPrimal

Reverse mode differentiation.
- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward.
- `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward.
- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off).
- `Width`: Batch Size (0 if to be automatically derived)
- `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived).
Default instance of [`ReverseModeSplit`](@ref) that also returns the primal
"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end
const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}()
const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}()

"""
ReverseSplitModified(::ReverseModeSplit, ::Val{MB})

Return a new instance of [`ReverseModeSplit`](@ref) mode where `ModifiedBetween` is set to `MB`.
"""
function ReverseSplitModified end

@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}()

"""
ReverseSplitWidth(::ReverseModeSplit, ::Val{W})

Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to `W`.
"""
function ReverseSplitWidth end

@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}()

@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}()
Expand All @@ -287,13 +449,46 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau


"""
struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity}
struct ForwardMode{
ReturnPrimal,
ABI,
ErrIfFuncWritten,
RuntimeActivity
} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}

Forward mode differentiation
Subtype of [`Mode`](@ref) for forward mode differentiation.

# Type parameters

- `ReturnPrimal`: whether to return the primal return value from the augmented-forward.
- other parameters: see [`Mode`](@ref)

!!! warning
The type parameters of `ForwardMode` are not part of the public API and can change without notice.
Please use one of the following concrete instantiations instead:
- [`Forward`](@ref)
- [`ForwardWithPrimal`](@ref)
You can modify them with the following helper functions:
- [`WithPrimal`](@ref) / [`NoPrimal`](@ref)
- [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref)
- [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref)
- [`set_abi`](@ref)
"""
struct ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity}
end

"""
const Forward

Default instance of [`ForwardMode`](@ref) that doesn't return the primal
"""
const Forward = ForwardMode{false, DefaultABI, false, false}()

"""
const ForwardWithPrimal

Default instance of [`ForwardMode`](@ref) that also returns the primal
"""
const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}()

@inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}()
Expand All @@ -317,22 +512,22 @@ function autodiff_deferred_thunk end
"""
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T

Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing

Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
"""
function make_zero! end

"""
make_zero(prev::T)

Helper function to recursively make zero.
Helper function to recursively make zero.
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
Expand Down Expand Up @@ -363,7 +558,7 @@ if !isdefined(Base, :get_extension)
end

"""
within_autodiff()
within_autodiff()

Returns true if within autodiff, otherwise false.
"""
Expand Down
Loading