Skip to content

Commit

Permalink
Revamp batch size handling (#575)
Browse files Browse the repository at this point in the history
* Batch size

* Revamp batch size computations

* Fix

* Single batch modifications

* Introduce BatchSizeSettings

* Fix PolyesterFD

* Fix

* Fix internals

* Type stab

* Fix

* Fixes

* Coverage

* Fix

* Guess activity in Enzyme

* Fix

* Fixes

* No static test

* More coverage

* AutoSparse with adaptive backends

* Proper thresholding

* Fix

* Fixes

* Fix

* Up

* Fix
  • Loading branch information
gdalle authored Oct 14, 2024
1 parent fd7580c commit 4496997
Show file tree
Hide file tree
Showing 25 changed files with 629 additions and 301 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.12"
version = "0.6.13"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
20 changes: 20 additions & 0 deletions DifferentiationInterface/docs/src/explanation/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,23 @@ The complexity of sparse Jacobians or Hessians grows with the number of distinct
To reduce this number of colors, [`GreedyColoringAlgorithm`](@ref) has two main settings: the order used for vertices and the decompression method.
Depending on your use case, you may want to modify either of these options to increase performance.
See the documentation of [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) for details.

## Batch mode

### Multiple tangents

The [`jacobian`](@ref) and [`hessian`](@ref) operators compute matrices by repeatedly applying lower-level operators ([`pushforward`](@ref), [`pullback`](@ref) or [`hvp`](@ref)) to a set of tangents.
The tangents usually correspond to basis elements of the appropriate vector space.
We could call the lower-level operator on each tangent separately, but some packages ([ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)) have optimized implementations to handle multiple tangents at once.

This behavior is often called "vector mode" AD, but we call it "batch mode" to avoid confusion with Julia's `Vector` type.
As a matter of fact, the optimal batch size $B$ (number of simultaneous tangents) is usually very small, so tangents are passed within an `NTuple` and not a `Vector`.
When the underlying vector space has dimension $N$, the operators `jacobian` and `hessian` process $\lceil N / B \rceil$ batches of size $B$ each.

### Optimal batch size

For every backend which does not support batch mode, the batch size is set to $B = 1$.
But for [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) and [`AutoEnzyme`](@extref ADTypes.AutoEnzyme), more complicated rules apply.
If the backend object has a pre-determined batch size $B_0$, then we always set $B = B_0$.
In particular, this will throw errors when $N < B_0$.
On the other hand, without a pre-determined batch size, we apply backend-specific heuristics to pick $B$ based on $N$.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ using DifferentiationInterface:
NoHVPPrep,
NoJacobianPrep,
NoPullbackPrep,
NoPushforwardPrep,
pick_batchsize
NoPushforwardPrep
using Enzyme:
Active,
Annotation,
BatchDuplicated,
BatchMixedDuplicated,
Const,
Duplicated,
DuplicatedNoNeed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ end
function DI.prepare_gradient(
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
) where {F}
valB = pick_batchsize(backend, length(x))
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(valB, shadows)
end
Expand Down Expand Up @@ -190,7 +190,7 @@ function DI.prepare_jacobian(
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
) where {F}
y = f(x)
valB = pick_batchsize(backend, length(x))
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ end
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
y = f(x)
Sy = size(y)
valB = pick_batchsize(backend, prod(Sy))
valB = to_val(DI.pick_batchsize(backend, y))
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(min(dimension, 16))
function DI.BatchSizeSettings(::AutoEnzyme, N::Integer)
B = DI.reasonable_batchsize(N, 16)
singlebatch = B == N
aligned = N % B == 0
return DI.BatchSizeSettings{B,singlebatch,aligned}(N)
end

to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)

## Annotations

