Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Jacobian and Hessian preparation #535

Merged
merged 8 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.4"
version = "0.6.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = min(dimension, 16)
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16)

## Annotations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ using LinearAlgebra: dot, mul!

DI.check_available(::AutoForwardDiff) = true

function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C}
if isnothing(C)
return ForwardDiff.pickchunksize(dimension)
else
return min(dimension, C)
end
DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C)

function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
# type-unstable
return Val(ForwardDiff.pickchunksize(dimension))
end

include("utils.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}()
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()
gdalle marked this conversation as resolved.
Show resolved Hide resolved

tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T
tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using DifferentiationInterface:
outer,
multibasis,
pick_batchsize,
pick_jacobian_batchsize,
pushforward_performance,
unwrap,
with_contexts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
function DI.prepare_hessian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(backend, length(x))
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
end

function _prepare_sparse_hessian_aux(
::Val{B}, f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {B,F,C}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(
with_contexts(f, contexts...), x, sparsity_detector(backend)
Expand All @@ -52,7 +59,6 @@ function DI.prepare_hessian(
)
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(outer(dense_backend), Ng)
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,28 @@ function DI.prepare_jacobian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
y = f(x, map(unwrap, contexts)...)
return _prepare_sparse_jacobian_aux(
pushforward_performance(backend), y, (f,), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_sparse_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
end

function DI.prepare_jacobian(
f!::F, y, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
return _prepare_sparse_jacobian_aux(
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_sparse_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
end

function _prepare_sparse_jacobian_aux(
::PushforwardFast, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {FY,C}
::PushforwardFast,
::Val{B},
y,
f_or_f!y::FY,
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {B,FY,C}
dense_backend = dense_ad(backend)

sparsity = jacobian_sparsity(
Expand All @@ -104,7 +110,6 @@ function _prepare_sparse_jacobian_aux(
)
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
batched_seeds = [
Expand All @@ -121,8 +126,14 @@ function _prepare_sparse_jacobian_aux(
end

function _prepare_sparse_jacobian_aux(
::PushforwardSlow, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {FY,C}
::PushforwardSlow,
::Val{B},
y,
f_or_f!y::FY,
backend::AutoSparse,
x,
contexts::Vararg{Context,C},
) where {B,FY,C}
dense_backend = dense_ad(backend)
sparsity = jacobian_sparsity(
fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
Expand All @@ -136,7 +147,6 @@ function _prepare_sparse_jacobian_aux(
)
groups = row_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
batched_seeds = [
Expand Down
38 changes: 20 additions & 18 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,29 @@ function prepare_jacobian(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
y = f(x, map(unwrap, contexts)...)
return _prepare_jacobian_aux(
pushforward_performance(backend), y, (f,), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
end

function prepare_jacobian(
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
return _prepare_jacobian_aux(
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
)
perf = pushforward_performance(backend)
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
return _prepare_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
end

function _prepare_jacobian_aux(
::PushforwardFast,
::Val{B},
y,
f_or_f!y::FY,
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,C}
) where {B,FY,C}
N = length(x)
B = pick_batchsize(backend, N)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
Expand All @@ -128,14 +128,14 @@ end

function _prepare_jacobian_aux(
::PushforwardSlow,
::Val{B},
y,
f_or_f!y::FY,
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,C}
) where {B,FY,C}
M = length(y)
B = pick_batchsize(backend, M)
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
Expand Down Expand Up @@ -241,13 +241,14 @@ function _jacobian_aux(
batched_seeds[a],
contexts...,
)
stack(vec, dy_batch; dims=2)
block = stack(vec, dy_batch; dims=2)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
block
end

jac = reduce(hcat, jac_blocks)
if N < size(jac, 2)
jac = jac[:, 1:N]
end
return jac
end

Expand All @@ -268,13 +269,14 @@ function _jacobian_aux(
dx_batch = pullback(
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
)
stack(vec, dx_batch; dims=1)
block = stack(vec, dx_batch; dims=1)
if M % B != 0 && a == lastindex(batched_seeds)
block = block[1:(M - (a - 1) * B), :]
end
block
end

jac = reduce(vcat, jac_blocks)
if M < size(jac, 1)
jac = jac[1:M, :]
end
return jac
end

Expand Down
14 changes: 12 additions & 2 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ end
function prepare_hessian(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(backend, length(x))
return _prepare_hessian_aux(valB, f, backend, x, contexts...)
end

function _prepare_hessian_aux(
::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {B,F,C}
N = length(x)
B = pick_batchsize(outer(backend), N)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
Expand Down Expand Up @@ -107,7 +113,11 @@ function hessian(

hess_blocks = map(eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
stack(vec, dg_batch; dims=2)
block = stack(vec, dg_batch; dims=2)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
block
end

hess = reduce(hcat, hess_blocks)
Expand Down
16 changes: 14 additions & 2 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ end

Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.

Returns `1` for backends which have not overloaded it.
Returns `Val(1)` for backends which have not overloaded it.
"""
pick_batchsize(::AbstractADType, dimension::Integer) = 1
pick_batchsize(::AbstractADType, dimension::Integer) = Val(1)

function pick_jacobian_batchsize(
::PushforwardFast, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, N)
end

function pick_jacobian_batchsize(
::PushforwardSlow, backend::AbstractADType; M::Integer, N::Integer
)
return pick_batchsize(backend, M)
end
10 changes: 9 additions & 1 deletion DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ test_differentiation(
);

test_differentiation(
dense_backends; correctness=false, type_stability=true, logging=LOGGING
AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING
);

test_differentiation(
AutoForwardDiff(; chunksize=5);
correctness=false,
type_stability=true,
preparation_type_stability=true,
logging=LOGGING,
);

test_differentiation(
Expand Down
10 changes: 7 additions & 3 deletions DifferentiationInterface/test/Misc/Internals/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ end
end

@testset "Batch size" begin
@test DI.pick_batchsize(AutoZygote(), 2) == 1
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == 2
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == 4
@test DI.pick_batchsize(AutoZygote(), 2) == Val(1)
@test DI.pick_batchsize(AutoForwardDiff(), 2) == Val(2)
@test DI.pick_batchsize(AutoForwardDiff(), 6) == Val(6)
@test DI.pick_batchsize(AutoForwardDiff(), 100) == Val(12)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 2) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 6) == Val(4)
@test DI.pick_batchsize(AutoForwardDiff(; chunksize=4), 100) == Val(4)
end
15 changes: 14 additions & 1 deletion DifferentiationInterface/test/Misc/ZeroBackends/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
using DifferentiationInterfaceTest
using ComponentArrays: ComponentArrays
using JLArrays: JLArrays
using SparseMatrixColorings
using StaticArrays: StaticArrays
using Test

Expand All @@ -19,9 +20,10 @@ end

test_differentiation(
zero_backends,
zero.(default_scenarios(; include_constantified=true));
default_scenarios(; include_constantified=true);
correctness=true,
type_stability=true,
preparation_type_stability=true,
logging=LOGGING,
)

Expand All @@ -33,10 +35,21 @@ test_differentiation(
default_scenarios();
correctness=false,
type_stability=true,
preparation_type_stability=true,
first_order=false,
logging=LOGGING,
)

test_differentiation(
AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()),
default_scenarios(; include_constantified=true);
correctness=false,
type_stability=true,
preparation_type_stability=true,
excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative],
logging=LOGGING,
)

## Weird arrays

test_differentiation(
Expand Down
5 changes: 3 additions & 2 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Cross-test a list of `backends` on a list of `scenarios`, running a variety of d
Testing:

- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario
- `type_stability=false`: whether to check type stability with JET.jl (thanks to `JET.@test_opt`)
- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`)
- `sparsity`: whether to check sparsity of the jacobian / hessian
- `detailed=false`: whether to print a detailed or condensed test log

Expand All @@ -52,6 +52,7 @@ function test_differentiation(
# testing
correctness::Bool=true,
type_stability::Bool=false,
preparation_type_stability::Bool=false,
call_count::Bool=false,
sparsity::Bool=false,
detailed=false,
Expand Down Expand Up @@ -113,7 +114,7 @@ function test_differentiation(
end
type_stability && @testset "Type stability" begin
@static if VERSION >= v"1.7"
test_jet(backend, scen)
test_jet(backend, scen; test_preparation=preparation_type_stability)
end
end
sparsity && @testset "Sparsity" begin
Expand Down
Loading
Loading