Skip to content

Commit

Permalink
Force constant chunk size when specified in ForwardDiff (#539)
Browse files Browse the repository at this point in the history
* Force constant chunk size when specified in ForwardDiff

* Fix

* Fix

* Tests
  • Loading branch information
gdalle authored Oct 3, 2024
1 parent 1208b44 commit c47e26c
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
return Val(ForwardDiff.pickchunksize(dimension))
end

function DI.threshold_batchsize(backend::AutoForwardDiff{C1}, C2::Integer) where {C1}
C = (C1 === nothing) ? nothing : min(C1, C2)
return AutoForwardDiff(; chunksize=C, tag=backend.tag)
end

include("utils.jl")
include("onearg.jl")
include("twoarg.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}()
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()

tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T
tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian!
using PolyesterForwardDiff.ForwardDiff: Chunk
using PolyesterForwardDiff.ForwardDiff.DiffResults: DiffResults

DI.check_available(::AutoPolyesterForwardDiff) = true

function single_threaded(backend::AutoPolyesterForwardDiff{C,T}) where {C,T}
return AutoForwardDiff{C,T}(backend.tag)
end

DI.check_available(::AutoPolyesterForwardDiff) = true

function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, dimension::Integer)
return DI.pick_batchsize(single_threaded(backend), dimension)
end

function DI.threshold_batchsize(
backend::AutoPolyesterForwardDiff{C1}, C2::Integer
) where {C1}
C = (C1 === nothing) ? nothing : min(C1, C2)
return AutoPolyesterForwardDiff(; chunksize=C, tag=backend.tag)
end

include("onearg.jl")
include("twoarg.jl")

Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include("second_order/second_order.jl")
include("utils/prep.jl")
include("utils/traits.jl")
include("utils/basis.jl")
include("utils/batchsize.jl")
include("utils/check.jl")
include("utils/exceptions.jl")
include("utils/printing.jl")
Expand Down
21 changes: 0 additions & 21 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,3 @@ function multibasis(a::AbstractArray{T,N}, inds::AbstractVector) where {T,N}
end
return seed
end

"""
pick_batchsize(backend::AbstractADType, dimension::Integer)
Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.
Returns `Val(1)` for backends which have not overloaded it.
"""
pick_batchsize(::AbstractADType, dimension::Integer) = Val(1)

function pick_jacobian_batchsize(
::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, N)
end

function pick_jacobian_batchsize(
::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, M)
end
22 changes: 22 additions & 0 deletions DifferentiationInterface/src/utils/batchsize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
pick_batchsize(backend::AbstractADType, dimension::Integer)
Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.
Returns `Val(1)` for backends which have not overloaded it.
"""
pick_batchsize(::AbstractADType, dimension::Integer) = Val(1)

function pick_jacobian_batchsize(
::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, N)
end

function pick_jacobian_batchsize(
::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, M)
end

threshold_batchsize(backend::AbstractADType, ::Integer) = backend
8 changes: 8 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ test_differentiation(
AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING
);

test_differentiation(
AutoForwardDiff(; chunksize=5);
correctness=false,
type_stability=true,
preparation_type_stability=true,
logging=LOGGING,
);

test_differentiation(
dense_backends,
# ForwardDiff accesses individual indices
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/test/Misc/Internals/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ end
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 100) == Val(4)
@test DI.threshold_batchsize(AutoForwardDiff(), 2) isa AutoForwardDiff{nothing}
@test DI.threshold_batchsize(AutoForwardDiff(; chunksize=4), 2) isa AutoForwardDiff{2}
@test DI.threshold_batchsize(AutoForwardDiff(; chunksize=4), 6) isa AutoForwardDiff{4}
end
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ using DifferentiationInterface:
PushforwardPrep,
SecondDerivativePrep,
Rewrap
using DocStringExtensions
import DifferentiationInterface as DI
using DocStringExtensions
using Functors: fmap
using JET: JET
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
Expand Down
8 changes: 8 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ function Base.show(
end
return nothing
end

adapt_batchsize(backend::AbstractADType, ::Scenario) = backend

function adapt_batchsize(
backend::Union{ADTypes.AutoForwardDiff,ADTypes.AutoPolyesterForwardDiff}, scen::Scenario
)
return DI.threshold_batchsize(backend, length(scen.x))
end
16 changes: 12 additions & 4 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,23 @@ function test_differentiation(
(:nb_contexts, length(scen.contexts)),
],
)
adapted_backend = adapt_batchsize(backend, scen)
correctness && @testset "Correctness" begin
test_correctness(backend, scen; isapprox, atol, rtol, scenario_intact)
test_correctness(
adapted_backend, scen; isapprox, atol, rtol, scenario_intact
)
end
type_stability && @testset "Type stability" begin
@static if VERSION >= v"1.7"
test_jet(backend, scen; test_preparation=preparation_type_stability)
test_jet(
adapted_backend,
scen;
test_preparation=preparation_type_stability,
)
end
end
sparsity && @testset "Sparsity" begin
test_sparsity(backend, scen)
test_sparsity(adapted_backend, scen)
end
yield()
end
Expand Down Expand Up @@ -186,7 +193,8 @@ function benchmark_differentiation(
(:nb_contexts, length(scen.contexts)),
],
)
run_benchmark!(benchmark_data, backend, scen; logging)
adapted_backend = adapt_batchsize(backend, scen)
run_benchmark!(benchmark_data, adapted_backend, scen; logging)
yield()
end
end
Expand Down
4 changes: 1 addition & 3 deletions DifferentiationInterfaceTest/test/standard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ LOGGING = get(ENV, "CI", "false") == "false"

test_differentiation(
AutoForwardDiff(),
default_scenarios(
Random.default_rng(); include_batchified=true, include_constantified=true
);
default_scenarios(Random.default_rng(); include_constantified=true);
logging=LOGGING,
)

Expand Down

0 comments on commit c47e26c

Please sign in to comment.