diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a439cf430c..7a864aa51a 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -170,7 +170,7 @@ end end """ - autodiff(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ReverseMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -317,7 +317,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end """ - autodiff(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) + autodiff(mode::Mode, f, ::Type{A}, args::Annotation...) Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ @@ -345,7 +345,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. end """ - autodiff(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ForwardMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using forward mode. @@ -431,7 +431,7 @@ f(x) = x*x end """ - autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ReverseMode, f, Activity, args::Annotation...) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. @@ -472,9 +472,9 @@ code, as well as high-order differentiation. end """ - autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ForwardMode, f, Activity, args::Annotation...) -Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU +Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} @@ -532,7 +532,7 @@ code, as well as high-order differentiation. end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -628,7 +628,7 @@ end end """ - autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the thunk forward mode function for annotated function type ftype when called with args of type `argtypes`. @@ -798,7 +798,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1067,7 +1067,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ -@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{<:Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} +@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} toemit= Expr[quote act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState end]