Skip to content

Commit

Permalink
Special case for a single batch
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 10, 2024
1 parent 720be01 commit ebbf6a8
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ function _prepare_sparse_hessian_aux(
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for
a in 1:div(Ng, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:cld(Ng, B)
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
hvp_prep = prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...)
Expand Down Expand Up @@ -107,7 +106,7 @@ function DI.hessian!(

for b in eachindex(batched_results[a])
copyto!(
view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
view(compressed_matrix, :, mod1((a - 1) * B + (b - 1), Ng)),
vec(batched_results[a][b]),
)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ function _prepare_sparse_jacobian_aux(
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for
a in 1:div(Ng, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:cld(Ng, B)
]
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
pushforward_prep = prepare_pushforward(
Expand Down Expand Up @@ -146,12 +145,11 @@ function _prepare_sparse_jacobian_aux(
decompression_eltype=promote_type(eltype(x), eltype(y)),
)
groups = row_groups(coloring_result)
Ng = length(groups)
Mg = length(groups)
seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for
a in 1:div(Ng, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Mg)], Val(B)) for a in 1:cld(Mg, B)
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
pullback_prep = prepare_pullback(
Expand Down Expand Up @@ -265,7 +263,7 @@ function _sparse_jacobian_aux!(

for b in eachindex(batched_results[a])
copyto!(
view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
view(compressed_matrix, :, mod1((a - 1) * B + (b - 1), Ng)),
vec(batched_results[a][b]),
)
end
Expand All @@ -286,7 +284,7 @@ function _sparse_jacobian_aux!(
(; coloring_result, compressed_matrix, batched_seeds, batched_results, pullback_prep) =
prep
dense_backend = dense_ad(backend)
Ng = length(row_groups(coloring_result))
Mg = length(row_groups(coloring_result))

pullback_prep_same = prepare_pullback_same_point(
f_or_f!y..., pullback_prep, dense_backend, x, batched_seeds[1], contexts...
Expand All @@ -305,7 +303,7 @@ function _sparse_jacobian_aux!(

for b in eachindex(batched_results[a])
copyto!(
view(compressed_matrix, 1 + ((a - 1) * B + (b - 1)) % Ng, :),
view(compressed_matrix, mod1((a - 1) * B + (b - 1), Mg), :),
vec(batched_results[a][b]),
)
end
Expand Down
94 changes: 56 additions & 38 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,22 @@ function jacobian! end

## Preparation

struct PushforwardJacobianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E<:PushforwardPrep} <:
struct PushforwardJacobianPrep{onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E<:PushforwardPrep} <:
JacobianPrep
batched_seeds::Vector{TD}
batched_results::Vector{TR}
pushforward_prep::E
N::Int
A::Int
end

struct PullbackJacobianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E<:PullbackPrep} <: JacobianPrep
struct PullbackJacobianPrep{onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E<:PullbackPrep} <:
JacobianPrep
batched_seeds::Vector{TD}
batched_results::Vector{TR}
pullback_prep::E
M::Int
A::Int
end

function prepare_jacobian(
Expand Down Expand Up @@ -109,10 +112,10 @@ function _prepare_jacobian_aux(
contexts::Vararg{Context,C},
) where {B,FY,C}
N = length(x)
A = cld(N, B)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for
a in 1:div(N, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
pushforward_prep = prepare_pushforward(
Expand All @@ -121,8 +124,9 @@ function _prepare_jacobian_aux(
TD = eltype(batched_seeds)
TR = eltype(batched_results)
E = typeof(pushforward_prep)
return PushforwardJacobianPrep{B,TD,TR,E}(
batched_seeds, batched_results, pushforward_prep, N
onebatch = length(batched_seeds) == 1
return PushforwardJacobianPrep{onebatch,B,TD,TR,E}(
batched_seeds, batched_results, pushforward_prep, N, A
)
end

Expand All @@ -136,17 +140,20 @@ function _prepare_jacobian_aux(
contexts::Vararg{Context,C},
) where {B,FY,C}
M = length(y)
A = cld(M, B)
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for
a in 1:div(M, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
pullback_prep = prepare_pullback(f_or_f!y..., backend, x, batched_seeds[1], contexts...)
TD = eltype(batched_seeds)
TR = eltype(batched_results)
E = typeof(pullback_prep)
return PullbackJacobianPrep{B,TD,TR,E}(batched_seeds, batched_results, pullback_prep, M)
onebatch = length(batched_seeds) == 1
return PullbackJacobianPrep{onebatch,B,TD,TR,E}(
batched_seeds, batched_results, pullback_prep, M, A
)
end

## One argument
Expand Down Expand Up @@ -221,63 +228,74 @@ end

function _jacobian_aux(
f_or_f!y::FY,
prep::PushforwardJacobianPrep{B},
prep::PushforwardJacobianPrep{onebatch,B},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,B,C}
(; batched_seeds, pushforward_prep, N) = prep
) where {FY,onebatch,B,C}
(; batched_seeds, pushforward_prep, N, A) = prep

pushforward_prep_same = prepare_pushforward_same_point(
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
)

jac_blocks = map(eachindex(batched_seeds)) do a
if onebatch
dy_batch = pushforward(
f_or_f!y...,
pushforward_prep_same,
backend,
x,
batched_seeds[a],
contexts...,
f_or_f!y..., pushforward_prep_same, backend, x, only(batched_seeds), contexts...
)
block = stack_vec_col(dy_batch)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
return crop_last_col_block(block, 1, 1, B, N)
else
jac_blocks = map(eachindex(batched_seeds)) do a
dy_batch = pushforward(
f_or_f!y...,
pushforward_prep_same,
backend,
x,
batched_seeds[a],
contexts...,
)
block = stack_vec_col(dy_batch)
return crop_last_col_block(block, a, A, B, N)
end
block
return reduce(hcat, jac_blocks)
end

jac = reduce(hcat, jac_blocks)
return jac
end

function _jacobian_aux(
f_or_f!y::FY,
prep::PullbackJacobianPrep{B},
prep::PullbackJacobianPrep{onebatch,B},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {FY,B,C}
(; batched_seeds, pullback_prep, M) = prep
) where {FY,onebatch,B,C}
(; batched_seeds, pullback_prep, M, A) = prep

pullback_prep_same = prepare_pullback_same_point(
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
)

jac_blocks = map(eachindex(batched_seeds)) do a
if onebatch
dx_batch = pullback(
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
f_or_f!y..., pullback_prep_same, backend, x, only(batched_seeds), contexts...
)
block = stack_vec_row(dx_batch)
if M % B != 0 && a == lastindex(batched_seeds)
block = block[1:(M - (a - 1) * B), :]
return crop_last_row_block(block, 1, 1, B, M)
else
jac_blocks = map(eachindex(batched_seeds)) do a
dx_batch = pullback(
f_or_f!y...,
pullback_prep_same,
backend,
x,
batched_seeds[a],
contexts...,
)
block = stack_vec_row(dx_batch)
return crop_last_row_block(block, a, A, B, M)
end
block
return reduce(vcat, jac_blocks)
end

jac = reduce(vcat, jac_blocks)
return jac
end

function _jacobian_aux!(
Expand Down Expand Up @@ -307,7 +325,7 @@ function _jacobian_aux!(

for b in eachindex(batched_results[a])
copyto!(
view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), vec(batched_results[a][b])
view(jac, :, mod1((a - 1) * B + (b - 1), N)), vec(batched_results[a][b])
)
end
end
Expand Down Expand Up @@ -342,7 +360,7 @@ function _jacobian_aux!(

for b in eachindex(batched_results[a])
copyto!(
view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), vec(batched_results[a][b])
view(jac, mod1((a - 1) * B + (b - 1), M), :), vec(batched_results[a][b])
)
end
end
Expand Down
36 changes: 20 additions & 16 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ function value_gradient_and_hessian! end

## Preparation

struct HVPGradientHessianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E2<:HVPPrep,E1<:GradientPrep} <:
HessianPrep
struct HVPGradientHessianPrep{
onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E2<:HVPPrep,E1<:GradientPrep
} <: HessianPrep
batched_seeds::Vector{TD}
batched_results::Vector{TR}
hvp_prep::E2
gradient_prep::E1
N::Int
A::Int
end

function prepare_hessian(
Expand All @@ -80,19 +82,20 @@ function _prepare_hessian_aux(
::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {B,F,C}
N = length(x)
A = cld(N, B)
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for
a in 1:div(N, B, RoundUp)
ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...)
gradient_prep = prepare_gradient(f, inner(backend), x, contexts...)
TD = eltype(batched_seeds)
TR = eltype(batched_results)
E2, E1 = typeof(hvp_prep), typeof(gradient_prep)
return HVPGradientHessianPrep{B,TD,TR,E2,E1}(
batched_seeds, batched_results, hvp_prep, gradient_prep, N
onebatch = length(batched_seeds) == 1
return HVPGradientHessianPrep{onebatch,B,TD,TR,E2,E1}(
batched_seeds, batched_results, hvp_prep, gradient_prep, N, A
)
end

Expand All @@ -105,23 +108,24 @@ function hessian(
x,
contexts::Vararg{Context,C},
) where {F,B,C}
(; batched_seeds, hvp_prep, N) = prep
(; batched_seeds, hvp_prep, N, A) = prep

hvp_prep_same = prepare_hvp_same_point(
f, hvp_prep, backend, x, batched_seeds[1], contexts...
)

hess_blocks = map(eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
if onebatch
dg_batch = hvp(f, hvp_prep_same, backend, x, only(batched_seeds), contexts...)
block = stack_vec_col(dg_batch)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
return crop_last_col_block(block, 1, 1, B, N)
else
hess_blocks = map(eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
block = stack_vec_col(dg_batch)
return crop_last_col_block(block, a, A, B, N)
end
block
return reduce(hcat, hess_blocks)
end

hess = reduce(hcat, hess_blocks)
return hess
end

function hessian!(
Expand All @@ -145,7 +149,7 @@ function hessian!(

for b in eachindex(batched_results[a])
copyto!(
view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), vec(batched_results[a][b])
view(hess, :, mod1((a - 1) * B + (b - 1), N)), vec(batched_results[a][b])
)
end
end
Expand Down
16 changes: 16 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)

@inline function crop_last_col_block(block, a::Integer, A::Integer, B::Integer, N::Integer)
if A * B == N || a < A
return block
else
return block[:, 1:(N - (A - 1) * B)]
end
end

@inline function crop_last_row_block(block, a::Integer, A::Integer, B::Integer, M::Integer)
if A * B == M || a < A
return block
else
return block[1:(M - (A - 1) * B), :]
end
end
10 changes: 6 additions & 4 deletions DifferentiationInterface/test/Misc/FromPrimitive/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@ using Test
LOGGING = get(ENV, "CI", "false") == "false"

backends = [ #
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=3)),
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=3)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
]

second_order_backends = [ #
SecondOrder(
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=3)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=2)),
),
SecondOrder(
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=3)),
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=2)),
),
]

Expand Down

0 comments on commit ebbf6a8

Please sign in to comment.