diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index e976e59fc..05a5c3a55 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -62,8 +62,7 @@ function _prepare_sparse_hessian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for - a in 1:div(Ng, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:cld(Ng, B) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] hvp_prep = prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...) @@ -107,7 +106,7 @@ function DI.hessian!( for b in eachindex(batched_results[a]) copyto!( - view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng), + view(compressed_matrix, :, mod1((a - 1) * B + (b - 1), Ng)), vec(batched_results[a][b]), ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 6dd5c0caa..cf014c357 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -113,8 +113,7 @@ function _prepare_sparse_jacobian_aux( seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for - a in 1:div(Ng, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for a in 1:cld(Ng, B) ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward( @@ -146,12 +145,11 @@ function _prepare_sparse_jacobian_aux( decompression_eltype=promote_type(eltype(x), eltype(y)), ) groups = row_groups(coloring_result) - Ng = length(groups) + Mg = length(groups) seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups] compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1) batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Ng)], Val(B)) for - a in 1:div(Ng, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), Mg)], Val(B)) for a in 1:cld(Mg, B) ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = prepare_pullback( @@ -265,7 +263,7 @@ function _sparse_jacobian_aux!( for b in eachindex(batched_results[a]) copyto!( - view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng), + view(compressed_matrix, :, mod1((a - 1) * B + (b - 1), Ng)), vec(batched_results[a][b]), ) end @@ -286,7 +284,7 @@ function _sparse_jacobian_aux!( (; coloring_result, compressed_matrix, batched_seeds, batched_results, pullback_prep) = prep dense_backend = dense_ad(backend) - Ng = length(row_groups(coloring_result)) + Mg = length(row_groups(coloring_result)) pullback_prep_same = prepare_pullback_same_point( f_or_f!y..., pullback_prep, dense_backend, x, batched_seeds[1], contexts... @@ -305,7 +303,7 @@ function _sparse_jacobian_aux!( for b in eachindex(batched_results[a]) copyto!( - view(compressed_matrix, 1 + ((a - 1) * B + (b - 1)) % Ng, :), + view(compressed_matrix, mod1((a - 1) * B + (b - 1), Mg), :), vec(batched_results[a][b]), ) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 0a2304b13..f02968b3a 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -67,19 +67,22 @@ function jacobian! end ## Preparation -struct PushforwardJacobianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E<:PushforwardPrep} <: +struct PushforwardJacobianPrep{onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E<:PushforwardPrep} <: JacobianPrep batched_seeds::Vector{TD} batched_results::Vector{TR} pushforward_prep::E N::Int + A::Int end -struct PullbackJacobianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E<:PullbackPrep} <: JacobianPrep +struct PullbackJacobianPrep{onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E<:PullbackPrep} <: + JacobianPrep batched_seeds::Vector{TD} batched_results::Vector{TR} pullback_prep::E M::Int + A::Int end function prepare_jacobian( @@ -109,10 +112,10 @@ function _prepare_jacobian_aux( contexts::Vararg{Context,C}, ) where {B,FY,C} N = length(x) + A = cld(N, B) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for - a in 1:div(N, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward( @@ -121,8 +124,9 @@ function _prepare_jacobian_aux( TD = eltype(batched_seeds) TR = eltype(batched_results) E = typeof(pushforward_prep) - return PushforwardJacobianPrep{B,TD,TR,E}( - batched_seeds, batched_results, pushforward_prep, N + onebatch = length(batched_seeds) == 1 + return PushforwardJacobianPrep{onebatch,B,TD,TR,E}( + batched_seeds, batched_results, pushforward_prep, N, A ) end @@ -136,17 +140,20 @@ function _prepare_jacobian_aux( contexts::Vararg{Context,C}, ) where {B,FY,C} M = length(y) + A = cld(M, B) seeds = [basis(backend, y, ind) for ind in eachindex(y)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for - a in 1:div(M, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), M)], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = prepare_pullback(f_or_f!y..., backend, x, batched_seeds[1], contexts...) TD = eltype(batched_seeds) TR = eltype(batched_results) E = typeof(pullback_prep) - return PullbackJacobianPrep{B,TD,TR,E}(batched_seeds, batched_results, pullback_prep, M) + onebatch = length(batched_seeds) == 1 + return PullbackJacobianPrep{onebatch,B,TD,TR,E}( + batched_seeds, batched_results, pullback_prep, M, A + ) end ## One argument @@ -221,63 +228,74 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PushforwardJacobianPrep{B}, + prep::PushforwardJacobianPrep{onebatch,B}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} - (; batched_seeds, pushforward_prep, N) = prep +) where {FY,onebatch,B,C} + (; batched_seeds, pushforward_prep, N, A) = prep pushforward_prep_same = prepare_pushforward_same_point( f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + if onebatch dy_batch = pushforward( - f_or_f!y..., - pushforward_prep_same, - backend, - x, - batched_seeds[a], - contexts..., + f_or_f!y..., pushforward_prep_same, backend, x, only(batched_seeds), contexts... ) block = stack_vec_col(dy_batch) - if N % B != 0 && a == lastindex(batched_seeds) - block = block[:, 1:(N - (a - 1) * B)] + return crop_last_col_block(block, 1, 1, B, N) + else + jac_blocks = map(eachindex(batched_seeds)) do a + dy_batch = pushforward( + f_or_f!y..., + pushforward_prep_same, + backend, + x, + batched_seeds[a], + contexts..., + ) + block = stack_vec_col(dy_batch) + return crop_last_col_block(block, a, A, B, N) end - block + return reduce(hcat, jac_blocks) end - - jac = reduce(hcat, jac_blocks) - return jac end function _jacobian_aux( f_or_f!y::FY, - prep::PullbackJacobianPrep{B}, + prep::PullbackJacobianPrep{onebatch,B}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} - (; batched_seeds, pullback_prep, M) = prep +) where {FY,onebatch,B,C} + (; batched_seeds, pullback_prep, M, A) = prep pullback_prep_same = prepare_pullback_same_point( f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + if onebatch dx_batch = pullback( - f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... + f_or_f!y..., pullback_prep_same, backend, x, only(batched_seeds), contexts... ) block = stack_vec_row(dx_batch) - if M % B != 0 && a == lastindex(batched_seeds) - block = block[1:(M - (a - 1) * B), :] + return crop_last_row_block(block, 1, 1, B, M) + else + jac_blocks = map(eachindex(batched_seeds)) do a + dx_batch = pullback( + f_or_f!y..., + pullback_prep_same, + backend, + x, + batched_seeds[a], + contexts..., + ) + block = stack_vec_row(dx_batch) + return crop_last_row_block(block, a, A, B, M) end - block + return reduce(vcat, jac_blocks) end - - jac = reduce(vcat, jac_blocks) - return jac end function _jacobian_aux!( @@ -307,7 +325,7 @@ function _jacobian_aux!( for b in eachindex(batched_results[a]) copyto!( - view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), vec(batched_results[a][b]) + view(jac, :, mod1((a - 1) * B + (b - 1), N)), vec(batched_results[a][b]) ) end end @@ -342,7 +360,7 @@ function _jacobian_aux!( for b in eachindex(batched_results[a]) copyto!( - view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), vec(batched_results[a][b]) + view(jac, mod1((a - 1) * B + (b - 1), M), :), vec(batched_results[a][b]) ) end end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index d56027d50..34e68dfc0 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -60,13 +60,15 @@ function value_gradient_and_hessian! end ## Preparation -struct HVPGradientHessianPrep{B,TD<:NTuple{B},TR<:NTuple{B},E2<:HVPPrep,E1<:GradientPrep} <: - HessianPrep +struct HVPGradientHessianPrep{ + onebatch,B,TD<:NTuple{B},TR<:NTuple{B},E2<:HVPPrep,E1<:GradientPrep +} <: HessianPrep batched_seeds::Vector{TD} batched_results::Vector{TR} hvp_prep::E2 gradient_prep::E1 N::Int + A::Int end function prepare_hessian( @@ -80,10 +82,10 @@ function _prepare_hessian_aux( ::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {B,F,C} N = length(x) + A = cld(N, B) seeds = [basis(backend, x, ind) for ind in eachindex(x)] batched_seeds = [ - ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for - a in 1:div(N, B, RoundUp) + ntuple(b -> seeds[mod1((a - 1) * B + (b - 1), N)], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...) @@ -91,8 +93,9 @@ function _prepare_hessian_aux( TD = eltype(batched_seeds) TR = eltype(batched_results) E2, E1 = typeof(hvp_prep), typeof(gradient_prep) - return HVPGradientHessianPrep{B,TD,TR,E2,E1}( - batched_seeds, batched_results, hvp_prep, gradient_prep, N + onebatch = length(batched_seeds) == 1 + return HVPGradientHessianPrep{onebatch,B,TD,TR,E2,E1}( + batched_seeds, batched_results, hvp_prep, gradient_prep, N, A ) end @@ -105,23 +108,24 @@ function hessian( x, contexts::Vararg{Context,C}, ) where {F,B,C} - (; batched_seeds, hvp_prep, N) = prep + (; batched_seeds, hvp_prep, N, A) = prep hvp_prep_same = prepare_hvp_same_point( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess_blocks = map(eachindex(batched_seeds)) do a - dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) + if onebatch + dg_batch = hvp(f, hvp_prep_same, backend, x, only(batched_seeds), contexts...) block = stack_vec_col(dg_batch) - if N % B != 0 && a == lastindex(batched_seeds) - block = block[:, 1:(N - (a - 1) * B)] + return crop_last_col_block(block, 1, 1, B, N) + else + hess_blocks = map(eachindex(batched_seeds)) do a + dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) + block = stack_vec_col(dg_batch) + return crop_last_col_block(block, a, A, B, N) end - block + return reduce(hcat, hess_blocks) end - - hess = reduce(hcat, hess_blocks) - return hess end function hessian!( @@ -145,7 +149,7 @@ function hessian!( for b in eachindex(batched_results[a]) copyto!( - view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), vec(batched_results[a][b]) + view(hess, :, mod1((a - 1) * B + (b - 1), N)), vec(batched_results[a][b]) ) end end diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 392c7416f..89aecb095 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -1,2 +1,18 @@ stack_vec_col(t::NTuple) = stack(vec, t; dims=2) stack_vec_row(t::NTuple) = stack(vec, t; dims=1) + +@inline function crop_last_col_block(block, a::Integer, A::Integer, B::Integer, N::Integer) + if A * B == N || a < A + return block + else + return block[:, 1:(N - (A - 1) * B)] + end +end + +@inline function crop_last_row_block(block, a::Integer, A::Integer, B::Integer, M::Integer) + if A * B == M || a < A + return block + else + return block[1:(M - (A - 1) * B), :] + end +end diff --git a/DifferentiationInterface/test/Misc/FromPrimitive/test.jl b/DifferentiationInterface/test/Misc/FromPrimitive/test.jl index f55167945..f844d879b 100644 --- a/DifferentiationInterface/test/Misc/FromPrimitive/test.jl +++ b/DifferentiationInterface/test/Misc/FromPrimitive/test.jl @@ -7,18 +7,20 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" backends = [ # + AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=3)), AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)), + AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=3)), AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)), ] second_order_backends = [ # SecondOrder( - AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)), - AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)), + AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=3)), + AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=2)), ), SecondOrder( - AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)), - AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)), + AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=3)), + AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=2)), ), ]