Expand All @@ -17,9 +24,10 @@ function get_f_and_df(
M,
<:Union{
Duplicated,
EnzymeCore.DuplicatedNoNeed,
MixedDuplicated,
BatchDuplicated,
EnzymeCore.BatchDuplicatedFunc,
BatchMixedDuplicated,
EnzymeCore.DuplicatedNoNeed,
EnzymeCore.BatchDuplicatedNoNeed,
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
using Base: Fix1, Fix2
import DifferentiationInterface as DI
using DifferentiationInterface:
BatchSizeSettings,
Context,
DerivativePrep,
DifferentiateWith,
Expand Down Expand Up @@ -49,24 +50,6 @@ using LinearAlgebra: dot, mul!

DI.check_available(::AutoForwardDiff) = true

function DI.pick_batchsize(
::AutoForwardDiff{chunksize}, dimension::Integer
) where {chunksize}
return Val{chunksize}()
end

function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
# type-unstable
return Val(ForwardDiff.pickchunksize(dimension))
end

function DI.threshold_batchsize(
backend::AutoForwardDiff{chunksize1}, chunksize2::Integer
) where {chunksize1}
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
return AutoForwardDiff(; chunksize, tag=backend.tag)
end

include("utils.jl")
include("onearg.jl")
include("twoarg.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, N::Integer)
B = ForwardDiff.pickchunksize(N)
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
end

function DI.BatchSizeSettings(::AutoForwardDiff{chunksize}, N::Integer) where {chunksize}
if chunksize > N
throw(ArgumentError("Fixed chunksize $chunksize larger than input size $N"))
end
B = chunksize
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
end

function DI.threshold_batchsize(
backend::AutoForwardDiff{chunksize1}, chunksize2::Integer
) where {chunksize1}
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
return AutoForwardDiff(; chunksize, tag=backend.tag)
end

choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksize}()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ end

DI.check_available(::AutoPolyesterForwardDiff) = true

function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, dimension::Integer)
return DI.pick_batchsize(single_threaded(backend), dimension)
function DI.BatchSizeSettings(backend::AutoPolyesterForwardDiff, x_or_N)
return DI.BatchSizeSettings(single_threaded(backend), x_or_N)
end

function DI.threshold_batchsize(
backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer
) where {chunksize1}
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ using ADTypes:
hessian_sparsity
using DifferentiationInterface
using DifferentiationInterface:
BatchSizeSettings,
GradientPrep,
HessianPrep,
HVPPrep,
JacobianPrep,
PullbackPrep,
PushforwardPrep,
PushforwardFast,
PushforwardSlow,
PushforwardPerformance,
inner,
outer,
multibasis,
pick_hessian_batchsize,
pick_jacobian_batchsize,
pick_batchsize,
pushforward_performance,
unwrap,
with_contexts
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
struct SparseHessianPrep{
B,
BS<:BatchSizeSettings,
C<:AbstractColoringResult{:symmetric,:column},
M<:AbstractMatrix{<:Real},
TD<:NTuple{B},
TR<:NTuple{B},
S<:AbstractVector{<:NTuple},
R<:AbstractVector{<:NTuple},
E2<:HVPPrep,
E1<:GradientPrep,
} <: HessianPrep
batch_size_settings::BS
coloring_result::C
compressed_matrix::M
batched_seeds::Vector{TD}
batched_results::Vector{TR}
batched_seeds::S
batched_results::R
hvp_prep::E2
gradient_prep::E1
end

function SparseHessianPrep{B}(;
coloring_result::C,
compressed_matrix::M,
batched_seeds::Vector{TD},
batched_results::Vector{TR},
hvp_prep::E2,
gradient_prep::E1,
) where {B,C,M,TD,TR,E2,E1}
return SparseHessianPrep{B,C,M,TD,TR,E2,E1}(
coloring_result,
compressed_matrix,
batched_seeds,
batched_results,
hvp_prep,
gradient_prep,
)
end

SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result)
SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
Expand All @@ -42,13 +25,6 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
function DI.prepare_hessian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_hessian_batchsize(dense_ad(backend), length(x))
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
end

function _prepare_sparse_hessian_aux(
::Val{B}, f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {B,F,C}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(
with_contexts(f, contexts...), x, sparsity_detector(backend)
Expand All @@ -57,18 +33,34 @@ function _prepare_sparse_hessian_aux(
coloring_result = coloring(
sparsity, problem, coloring_algorithm(backend); decompression_eltype=eltype(x)
)
N = length(column_groups(coloring_result))
batch_size_settings = pick_batchsize(outer(dense_backend), N)
return _prepare_sparse_hessian_aux(
batch_size_settings, coloring_result, f, backend, x, contexts...
)
end

function _prepare_sparse_hessian_aux(
batch_size_settings::BatchSizeSettings{B},
coloring_result::AbstractColoringResult{:symmetric,:column},
f::F,
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {B,F,C}
(; N, A) = batch_size_settings
dense_backend = dense_ad(backend)
groups = column_groups(coloring_result)
Ng = length(groups)
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for
a in 1:div(Ng, B, RoundUp)
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
hvp_prep = prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...)
gradient_prep = prepare_gradient(f, inner(dense_backend), x, contexts...)
return SparseHessianPrep{B}(;
return SparseHessianPrep(
batch_size_settings,
coloring_result,
compressed_matrix,
batched_seeds,
Expand All @@ -81,14 +73,21 @@ end
function DI.hessian!(
f::F,
hess,
prep::SparseHessianPrep{B},
prep::SparseHessianPrep{<:BatchSizeSettings{B}},
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {F,B,C}
(; coloring_result, compressed_matrix, batched_seeds, batched_results, hvp_prep) = prep
(;
batch_size_settings,
coloring_result,
compressed_matrix,
batched_seeds,
batched_results,
hvp_prep,
) = prep
(; N) = batch_size_settings
dense_backend = dense_ad(backend)
Ng = length(column_groups(coloring_result))

hvp_prep_same = prepare_hvp_same_point(
f, hvp_prep, dense_backend, x, batched_seeds[1], contexts...
Expand All @@ -107,7 +106,7 @@ function DI.hessian!(

for b in eachindex(batched_results[a])
copyto!(
view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % N),
vec(batched_results[a][b]),
)
end
Expand Down
Loading

0 comments on commit 4496997

Please sign in to comment.