Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Jacobian and Hessian preparation #535

Merged
merged 8 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.4"
version = "0.6.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -60,7 +60,7 @@ PolyesterForwardDiff = "0.1.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
SparseMatrixColorings = "0.4.4"
SparseMatrixColorings = "0.4.5"
Symbolics = "5.27.1, 6"
Tracker = "0.2.33"
Zygote = "0.6.69"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,16 @@ struct EnzymeForwardGradientPrep{B,O} <: GradientPrep
shadows::O
end

function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
return EnzymeForwardGradientPrep{B,O}(shadows)
end

function DI.prepare_gradient(
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
) where {F}
B = pick_batchsize(backend, length(x))
shadows = create_shadows(Val(B), x)
return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows)
valB = pick_batchsize(backend, length(x))
shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(valB, shadows)
end

function DI.gradient(
Expand Down Expand Up @@ -176,13 +180,19 @@ struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep
output_length::Int
end

function EnzymeForwardOneArgJacobianPrep(
::Val{B}, shadows::O, output_length::Integer
) where {B,O}
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
end

function DI.prepare_jacobian(
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
) where {F}
y = f(x)
B = pick_batchsize(backend, length(x))
shadows = create_shadows(Val(B), x)
return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y))
valB = pick_batchsize(backend, length(x))
shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
end

function DI.jacobian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,15 @@ end

struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end

function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
end

function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
y = f(x)
Sy = size(y)
B = pick_batchsize(backend, prod(Sy))
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
valB = pick_batchsize(backend, prod(Sy))
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
end

function DI.jacobian(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = min(dimension, 16)
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16)

## Annotations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ using LinearAlgebra: dot, mul!

DI.check_available(::AutoForwardDiff) = true

function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C}
if isnothing(C)
return ForwardDiff.pickchunksize(dimension)
else
return min(dimension, C)
end
DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C)

function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
# type-unstable
return Val(ForwardDiff.pickchunksize(dimension))
end

include("utils.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using DifferentiationInterface:
outer,
multibasis,
pick_batchsize,
pick_jacobian_batchsize,
pushforward_performance,
unwrap,
with_contexts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
function DI.prepare_hessian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(backend, length(x))
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
end

function _prepare_sparse_hessian_aux(
::Val{B}, f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {B,F,C}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(
with_contexts(f, contexts...), x, sparsity_detector(backend)
Expand All @@ -52,7 +59,6 @@ function DI.prepare_hessian(
)
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(outer(dense_backend), Ng)
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,28 @@ function DI.prepare_jacobian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
y = f(x, map(unwrap, contexts)...)
return _prepare_sparse_jacobian_aux(
pushforward_performance(backend), y, (f,), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_sparse_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
end

function DI.prepare_jacobian(
f!::F, y, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
return _prepare_sparse_jacobian_aux(
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_sparse_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
end

function _prepare_sparse_jacobian_aux(
::PushforwardFast, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {FY,C}
::PushforwardFast,
::Val{B},
y,
f_or_f!y::FY,
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {B,FY,C}
dense_backend = dense_ad(backend)

sparsity = jacobian_sparsity(
Expand All @@ -104,7 +110,6 @@ function _prepare_sparse_jacobian_aux(
)
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
batched_seeds = [
Expand All @@ -121,8 +126,14 @@ function _prepare_sparse_jacobian_aux(
end

function _prepare_sparse_jacobian_aux(
::PushforwardSlow, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {FY,C}
::PushforwardSlow,
::Val{B},
y,
f_or_f!y::FY,
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {B,FY,C}
dense_backend = dense_ad(backend)
sparsity = jacobian_sparsity(
fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
Expand All @@ -136,7 +147,6 @@ function _prepare_sparse_jacobian_aux(
)
groups = row_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
batched_seeds = [
Expand Down
38 changes: 20 additions & 18 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,29 @@ function prepare_jacobian(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
y = f(x, map(unwrap, contexts)...)
return _prepare_jacobian_aux(
pushforward_performance(backend), y, (f,), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
end

function prepare_jacobian(
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
return _prepare_jacobian_aux(
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
end

function _prepare_jacobian_aux(
::PushforwardFast,
::Val{B},
y,
f_or_f!y::FY,
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,C}
) where {B,FY,C}
N = length(x)
B = pick_batchsize(backend, N)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
Expand All @@ -128,14 +128,14 @@ end

function _prepare_jacobian_aux(
::PushforwardSlow,
::Val{B},
y,
f_or_f!y::FY,
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,C}
) where {B,FY,C}
M = length(y)
B = pick_batchsize(backend, M)
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
Expand Down Expand Up @@ -241,13 +241,14 @@ function _jacobian_aux(
batched_seeds[a],
contexts...,
)
stack(vec, dy_batch; dims=2)
block = stack(vec, dy_batch; dims=2)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
block
end

jac = reduce(hcat, jac_blocks)
if N < size(jac, 2)
jac = jac[:, 1:N]
end
return jac
end

Expand All @@ -268,13 +269,14 @@ function _jacobian_aux(
dx_batch = pullback(
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
)
stack(vec, dx_batch; dims=1)
block = stack(vec, dx_batch; dims=1)
if M % B != 0 && a == lastindex(batched_seeds)
block = block[1:(M - (a - 1) * B), :]
end
block
end

jac = reduce(vcat, jac_blocks)
if M < size(jac, 1)
jac = jac[1:M, :]
end
return jac
end

Expand Down
14 changes: 12 additions & 2 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ end
function prepare_hessian(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(backend, length(x))
return _prepare_hessian_aux(valB, f, backend, x, contexts...)
end

function _prepare_hessian_aux(
::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {B,F,C}
N = length(x)
B = pick_batchsize(outer(backend), N)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
Expand Down Expand Up @@ -107,7 +113,11 @@ function hessian(

hess_blocks = map(eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
stack(vec, dg_batch; dims=2)
block = stack(vec, dg_batch; dims=2)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
block
end

hess = reduce(hcat, hess_blocks)
Expand Down
16 changes: 14 additions & 2 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ end

Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.

Returns `1` for backends which have not overloaded it.
Returns `Val(1)` for backends which have not overloaded it.
"""
pick_batchsize(::AbstractADType, dimension::Integer) = 1
pick_batchsize(::AbstractADType, dimension::Integer) = Val(1)

function pick_jacobian_batchsize(
::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, N)
end

function pick_jacobian_batchsize(
::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, M)
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ test_differentiation(
);

test_differentiation(
dense_backends; correctness=false, type_stability=true, logging=LOGGING
AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING
);

test_differentiation(
Expand Down
10 changes: 7 additions & 3 deletions DifferentiationInterface/test/Misc/Internals/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ end
end

@testset "Batch size" begin
@test DI.pick_batchsize(AutoZygote(), 2) == 1
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == 2
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == 4
@test DI.pick_batchsize(AutoZygote(), 2) == Val(1)
@test DI.pick_batchsize(AutoForwardDiff(), 2) == Val(2)
@test DI.pick_batchsize(AutoForwardDiff(), 6) == Val(6)
@test DI.pick_batchsize(AutoForwardDiff(), 100) == Val(12)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 100) == Val(4)
end
Loading
Loading