Skip to content

Commit

Permalink
Pick correct batch size for hessian (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Oct 12, 2024
1 parent 4e70d75 commit fd7580c
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.11"
version = "0.6.12"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using DifferentiationInterface:
PushforwardSlow,
inner,
multibasis,
pick_batchsize,
pick_hessian_batchsize,
pick_jacobian_batchsize,
pushforward_performance,
unwrap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
function DI.prepare_hessian(
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(dense_ad(backend), length(x))
valB = pick_hessian_batchsize(dense_ad(backend), length(x))
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
end

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
function prepare_hessian(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
valB = pick_batchsize(backend, length(x))
valB = pick_hessian_batchsize(backend, length(x))
return _prepare_hessian_aux(valB, f, backend, x, contexts...)
end

Expand Down
4 changes: 4 additions & 0 deletions DifferentiationInterface/src/utils/batchsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ function pick_jacobian_batchsize(
return pick_batchsize(backend, M)
end

function pick_hessian_batchsize(backend::AbstractADType, N::Integer)
return pick_batchsize(outer(backend), N)
end

threshold_batchsize(backend::AbstractADType, ::Integer) = backend

0 comments on commit fd7580c

Please sign in to comment.