From c29f90aa7e1badca7153e181e3f6bd30e1527baa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Sep 2024 17:26:29 -0400 Subject: [PATCH] feat: integrate SciMLJacobianOperators into NonlinearSolve --- docs/src/devdocs/operators.md | 15 - .../src/SciMLJacobianOperators.jl | 154 +++++++++- src/NonlinearSolve.jl | 3 +- src/globalization/trust_region.jl | 6 +- src/internal/jacobian.jl | 12 +- src/internal/operators.jl | 278 ------------------ 6 files changed, 162 insertions(+), 306 deletions(-) delete mode 100644 src/internal/operators.jl diff --git a/docs/src/devdocs/operators.md b/docs/src/devdocs/operators.md index b96a63f8c..15d00093a 100644 --- a/docs/src/devdocs/operators.md +++ b/docs/src/devdocs/operators.md @@ -6,21 +6,6 @@ NonlinearSolve.AbstractNonlinearSolveOperator ``` -## Jacobian Operators - -```@docs -NonlinearSolve.JacobianOperator -NonlinearSolve.VecJacOperator -NonlinearSolve.JacVecOperator -``` - -### Stateful Jacobian Operators - -```@docs -NonlinearSolve.StatefulJacobianOperator -NonlinearSolve.StatefulJacobianNormalFormOperator -``` - ## Low-Rank Jacobian Operators ```@docs diff --git a/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl b/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl index f16e3c451..dc49fcb46 100644 --- a/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl +++ b/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl @@ -5,6 +5,7 @@ using ConcreteStructs: @concrete using ConstructionBase: ConstructionBase using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure +using LinearAlgebra: LinearAlgebra using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction using SciMLOperators: AbstractSciMLOperator using Setfield: @set! @@ -23,6 +24,57 @@ struct JVP <: AbstractMode end flip_mode(::VJP) = JVP() flip_mode(::JVP) = VJP() +""" + JacobianOperator{iip, T} <: AbstractJacobianOperator{T} <: AbstractSciMLOperator{T} + +A Jacobian Operator Provides both JVP and VJP without materializing either (if possible). + +### Constructor + +```julia +JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing, + vjp_autodiff = nothing, skip_vjp::Val = Val(false), skip_jvp::Val = Val(false)) +``` + +By default, the `JacobianOperator` will compute `JVP`. Use `Base.adjoint` or +`Base.transpose` to switch to `VJP`. + +### Computing the VJP + +Computing the VJP is done according to the following rules: + + - If `f` has a `vjp` method, then we use that. + - If `f` has a `jac` method and no `vjp_autodiff` is provided, then we use `jac * v`. + - If `vjp_autodiff` is provided we using DifferentiationInterface.jl to compute the VJP. + +### Computing the JVP + +Computing the JVP is done according to the following rules: + + - If `f` has a `jvp` method, then we use that. + - If `f` has a `jac` method and no `jvp_autodiff` is provided, then we use `v * jac`. + - If `jvp_autodiff` is provided we using DifferentiationInterface.jl to compute the JVP. + +### Special Case (Number) + +For Number inputs, VJP and JVP are not distinct. Hence, if either `vjp` or `jvp` is +provided, then we use that. If neither is provided, then we use `v * jac` if `jac` is +provided. Finally, we use the respective autodiff methods to compute the derivative +using DifferentiationInterface.jl and multiply by `v`. + +### Methods Provided + +!!! warning + + Currently it is expected that `p` during problem construction is same as `p` during + operator evaluation. This restriction will be lifted in the future. + + - `(op::JacobianOperator)(v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v`. + - `(op::JacobianOperator)(res, v, u, p)`: Computes `∂f(u, p)/∂u * v` or `∂f(u, p)/∂uᵀ * v` + and stores the result in `res`. + +See also [`VecJacOperator`](@ref) and [`JacVecOperator`](@ref). +""" @concrete struct JacobianOperator{iip, T <: Real} <: AbstractJacobianOperator{T} mode <: AbstractMode @@ -65,8 +117,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff) jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff) - output_cache = iip ? similar(fu, T) : nothing - input_cache = iip ? similar(u, T) : nothing + output_cache = similar(fu, T) + input_cache = similar(u, T) return JacobianOperator{iip, T}( JVP(), jvp_op, vjp_op, (length(fu), length(u)), output_cache, input_cache) @@ -112,14 +164,106 @@ function (op::JacobianOperator)(Jv, v, u, p) end end +""" + VecJacOperator(args...; autodiff = nothing, kwargs...) + +Constructs a [`JacobianOperator`](@ref) which only provides the VJP using the +`vjp_autodiff = autodiff`. +""" function VecJacOperator(args...; autodiff = nothing, kwargs...) return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)' end +""" + JacVecOperator(args...; autodiff = nothing, kwargs...) + +Constructs a [`JacobianOperator`](@ref) which only provides the JVP using the +`jvp_autodiff = autodiff`. +""" function JacVecOperator(args...; autodiff = nothing, kwargs...) return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff) end +""" + StatefulJacobianOperator(jac_op::JacobianOperator, u, p) + +Wrapper over a [`JacobianOperator`](@ref) which stores the input `u` and `p` and defines +`mul!` and `*` for computing VJPs and JVPs. +""" +@concrete struct StatefulJacobianOperator{M <: AbstractMode, T} <: + AbstractJacobianOperator{T} + mode::M + jac_op <: JacobianOperator + u + p + + function StatefulJacobianOperator(jac_op::JacobianOperator, u, p) + return new{ + typeof(jac_op.mode), eltype(jac_op), typeof(jac_op), typeof(u), typeof(p)}( + jac_op.mode, jac_op, u, p) + end +end + +Base.size(J::StatefulJacobianOperator) = size(J.jac_op) +Base.size(J::StatefulJacobianOperator, d::Integer) = size(J.jac_op, d) + +for op in (:adjoint, :transpose) + @eval function Base.$(op)(operator::StatefulJacobianOperator) + return StatefulJacobianOperator($(op)(operator.jac_op), operator.u, operator.p) + end +end + +Base.:*(J::StatefulJacobianOperator, v::AbstractArray) = J.jac_op(v, J.u, J.p) + +function LinearAlgebra.mul!( + Jv::AbstractArray, J::StatefulJacobianOperator, v::AbstractArray) + J.jac_op(Jv, v, J.u, J.p) + return Jv +end + +""" + StatefulJacobianNormalFormOperator(vjp_operator, jvp_operator, cache) + +This constructs a Normal Form Jacobian Operator, i.e. it constructs the operator +corresponding to `JᵀJ` where `J` is the Jacobian Operator. This is not meant to be directly +constructed, rather it is constructed with `*` on two [`StatefulJacobianOperator`](@ref)s. +""" +@concrete mutable struct StatefulJacobianNormalFormOperator{T} <: + AbstractJacobianOperator{T} + vjp_operator <: StatefulJacobianOperator{VJP} + jvp_operator <: StatefulJacobianOperator{JVP} + cache +end + +function Base.size(J::StatefulJacobianNormalFormOperator) + return size(J.vjp_operator, 1), size(J.jvp_operator, 2) +end + +function Base.:*(J1::StatefulJacobianOperator{VJP}, J2::StatefulJacobianOperator{JVP}) + cache = J2 * J2.jac_op.input_cache + T = promote_type(eltype(J1), eltype(J2)) + return StatefulJacobianNormalFormOperator{T}(J1, J2, cache) +end + +function LinearAlgebra.mul!(C::StatefulJacobianNormalFormOperator, + A::StatefulJacobianOperator{VJP}, B::StatefulJacobianOperator{JVP}) + C.vjp_operator = A + C.jvp_operator = B + return C +end + +function Base.:*(JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray) + return JᵀJ.vjp_operator * (JᵀJ.jvp_operator * x) +end + +function LinearAlgebra.mul!( + JᵀJx::AbstractArray, JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray) + mul!(JᵀJ.cache, JᵀJ.jvp_operator, x) + mul!(JᵀJx, JᵀJ.vjp_operator, JᵀJ.cache) + return JᵀJx +end + +# Helper Functions prepare_vjp(::Val{true}, args...; kwargs...) = nothing function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem, @@ -206,8 +350,8 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem, fu_cache = copy(fu) di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u) return @closure (Jv, v, u, p) -> begin - DI.pushforward!(fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v, - di_extras) + DI.pushforward!( + fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v, di_extras) return end else @@ -231,5 +375,7 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem, end export JacobianOperator, VecJacOperator, JacVecOperator +export StatefulJacobianOperator +export StatefulJacobianNormalFormOperator end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 781f6eae9..77293ce3f 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -40,6 +40,8 @@ using Preferences: Preferences, @load_preference, @set_preferences! using RecursiveArrayTools: recursivecopy!, recursivefill! using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem, AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats +using SciMLJacobianOperators: JacobianOperator, VecJacOperator, JacVecOperator, + StatefulJacobianOperator, StatefulJacobianNormalFormOperator using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection, ApproximateJacobianSparsity, JacPrototypeSparsityDetection, @@ -72,7 +74,6 @@ include("descent/dogleg.jl") include("descent/damped_newton.jl") include("descent/geodesic_acceleration.jl") -include("internal/operators.jl") include("internal/jacobian.jl") include("internal/forward_diff.jl") include("internal/linear_solve.jl") diff --git a/src/globalization/trust_region.jl b/src/globalization/trust_region.jl index 51b54959e..e6e2cba17 100644 --- a/src/globalization/trust_region.jl +++ b/src/globalization/trust_region.jl @@ -386,11 +386,13 @@ function __internal_init( p1, p2, p3, p4 = __get_parameters(T, alg.method) ϵ = T(1e-8) + reverse_ad = get_concrete_reverse_ad(alg.reverse_ad, prob; check_reverse_mode = true) vjp_operator = alg.method isa RUS.__Yuan || alg.method isa RUS.__Bastin ? - VecJacOperator(prob, fu, u; autodiff = alg.reverse_ad) : nothing + VecJacOperator(prob, fu, u; autodiff = reverse_ad) : nothing + forward_ad = get_concrete_forward_ad(alg.forward_ad, prob; check_forward_mode = true) jvp_operator = alg.method isa RUS.__Bastin ? - JacVecOperator(prob, fu, u; autodiff = alg.forward_ad) : nothing + JacVecOperator(prob, fu, u; autodiff = forward_ad) : nothing if alg.method isa RUS.__Yuan Jᵀfu_cache = StatefulJacobianOperator(vjp_operator, u, prob.p) * _vec(fu) diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index b712b93e2..be3402314 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -25,7 +25,8 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`. - `jvp_autodiff`: Automatic Differentiation or Finite Differencing backend for computing the Jacobian-vector product. - `linsolve`: Linear Solver Algorithm used to determine if we need a concrete jacobian - or if possible we can just use a [`NonlinearSolve.JacobianOperator`](@ref) instead. + or if possible we can just use a [`SciMLJacobianOperators.JacobianOperator`](@ref) + instead. """ @concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip} J @@ -85,8 +86,7 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, __similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) : copy(f.jac_prototype) elseif f.jac_prototype === nothing - zero(init_jacobian( - jac_cache; preserve_immutable = Val(true))) + zero(init_jacobian(jac_cache; preserve_immutable = Val(true))) else f.jac_prototype end @@ -114,9 +114,9 @@ end @inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p) @inline function (cache::JacobianCache)(::Nothing) - J = cache.J - J isa JacobianOperator && return StatefulJacobianOperator(J, cache.u, cache.p) - return J + cache.J isa JacobianOperator && + return StatefulJacobianOperator(cache.J, cache.u, cache.p) + return cache.J end function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p) diff --git a/src/internal/operators.jl b/src/internal/operators.jl deleted file mode 100644 index 5bbfcb0bf..000000000 --- a/src/internal/operators.jl +++ /dev/null @@ -1,278 +0,0 @@ -# We want a general form of this in SciMLOperators. However, we use this extensively and we -# can have a custom implementation here till -# https://github.com/SciML/SciMLOperators.jl/issues/223 is resolved. -""" - JacobianOperator{vjp, iip, T} <: AbstractNonlinearSolveOperator{T} - -A Jacobian Operator Provides both JVP and VJP without materializing either (if possible). - -This is an internal operator, and is not guaranteed to have a stable API. It might even be -moved out of NonlinearSolve.jl in the future, without a deprecation cycle. Usage of this -outside NonlinearSolve.jl (by everyone except Avik) is strictly prohibited. - -`T` denotes if the Jacobian is transposed or not. `T = true` means that the Jacobian is -transposed, and `T = false` means that the Jacobian is not transposed. - -### Constructor - -```julia -JacobianOperator( - prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing, vjp_autodiff = nothing, - skip_vjp::Val{NoVJP} = False, skip_jvp::Val{NoJVP} = False) where {NoVJP, NoJVP} -``` - -See also [`NonlinearSolve.VecJacOperator`](@ref) and -[`NonlinearSolve.JacVecOperator`](@ref). -""" -@concrete struct JacobianOperator{vjp, iip, T} <: AbstractNonlinearSolveOperator{T} - jvp_op - vjp_op - - input_cache - output_cache -end - -Base.size(J::JacobianOperator) = prod(size(J.output_cache)), prod(size(J.input_cache)) -function Base.size(J::JacobianOperator, d::Integer) - d == 1 && return prod(size(J.output_cache)) - d == 2 && return prod(size(J.input_cache)) - error("Invalid dimension $d for JacobianOperator") -end - -for op in (:adjoint, :transpose) - @eval function Base.$(op)(operator::JacobianOperator{vjp, iip, T}) where {vjp, iip, T} - return JacobianOperator{!vjp, iip, T}( - operator.jvp_op, operator.vjp_op, operator.output_cache, operator.input_cache) - end -end - -function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing, - vjp_autodiff = nothing, skip_vjp::Val{NoVJP} = False, - skip_jvp::Val{NoJVP} = False) where {NoVJP, NoJVP} - f = prob.f - iip = isinplace(prob) - uf = JacobianWrapper{iip}(f, prob.p) - - vjp_op = if NoVJP - nothing - elseif SciMLBase.has_vjp(f) - f.vjp - elseif u isa Number # Ignore vjp directives - if ForwardDiff.can_dual(typeof(u)) && (vjp_autodiff === nothing || - vjp_autodiff isa AutoForwardDiff || - vjp_autodiff isa AutoPolyesterForwardDiff) - # VJP is same as VJP for scalars - @closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) - else - @closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v - end - else - vjp_autodiff = __get_nonsparse_ad(get_concrete_reverse_ad( - vjp_autodiff, prob, False)) - if vjp_autodiff isa AutoZygote - iip && error("`AutoZygote` cannot handle inplace problems.") - @closure (v, u, p) -> auto_vecjac(uf, u, v) - elseif vjp_autodiff isa AutoFiniteDiff - if iip - cache1 = zero(fu) - cache2 = zero(fu) - @closure (Jv, v, u, p) -> num_vecjac!(Jv, uf, u, v, cache1, cache2) - else - @closure (v, u, p) -> num_vecjac(uf, __mutable(u), v) - end - else - error("`vjp_autodiff` = `$(typeof(vjp_autodiff))` is not supported in \ - JacobianOperator.") - end - end - - jvp_op = if NoJVP - nothing - elseif SciMLBase.has_jvp(f) - f.jvp - elseif u isa Number # Ignore jvp directives - # Only ForwardDiff if user didn't override - if ForwardDiff.can_dual(typeof(u)) && (jvp_autodiff === nothing || - jvp_autodiff isa AutoForwardDiff || - jvp_autodiff isa AutoPolyesterForwardDiff) - @closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) - else - @closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v - end - else - jvp_autodiff = __get_nonsparse_ad(get_concrete_forward_ad( - jvp_autodiff, prob, False)) - if jvp_autodiff isa AutoForwardDiff || jvp_autodiff isa AutoPolyesterForwardDiff - if iip - # FIXME: Technically we should propagate the tag but ignoring that for now - cache1 = Dual{typeof(ForwardDiff.Tag(uf, eltype(u))), eltype(u), - 1}.(zero(u), ForwardDiff.Partials.(tuple.(u))) - cache2 = Dual{typeof(ForwardDiff.Tag(uf, eltype(fu))), eltype(fu), - 1}.(zero(fu), ForwardDiff.Partials.(tuple.(fu))) - @closure (Jv, v, u, p) -> auto_jacvec!(Jv, uf, u, v, cache1, cache2) - else - @closure (v, u, p) -> auto_jacvec(uf, u, v) - end - elseif jvp_autodiff isa AutoFiniteDiff - if iip - cache1 = zero(fu) - cache2 = zero(u) - @closure (Jv, v, u, p) -> num_jacvec!(Jv, uf, u, v, cache1, cache2) - else - @closure (v, u, p) -> num_jacvec(uf, u, v) - end - else - error("`jvp_autodiff` = `$(typeof(jvp_autodiff))` is not supported in \ - JacobianOperator.") - end - end - - return JacobianOperator{false, iip, promote_type(eltype(fu), eltype(u))}( - jvp_op, vjp_op, u, fu) -end - -""" - VecJacOperator(args...; autodiff = nothing, kwargs...) - -Constructs a [`JacobianOperator`](@ref) which only provides the VJP using the -`vjp_autodiff = autodiff`. - -This is very similar to `SparseDiffTools.VecJac` but is geared towards -[`NonlinearProblem`](@ref)s. For arguments and keyword arguments see -[`JacobianOperator`](@ref). -""" -function VecJacOperator(args...; autodiff = nothing, kwargs...) - return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)' -end - -""" - JacVecOperator(args...; autodiff = nothing, kwargs...) - -Constructs a [`JacobianOperator`](@ref) which only provides the JVP using the -`jvp_autodiff = autodiff`. - -This is very similar to `SparseDiffTools.JacVec` but is geared towards -[`NonlinearProblem`](@ref)s. For arguments and keyword arguments see -[`JacobianOperator`](@ref). -""" -function JacVecOperator(args...; autodiff = nothing, kwargs...) - return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff) -end - -function (op::JacobianOperator{vjp, iip})(v, u, p) where {vjp, iip} - if vjp - if iip - res = zero(op.output_cache) - op.vjp_op(res, v, u, p) - return res - else - return op.vjp_op(v, u, p) - end - else - if iip - res = zero(op.output_cache) - op.jvp_op(res, v, u, p) - return res - else - return op.jvp_op(v, u, p) - end - end -end - -# Prevent Ambiguity -function (op::JacobianOperator{vjp, iip})(Jv::Number, v::Number, u, p) where {vjp, iip} - error("Inplace Jacobian Operator not possible for scalars.") -end - -function (op::JacobianOperator{vjp, iip})(Jv, v, u, p) where {vjp, iip} - if vjp - if iip - op.vjp_op(Jv, v, u, p) - else - copyto!(Jv, op.vjp_op(v, u, p)) - end - else - if iip - op.jvp_op(Jv, v, u, p) - else - copyto!(Jv, op.jvp_op(v, u, p)) - end - end - return Jv -end - -""" - StatefulJacobianOperator(jac_op::JacobianOperator, u, p) - -Wrapper over a [`JacobianOperator`](@ref) which stores the input `u` and `p` and defines -`mul!` and `*` for computing VJPs and JVPs. -""" -@concrete struct StatefulJacobianOperator{ - vjp, iip, T, J <: JacobianOperator{vjp, iip, T}} <: AbstractNonlinearSolveOperator{T} - jac_op::J - u - p -end - -Base.size(J::StatefulJacobianOperator) = size(J.jac_op) -Base.size(J::StatefulJacobianOperator, d::Integer) = size(J.jac_op, d) - -for op in (:adjoint, :transpose) - @eval function Base.$op(operator::StatefulJacobianOperator) - return StatefulJacobianOperator($(op)(operator.jac_op), operator.u, operator.p) - end -end - -Base.:*(J::StatefulJacobianOperator, v::AbstractArray) = J.jac_op(v, J.u, J.p) -function Base.:*(J_op::StatefulJacobianOperator{vjp, iip, T, J, <:Number}, - v::Number) where {vjp, iip, T, J} - return J_op.jac_op(v, J_op.u, J_op.p) -end - -function LinearAlgebra.mul!( - Jv::AbstractArray, J::StatefulJacobianOperator, v::AbstractArray) - J.jac_op(Jv, v, J.u, J.p) - return Jv -end - -""" - StatefulJacobianNormalFormOperator(vjp_operator, jvp_operator, cache) - -This constructs a Normal Form Jacobian Operator, i.e. it constructs the operator -corresponding to `JᵀJ` where `J` is the Jacobian Operator. This is not meant to be directly -constructed, rather it is constructed with `*` on two [`StatefulJacobianOperator`](@ref)s. -""" -@concrete mutable struct StatefulJacobianNormalFormOperator{T} <: - AbstractNonlinearSolveOperator{T} - vjp_operator - jvp_operator - cache -end - -function Base.size(J::StatefulJacobianNormalFormOperator) - return size(J.vjp_operator, 1), size(J.jvp_operator, 2) -end - -function Base.:*(J1::StatefulJacobianOperator{true}, J2::StatefulJacobianOperator{false}) - cache = J2 * J2.jac_op.input_cache - T = promote_type(eltype(J1), eltype(J2)) - return StatefulJacobianNormalFormOperator{T}(J1, J2, cache) -end - -function LinearAlgebra.mul!(C::StatefulJacobianNormalFormOperator, - A::StatefulJacobianOperator{true}, B::StatefulJacobianOperator{false}) - C.vjp_operator = A - C.jvp_operator = B - return C -end - -function Base.:*(JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray) - return JᵀJ.vjp_operator * (JᵀJ.jvp_operator * x) -end - -function LinearAlgebra.mul!( - JᵀJx::AbstractArray, JᵀJ::StatefulJacobianNormalFormOperator, x::AbstractArray) - mul!(JᵀJ.cache, JᵀJ.jvp_operator, x) - mul!(JᵀJx, JᵀJ.vjp_operator, JᵀJ.cache) - return JᵀJx -end