From edf9b2c80058792ef401a2a5bff1352c47f23e3c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:00:17 +0200 Subject: [PATCH 1/8] Avoid slicing the whole Jacobian if batch size does not divide total size --- .../src/first_order/jacobian.jl | 18 ++++++++++-------- .../src/second_order/hessian.jl | 6 +++++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 04dc281df..23cf7b060 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -241,13 +241,14 @@ function _jacobian_aux( batched_seeds[a], contexts..., ) - stack(vec, dy_batch; dims=2) + block = stack(vec, dy_batch; dims=2) + if N % B != 0 && a == lastindex(batched_seeds) + block = block[:, 1:(N - (a - 1) * B)] + end + block end jac = reduce(hcat, jac_blocks) - if N < size(jac, 2) - jac = jac[:, 1:N] - end return jac end @@ -268,13 +269,14 @@ function _jacobian_aux( dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) - stack(vec, dx_batch; dims=1) + block = stack(vec, dx_batch; dims=1) + if M % B != 0 && a == lastindex(batched_seeds) + block = block[1:(M - (a - 1) * B), :] + end + block end jac = reduce(vcat, jac_blocks) - if M < size(jac, 1) - jac = jac[1:M, :] - end return jac end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 92b6d22d5..f2082ae24 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -107,7 +107,11 @@ function hessian( hess_blocks = map(eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) - stack(vec, dg_batch; dims=2) + block = stack(vec, dg_batch; dims=2) + if N % B != 0 && a == lastindex(batched_seeds) + block = block[:, 1:(N - (a - 1) * B)] + end + block end hess = reduce(hcat, hess_blocks) From d067e10ccb95d362fa39fc567e234ccb8c1b489e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:08:18 +0200 Subject: [PATCH 2/8] Make preparation type-stable and test it --- DifferentiationInterface/Project.toml | 2 +- .../utils.jl | 2 +- .../DifferentiationInterfaceForwardDiffExt.jl | 11 +++--- ...iationInterfaceSparseMatrixColoringsExt.jl | 1 + .../hessian.jl | 8 ++++- .../jacobian.jl | 34 +++++++++++------- .../src/first_order/jacobian.jl | 20 +++++------ .../src/second_order/hessian.jl | 8 ++++- DifferentiationInterface/src/utils/basis.jl | 16 +++++++-- .../test/Back/ForwardDiff/test.jl | 2 +- .../test/Misc/ZeroBackends/test.jl | 15 +++++++- .../src/test_differentiation.jl | 5 +-- .../src/tests/type_stability_eval.jl | 36 ++++++++++++------- .../test/zero_backends.jl | 1 + 14 files changed, 111 insertions(+), 50 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index de8cfedf8..43e493117 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.4" +version = "0.6.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index be64d3956..2ebfa52e8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -1,5 +1,5 @@ # until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged -DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = min(dimension, 16) +DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16) ## Annotations diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index 36e34aac3..4f884a4fb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -50,12 +50,11 @@ using LinearAlgebra: dot, mul! DI.check_available(::AutoForwardDiff) = true -function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} - if isnothing(C) - return ForwardDiff.pickchunksize(dimension) - else - return min(dimension, C) - end +DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C) + +function DI.pick_batchsize(::AutoForwardDiff{Nothing}, dimension::Integer) + # type-unstable + return Val(ForwardDiff.pickchunksize(dimension)) end include("utils.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl index 4d6e86ed9..f32d823f4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -24,6 +24,7 @@ using DifferentiationInterface: outer, multibasis, pick_batchsize, + pick_jacobian_batchsize, pushforward_performance, unwrap, with_contexts diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 459d786cd..668a87b46 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -42,6 +42,13 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result) function DI.prepare_hessian( f::F, backend::AutoSparse, x, contexts::Vararg{Context,C} ) where {F,C} + valB = pick_batchsize(backend, length(x)) + return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...) +end + +function _prepare_sparse_hessian_aux( + ::Val{B}, f::F, backend::AutoSparse, x, contexts::Vararg{Context,C} +) where {B,F,C} dense_backend = dense_ad(backend) sparsity = hessian_sparsity( with_contexts(f, contexts...), x, sparsity_detector(backend) @@ -52,7 +59,6 @@ function DI.prepare_hessian( ) groups = column_groups(coloring_result) Ng = length(groups) - B = pick_batchsize(outer(dense_backend), Ng) seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2) batched_seeds = [ diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 3e0339f0b..f55b6bca9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -74,22 +74,28 @@ function DI.prepare_jacobian( f::F, backend::AutoSparse, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) - return _prepare_sparse_jacobian_aux( - pushforward_performance(backend), y, (f,), backend, x, contexts... - ) + perf = pushforward_performance(backend) + valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y)) + return _prepare_sparse_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian( f!::F, y, backend::AutoSparse, x, contexts::Vararg{Context,C} ) where {F,C} - return _prepare_sparse_jacobian_aux( - pushforward_performance(backend), y, (f!, y), backend, x, contexts... - ) + perf = pushforward_performance(backend) + valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y)) + return _prepare_sparse_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...) end function _prepare_sparse_jacobian_aux( - ::PushforwardFast, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C} -) where {FY,C} + ::PushforwardFast, + ::Val{B}, + y, + f_or_f!y::FY, + backend::AutoSparse, + x, + contexts::Vararg{Context,C}, +) where {B,FY,C} dense_backend = dense_ad(backend) sparsity = jacobian_sparsity( @@ -104,7 +110,6 @@ function _prepare_sparse_jacobian_aux( ) groups = column_groups(coloring_result) Ng = length(groups) - B = pick_batchsize(dense_backend, Ng) seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2) batched_seeds = [ @@ -121,8 +126,14 @@ function _prepare_sparse_jacobian_aux( end function _prepare_sparse_jacobian_aux( - ::PushforwardSlow, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C} -) where {FY,C} + ::PushforwardSlow, + ::Val{B}, + y, + f_or_f!y::FY, + backend::AutoSparse, + x, + contexts::Vararg{Context,C}, +) where {B,FY,C} dense_backend = dense_ad(backend) sparsity = jacobian_sparsity( fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend) @@ -136,7 +147,6 @@ function _prepare_sparse_jacobian_aux( ) groups = row_groups(coloring_result) Ng = length(groups) - B = pick_batchsize(dense_backend, Ng) seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1) batched_seeds = [ diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 23cf7b060..a02380e94 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -86,29 +86,29 @@ function prepare_jacobian( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) - return _prepare_jacobian_aux( - pushforward_performance(backend), y, (f,), backend, x, contexts... - ) + perf = pushforward_performance(backend) + valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y)) + return _prepare_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...) end function prepare_jacobian( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - return _prepare_jacobian_aux( - pushforward_performance(backend), y, (f!, y), backend, x, contexts... - ) + perf = pushforward_performance(backend) + valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y)) + return _prepare_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...) end function _prepare_jacobian_aux( ::PushforwardFast, + ::Val{B}, y, f_or_f!y::FY, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,C} +) where {B,FY,C} N = length(x) - B = pick_batchsize(backend, N) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for @@ -128,14 +128,14 @@ end function _prepare_jacobian_aux( ::PushforwardSlow, + ::Val{B}, y, f_or_f!y::FY, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,C} +) where {B,FY,C} M = length(y) - B = pick_batchsize(backend, M) seeds = [basis(backend, y, ind) for ind in eachindex(y)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index f2082ae24..b031fb7ab 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -72,8 +72,14 @@ end function prepare_hessian( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + valB = pick_batchsize(backend, length(x)) + return _prepare_hessian_aux(valB, f, backend, x, contexts...) +end + +function _prepare_hessian_aux( + ::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} +) where {B,F,C} N = length(x) - B = pick_batchsize(outer(backend), N) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index 4d34e3dee..b6e636f6c 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -77,6 +77,18 @@ end Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`. -Returns `1` for backends which have not overloaded it. +Returns `Val(1)` for backends which have not overloaded it. """ -pick_batchsize(::AbstractADType, dimension::Integer) = 1 +pick_batchsize(::AbstractADType, dimension::Integer) = Val(1) + +function pick_jacobian_batchsize( + ::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer +) + return pick_batchsize(backend, N) +end + +function pick_jacobian_batchsize( + ::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer +) + return pick_batchsize(backend, M) +end diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 0ef6e1095..f572c643f 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -32,7 +32,7 @@ test_differentiation( ); test_differentiation( - dense_backends; correctness=false, type_stability=true, logging=LOGGING + AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING ); test_differentiation( diff --git a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl index 238004568..63b57b3bf 100644 --- a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl @@ -3,6 +3,7 @@ using DifferentiationInterface: AutoZeroForward, AutoZeroReverse using DifferentiationInterfaceTest using ComponentArrays: ComponentArrays using JLArrays: JLArrays +using SparseMatrixColorings using StaticArrays: StaticArrays using Test @@ -19,9 +20,10 @@ end test_differentiation( zero_backends, - zero.(default_scenarios(; include_constantified=true)); + default_scenarios(; include_constantified=true); correctness=true, type_stability=true, + preparation_type_stability=true, logging=LOGGING, ) @@ -33,10 +35,21 @@ test_differentiation( default_scenarios(); correctness=false, type_stability=true, + preparation_type_stability=true, first_order=false, logging=LOGGING, ) +test_differentiation( + AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()), + default_scenarios(; include_constantified=true); + correctness=false, + type_stability=true, + preparation_type_stability=true, + excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative], + logging=LOGGING, +) + ## Weird arrays test_differentiation( diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index ddcf3f477..7e515d462 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -29,7 +29,7 @@ Cross-test a list of `backends` on a list of `scenarios`, running a variety of d Testing: - `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario -- `type_stability=false`: whether to check type stability with JET.jl (thanks to `JET.@test_opt`) +- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`) - `sparsity`: whether to check sparsity of the jacobian / hessian - `detailed=false`: whether to print a detailed or condensed test log @@ -52,6 +52,7 @@ function test_differentiation( # testing correctness::Bool=true, type_stability::Bool=false, + preparation_type_stability::Bool=false, call_count::Bool=false, sparsity::Bool=false, detailed=false, @@ -113,7 +114,7 @@ function test_differentiation( end type_stability && @testset "Type stability" begin @static if VERSION >= v"1.7" - test_jet(backend, scen) + test_jet(backend, scen; test_preparation=preparation_type_stability) end end sparsity && @testset "Sparsity" begin diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index a00e8d9e1..d6427a387 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -26,18 +26,20 @@ for op in [ S2in = Scenario{op,:in,:in} if op in [:derivative, :gradient, :jacobian] - @eval function test_jet(ba::AbstractADType, scen::$S1out) + @eval function test_jet(ba::AbstractADType, scen::$S1out; test_preparation::Bool) @compat (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) JET.@test_opt $op(f, prep, ba, x, contexts...) JET.@test_call $op(f, prep, ba, x, contexts...) JET.@test_opt $val_and_op(f, prep, ba, x, contexts...) JET.@test_call $val_and_op(f, prep, ba, x, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S1in) + @eval function test_jet(ba::AbstractADType, scen::$S1in; test_preparation::Bool) @compat (; f, x, res1, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) JET.@test_opt $op!(f, res1, prep, ba, x, contexts...) JET.@test_call $op!(f, res1, prep, ba, x, contexts...) JET.@test_opt $val_and_op!(f, res1, prep, ba, x, contexts...) @@ -46,18 +48,20 @@ for op in [ op == :gradient && continue - @eval function test_jet(ba::AbstractADType, scen::$S2out) + @eval function test_jet(ba::AbstractADType, scen::$S2out; test_preparation::Bool) @compat (; f, x, y, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, y, ba, x, contexts...) JET.@test_opt $op(f, y, prep, ba, x, contexts...) JET.@test_call $op(f, y, prep, ba, x, contexts...) JET.@test_opt $val_and_op(f, y, prep, ba, x, contexts...) JET.@test_call $val_and_op(f, y, prep, ba, x, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S2in) + @eval function test_jet(ba::AbstractADType, scen::$S2in; test_preparation::Bool) @compat (; f, x, y, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, y, ba, x, contexts...) JET.@test_opt $op!(f, y, res1, prep, ba, x, contexts...) JET.@test_call $op!(f, y, res1, prep, ba, x, contexts...) JET.@test_opt $val_and_op!(f, y, res1, prep, ba, x, contexts...) @@ -65,18 +69,20 @@ for op in [ end elseif op in [:second_derivative, :hessian] - @eval function test_jet(ba::AbstractADType, scen::$S1out) + @eval function test_jet(ba::AbstractADType, scen::$S1out; test_preparation::Bool) @compat (; f, x, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) JET.@test_opt $op(f, prep, ba, x, contexts...) JET.@test_call $op(f, prep, ba, x, contexts...) JET.@test_opt $val_and_op(f, prep, ba, x, contexts...) JET.@test_call $val_and_op(f, prep, ba, x, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S1in) + @eval function test_jet(ba::AbstractADType, scen::$S1in; test_preparation::Bool) @compat (; f, x, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, contexts...) JET.@test_opt $op!(f, res2, prep, ba, x, contexts...) JET.@test_call $op!(f, res2, prep, ba, x, contexts...) JET.@test_opt $val_and_op!(f, res1, res2, prep, ba, x, contexts...) @@ -84,36 +90,40 @@ for op in [ end elseif op in [:pushforward, :pullback] - @eval function test_jet(ba::AbstractADType, scen::$S1out) + @eval function test_jet(ba::AbstractADType, scen::$S1out; test_preparation::Bool) @compat (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) JET.@test_opt $op(f, prep, ba, x, tang, contexts...) JET.@test_call $op(f, prep, ba, x, tang, contexts...) JET.@test_opt $val_and_op(f, prep, ba, x, tang, contexts...) JET.@test_call $val_and_op(f, prep, ba, x, tang, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S1in) + @eval function test_jet(ba::AbstractADType, scen::$S1in; test_preparation::Bool) @compat (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) JET.@test_opt $op!(f, res1, prep, ba, x, tang, contexts...) JET.@test_call $op!(f, res1, prep, ba, x, tang, contexts...) JET.@test_opt $val_and_op!(f, res1, prep, ba, x, tang, contexts...) JET.@test_call $val_and_op!(f, res1, prep, ba, x, tang, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S2out) + @eval function test_jet(ba::AbstractADType, scen::$S2out; test_preparation::Bool) @compat (; f, x, y, tang, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, y, ba, x, tang, contexts...) JET.@test_opt $op(f, y, prep, ba, x, tang, contexts...) JET.@test_call $op(f, y, prep, ba, x, tang, contexts...) JET.@test_opt $val_and_op(f, y, prep, ba, x, tang, contexts...) JET.@test_call $val_and_op(f, y, prep, ba, x, tang, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S2in) + @eval function test_jet(ba::AbstractADType, scen::$S2in; test_preparation::Bool) @compat (; f, x, y, tang, res1, contexts) = deepcopy(scen) prep = $prep_op(f, y, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, y, ba, x, tang, contexts...) JET.@test_opt $op!(f, y, res1, prep, ba, x, tang, contexts...) JET.@test_call $op!(f, y, res1, prep, ba, x, tang, contexts...) JET.@test_opt $val_and_op!(f, y, res1, prep, ba, x, tang, contexts...) @@ -121,16 +131,18 @@ for op in [ end elseif op in [:hvp] - @eval function test_jet(ba::AbstractADType, scen::$S1out) + @eval function test_jet(ba::AbstractADType, scen::$S1out; test_preparation::Bool) @compat (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) JET.@test_opt $op(f, prep, ba, x, tang, contexts...) JET.@test_call $op(f, prep, ba, x, tang, contexts...) end - @eval function test_jet(ba::AbstractADType, scen::$S1in) + @eval function test_jet(ba::AbstractADType, scen::$S1in; test_preparation::Bool) @compat (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) + test_preparation && JET.@test_opt $prep_op(f, ba, x, tang, contexts...) JET.@test_opt $op!(f, res2, prep, ba, x, tang, contexts...) JET.@test_call $op!(f, res2, prep, ba, x, tang, contexts...) end diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index 169be72a7..ee5ee8304 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -15,6 +15,7 @@ test_differentiation( zero.(default_scenarios()); correctness=true, type_stability=true, + preparation_type_stability=true, logging=LOGGING, ) From f6541f651d42b209b64aab91f064e3639ba396cc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:16:57 +0200 Subject: [PATCH 3/8] Pick batchsize --- .../DifferentiationInterfaceForwardDiffExt.jl | 2 +- .../test/Misc/Internals/backends.jl | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index 4f884a4fb..bfdbc6e9f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -52,7 +52,7 @@ DI.check_available(::AutoForwardDiff) = true DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C) -function DI.pick_batchsize(::AutoForwardDiff{Nothing}, dimension::Integer) +function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer) # type-unstable return Val(ForwardDiff.pickchunksize(dimension)) end diff --git a/DifferentiationInterface/test/Misc/Internals/backends.jl b/DifferentiationInterface/test/Misc/Internals/backends.jl index 5ebf10d7d..64f93ba91 100644 --- a/DifferentiationInterface/test/Misc/Internals/backends.jl +++ b/DifferentiationInterface/test/Misc/Internals/backends.jl @@ -33,7 +33,11 @@ end end @testset "Batch size" begin - @test DI.pick_batchsize(AutoZygote(), 2) == 1 - @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == 2 - @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == 4 + @test DI.pick_batchsize(AutoZygote(), 2) == Val(1) + @test DI.pick_batchsize(AutoForwardDiff(), 2) == Val(2) + @test DI.pick_batchsize(AutoForwardDiff(), 6) == Val(6) + @test DI.pick_batchsize(AutoForwardDiff(), 100) == Val(12) + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == Val(4) + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == Val(4) + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 100) == Val(4) end From 67a91d76aa12796e89e298839e1b48cf5db944e6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:46:12 +0200 Subject: [PATCH 4/8] Fix chunk size --- .../ext/DifferentiationInterfaceForwardDiffExt/utils.jl | 2 +- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 5ad8c5dec..11a23183a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -1,5 +1,5 @@ choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) -choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}() +choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}() tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x))) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f572c643f..8be107a85 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -35,6 +35,14 @@ test_differentiation( AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING ); +test_differentiation( + AutoForwardDiff(; chunksize=5); + correctness=false, + type_stability=true, + preparation_type_stability=true, + logging=LOGGING, +); + test_differentiation( dense_backends, # ForwardDiff accesses individual indices From 30919fe3a0d9f8d4e5af7c7e7490f64c96abedba Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:00:05 +0200 Subject: [PATCH 5/8] Fix ENzyme --- .../forward_onearg.jl | 22 ++++++++++++++----- .../reverse_onearg.jl | 8 +++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index e15818590..226808377 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -114,12 +114,16 @@ struct EnzymeForwardGradientPrep{B,O} <: GradientPrep shadows::O end +function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O} + return EnzymeForwardGradientPrep{B,O}(shadows) +end + function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x ) where {F} - B = pick_batchsize(backend, length(x)) - shadows = create_shadows(Val(B), x) - return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows) + valB = pick_batchsize(backend, length(x)) + shadows = create_shadows(valB, x) + return EnzymeForwardGradientPrep(valB, shadows) end function DI.gradient( @@ -176,13 +180,19 @@ struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep output_length::Int end +function EnzymeForwardOneArgJacobianPrep( + ::Val{B}, shadows::O, output_length::Integer +) where {B,O} + return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length) +end + function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x ) where {F} y = f(x) - B = pick_batchsize(backend, length(x)) - shadows = create_shadows(Val(B), x) - return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y)) + valB = pick_batchsize(backend, length(x)) + shadows = create_shadows(valB, x) + return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y)) end function DI.jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 52d99a414..67152f97f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -349,11 +349,15 @@ end struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end +function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B} + return EnzymeReverseOneArgJacobianPrep{Sy,B}() +end + function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F} y = f(x) Sy = size(y) - B = pick_batchsize(backend, prod(Sy)) - return EnzymeReverseOneArgJacobianPrep{Sy,B}() + valB = pick_batchsize(backend, prod(Sy)) + return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB) end function DI.jacobian( From 0a3377831ece2b1db6156ac4a5000be9f847154f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:18:03 +0200 Subject: [PATCH 6/8] No fully static chunk size --- .../ext/DifferentiationInterfaceForwardDiffExt/utils.jl | 2 +- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 11a23183a..5ad8c5dec 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -1,5 +1,5 @@ choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) -choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}() +choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}() tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x))) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 8be107a85..f572c643f 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -35,14 +35,6 @@ test_differentiation( AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING ); -test_differentiation( - AutoForwardDiff(; chunksize=5); - correctness=false, - type_stability=true, - preparation_type_stability=true, - logging=LOGGING, -); - test_differentiation( dense_backends, # ForwardDiff accesses individual indices From 1f3fdfe9492abc4c41b15d4c4bd31e9c50f3d798 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:27:59 +0200 Subject: [PATCH 7/8] No correctness for zero backends --- DifferentiationInterface/test/Misc/ZeroBackends/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl index 63b57b3bf..9c96a5d73 100644 --- a/DifferentiationInterface/test/Misc/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Misc/ZeroBackends/test.jl @@ -21,7 +21,7 @@ end test_differentiation( zero_backends, default_scenarios(; include_constantified=true); - correctness=true, + correctness=false, type_stability=true, preparation_type_stability=true, logging=LOGGING, From 09c206d727045b0bde5182d9cac97198492fc8f8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 18:40:09 +0200 Subject: [PATCH 8/8] Bump to next SMC --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 43e493117..60381de60 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -60,7 +60,7 @@ PolyesterForwardDiff = "0.1.1" ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" -SparseMatrixColorings = "0.4.4" +SparseMatrixColorings = "0.4.5" Symbolics = "5.27.1, 6" Tracker = "0.2.33" Zygote = "0.6.69"