From 7607ec2e43c25f1722a9e4f42e88195ba020e07b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:45:07 +0200 Subject: [PATCH] Custom stacking for StaticArrays (#564) * Improve type stability tests and benchmarking * Remove `first_order` and `second_order` * Docs * Zero allocs * Fixes * Call count * Fix * Fix * Add count calls * Default count calls * Fix * Custom stacking for StaticArrays * Bump * Clearer modulo * Woops * Undo mo1 --- DifferentiationInterface/Project.toml | 5 ++++- .../DifferentiationInterfaceStaticArraysExt.jl | 10 ++++++++++ .../src/DifferentiationInterface.jl | 1 + DifferentiationInterface/src/first_order/jacobian.jl | 4 ++-- DifferentiationInterface/src/second_order/hessian.jl | 2 +- DifferentiationInterface/src/utils/linalg.jl | 2 ++ 6 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl create mode 100644 DifferentiationInterface/src/utils/linalg.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ce7502c68..06169c44e 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.9" +version = "0.6.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -20,6 +20,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -37,6 +38,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" DifferentiationInterfaceReverseDiffExt = "ReverseDiff" DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" +DifferentiationInterfaceStaticArraysExt = "StaticArrays" DifferentiationInterfaceSymbolicsExt = "Symbolics" DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] @@ -56,6 +58,7 @@ PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" +StaticArrays = "1.9.7" SparseMatrixColorings = "0.4.5" Symbolics = "5.27.1, 6" Tracker = "0.2.33" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl new file mode 100644 index 000000000..53d6e7aff --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl @@ -0,0 +1,10 @@ +module DifferentiationInterfaceStaticArraysExt + +import DifferentiationInterface as DI +using StaticArrays: SArray + +function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B} + return hcat(map(vec, t)...) +end + +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 8cee0ad18..65c434b5c 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -43,6 +43,7 @@ include("utils/check.jl") include("utils/exceptions.jl") include("utils/printing.jl") include("utils/context.jl") +include("utils/linalg.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 3984dd49f..bde1c5255 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -241,7 +241,7 @@ function _jacobian_aux( batched_seeds[a], contexts..., ) - block = stack(vec, dy_batch; dims=2) + block = stack_vec_col(dy_batch) if N % B != 0 && a == lastindex(batched_seeds) block = block[:, 1:(N - (a - 1) * B)] end @@ -269,7 +269,7 @@ function _jacobian_aux( dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) - block = stack(vec, dx_batch; dims=1) + block = stack_vec_row(dx_batch) if M % B != 0 && a == lastindex(batched_seeds) block = block[1:(M - (a - 1) * B), :] end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index a40c2efd5..025c69315 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -113,7 +113,7 @@ function hessian( hess_blocks = map(eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) - block = stack(vec, dg_batch; dims=2) + block = stack_vec_col(dg_batch) if N % B != 0 && a == lastindex(batched_seeds) block = block[:, 1:(N - (a - 1) * B)] end diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl new file mode 100644 index 000000000..392c7416f --- /dev/null +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -0,0 +1,2 @@ +stack_vec_col(t::NTuple) = stack(vec, t; dims=2) +stack_vec_row(t::NTuple) = stack(vec, t; dims=1)