From c47e26c0e096d098485446f82eb1df5cb381d07a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 3 Oct 2024 22:00:23 +0200 Subject: [PATCH] Force constant chunk size when specified in ForwardDiff (#539) * Force constant chunk size when specified in ForwardDiff * Fix * Fix * Tests --- .../DifferentiationInterfaceForwardDiffExt.jl | 5 +++++ .../utils.jl | 2 +- ...tiationInterfacePolyesterForwardDiffExt.jl | 15 +++++++++++-- .../src/DifferentiationInterface.jl | 1 + DifferentiationInterface/src/utils/basis.jl | 21 ------------------ .../src/utils/batchsize.jl | 22 +++++++++++++++++++ .../test/Back/ForwardDiff/test.jl | 8 +++++++ .../test/Misc/Internals/backends.jl | 3 +++ .../src/DifferentiationInterfaceTest.jl | 2 +- .../src/scenarios/scenario.jl | 8 +++++++ .../src/test_differentiation.jl | 16 ++++++++++---- DifferentiationInterfaceTest/test/standard.jl | 4 +--- 12 files changed, 75 insertions(+), 32 deletions(-) create mode 100644 DifferentiationInterface/src/utils/batchsize.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index bfdbc6e9f..e421153ac 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -57,6 +57,11 @@ function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer) return Val(ForwardDiff.pickchunksize(dimension)) end +function DI.threshold_batchsize(backend::AutoForwardDiff{C1}, C2::Integer) where {C1} + C = (C1 === nothing) ? nothing : min(C1, C2) + return AutoForwardDiff(; chunksize=C, tag=backend.tag) +end + include("utils.jl") include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 5ad8c5dec..11a23183a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -1,5 +1,5 @@ choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) -choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}() +choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}() tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x))) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 8a98e0fd1..8261e3e4e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -22,12 +22,23 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! using PolyesterForwardDiff.ForwardDiff: Chunk using PolyesterForwardDiff.ForwardDiff.DiffResults: DiffResults -DI.check_available(::AutoPolyesterForwardDiff) = true - function single_threaded(backend::AutoPolyesterForwardDiff{C,T}) where {C,T} return AutoForwardDiff{C,T}(backend.tag) end +DI.check_available(::AutoPolyesterForwardDiff) = true + +function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, dimension::Integer) + return DI.pick_batchsize(single_threaded(backend), dimension) +end + +function DI.threshold_batchsize( + backend::AutoPolyesterForwardDiff{C1}, C2::Integer +) where {C1} + C = (C1 === nothing) ? nothing : min(C1, C2) + return AutoPolyesterForwardDiff(; chunksize=C, tag=backend.tag) +end + include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 80b9ff818..3cd720043 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -37,6 +37,7 @@ include("second_order/second_order.jl") include("utils/prep.jl") include("utils/traits.jl") include("utils/basis.jl") +include("utils/batchsize.jl") include("utils/check.jl") include("utils/exceptions.jl") include("utils/printing.jl") diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index b6e636f6c..2404efea1 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -71,24 +71,3 @@ function multibasis(a::AbstractArray{T,N}, inds::AbstractVector) where {T,N} end return seed end - -""" - pick_batchsize(backend::AbstractADType, dimension::Integer) - -Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`. - -Returns `Val(1)` for backends which have not overloaded it. -""" -pick_batchsize(::AbstractADType, dimension::Integer) = Val(1) - -function pick_jacobian_batchsize( - ::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer -) - return pick_batchsize(backend, N) -end - -function pick_jacobian_batchsize( - ::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer -) - return pick_batchsize(backend, M) -end diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl new file mode 100644 index 000000000..172b4b266 --- /dev/null +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -0,0 +1,22 @@ +""" + pick_batchsize(backend::AbstractADType, dimension::Integer) + +Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`. + +Returns `Val(1)` for backends which have not overloaded it. +""" +pick_batchsize(::AbstractADType, dimension::Integer) = Val(1) + +function pick_jacobian_batchsize( + ::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer +) + return pick_batchsize(backend, N) +end + +function pick_jacobian_batchsize( + ::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer +) + return pick_batchsize(backend, M) +end + +threshold_batchsize(backend::AbstractADType, ::Integer) = backend diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f572c643f..8be107a85 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -35,6 +35,14 @@ test_differentiation( AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING ); +test_differentiation( + AutoForwardDiff(; chunksize=5); + correctness=false, + type_stability=true, + preparation_type_stability=true, + logging=LOGGING, +); + test_differentiation( dense_backends, # ForwardDiff accesses individual indices diff --git a/DifferentiationInterface/test/Misc/Internals/backends.jl b/DifferentiationInterface/test/Misc/Internals/backends.jl index 64f93ba91..7b0366bbd 100644 --- a/DifferentiationInterface/test/Misc/Internals/backends.jl +++ b/DifferentiationInterface/test/Misc/Internals/backends.jl @@ -40,4 +40,7 @@ end @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == Val(4) @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == Val(4) @test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 100) == Val(4) + @test DI.threshold_batchsize(AutoForwardDiff(), 2) isa AutoForwardDiff{nothing} + @test DI.threshold_batchsize(AutoForwardDiff(; chunksize=4), 2) isa AutoForwardDiff{2} + @test DI.threshold_batchsize(AutoForwardDiff(; chunksize=4), 6) isa AutoForwardDiff{4} end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index bfc346a24..663723678 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -50,8 +50,8 @@ using DifferentiationInterface: PushforwardPrep, SecondDerivativePrep, Rewrap -using DocStringExtensions import DifferentiationInterface as DI +using DocStringExtensions using Functors: fmap using JET: JET using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 30484c469..e160f6980 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -121,3 +121,11 @@ function Base.show( end return nothing end + +adapt_batchsize(backend::AbstractADType, ::Scenario) = backend + +function adapt_batchsize( + backend::Union{ADTypes.AutoForwardDiff,ADTypes.AutoPolyesterForwardDiff}, scen::Scenario +) + return DI.threshold_batchsize(backend, length(scen.x)) +end diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 7e515d462..97363d45d 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -109,16 +109,23 @@ function test_differentiation( (:nb_contexts, length(scen.contexts)), ], ) + adapted_backend = adapt_batchsize(backend, scen) correctness && @testset "Correctness" begin - test_correctness(backend, scen; isapprox, atol, rtol, scenario_intact) + test_correctness( + adapted_backend, scen; isapprox, atol, rtol, scenario_intact + ) end type_stability && @testset "Type stability" begin @static if VERSION >= v"1.7" - test_jet(backend, scen; test_preparation=preparation_type_stability) + test_jet( + adapted_backend, + scen; + test_preparation=preparation_type_stability, + ) end end sparsity && @testset "Sparsity" begin - test_sparsity(backend, scen) + test_sparsity(adapted_backend, scen) end yield() end @@ -186,7 +193,8 @@ function benchmark_differentiation( (:nb_contexts, length(scen.contexts)), ], ) - run_benchmark!(benchmark_data, backend, scen; logging) + adapted_backend = adapt_batchsize(backend, scen) + run_benchmark!(benchmark_data, adapted_backend, scen; logging) yield() end end diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index a69946db4..4b3d8b614 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -12,9 +12,7 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( AutoForwardDiff(), - default_scenarios( - Random.default_rng(); include_batchified=true, include_constantified=true - ); + default_scenarios(Random.default_rng(); include_constantified=true); logging=LOGGING, )