diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 25beaf461..e15818590 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -117,8 +117,8 @@ end function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x ) where {F} - valB = pick_batchsize(backend, length(x)) - shadows = create_shadows(valB, x) + B = pick_batchsize(backend, length(x)) + shadows = create_shadows(Val(B), x) return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows) end @@ -180,8 +180,8 @@ function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x ) where {F} y = f(x) - valB = pick_batchsize(backend, length(x)) - shadows = create_shadows(valB, x) + B = pick_batchsize(backend, length(x)) + shadows = create_shadows(Val(B), x) return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 67152f97f..52d99a414 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -349,15 +349,11 @@ end struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end -function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B} - return EnzymeReverseOneArgJacobianPrep{Sy,B}() -end - function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F} y = f(x) Sy = size(y) - valB = pick_batchsize(backend, prod(Sy)) - return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB) + B = pick_batchsize(backend, prod(Sy)) + return EnzymeReverseOneArgJacobianPrep{Sy,B}() end function DI.jacobian(