Skip to content

Commit

Permalink
feat: add callable structs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 23, 2024
1 parent 7e3c585 commit 2039879
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 56 deletions.
2 changes: 2 additions & 0 deletions lib/SciMLJacobianOperators/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -19,6 +20,7 @@ ConcreteStructs = "0.2.3"
ConstructionBase = "1.5.8"
DifferentiationInterface = "0.5.17"
FastClosures = "0.3.2"
LinearAlgebra = "1.11.0"
SciMLOperators = "0.3.10"
Setfield = "1.1.1"
julia = "1.10"
Expand Down
148 changes: 92 additions & 56 deletions lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ const DI = DifferentiationInterface
const True = Val(true)
const False = Val(false)

abstract type AbstractJacobianOperator{T} <: AbstractSciMLOperator{T} end

abstract type AbstractMode end

struct VJP <: AbstractMode end
Expand All @@ -21,17 +23,20 @@ struct JVP <: AbstractMode end
flip_mode(::VJP) = JVP()
flip_mode(::JVP) = VJP()

@concrete struct JacobianOperator{iip, T <: Real} <: AbstractSciMLOperator{T}
@concrete struct JacobianOperator{iip, T <: Real} <: AbstractJacobianOperator{T}
mode <: AbstractMode

jvp_op
vjp_op

size
jvp_extras
vjp_extras

output_cache
input_cache
end

SciMLBase.isinplace(::JacobianOperator{iip}) where {iip} = iip

function ConstructionBase.constructorof(::Type{<:JacobianOperator{iip, T}}) where {iip, T}
return JacobianOperator{iip, T}
end
Expand All @@ -42,6 +47,9 @@ Base.size(J::JacobianOperator, d::Integer) = J.size[d]
for op in (:adjoint, :transpose)
@eval function Base.$(op)(operator::JacobianOperator)
@set! operator.mode = flip_mode(operator.mode)
(; output_cache, input_cache) = operator
@set! operator.output_cache = input_cache
@set! operator.input_cache = output_cache
return operator
end
end
Expand All @@ -53,16 +61,66 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
f = prob.f
iip = SciMLBase.isinplace(prob)
T = promote_type(eltype(u), eltype(fu))
fₚ = SciMLBase.JacobianWrapper{iip}(f, prob.p)

vjp_op, vjp_extras = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
jvp_op, jvp_extras = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)

output_cache = iip ? similar(fu, T) : nothing
input_cache = iip ? similar(u, T) : nothing

return JacobianOperator{iip, T}(
JVP(), jvp_op, vjp_op, (length(fu), length(u)), jvp_extras, vjp_extras)
JVP(), jvp_op, vjp_op, (length(fu), length(u)), output_cache, input_cache)
end

prepare_vjp(::Val{true}, args...; kwargs...) = nothing, nothing
function (op::JacobianOperator)(v, u, p)
if op.mode isa VJP
if SciMLBase.isinplace(op)
res = zero(op.output_cache)
op.vjp_op(res, v, u, p)
return res
end
return op.vjp_op(v, u, p)
else
if SciMLBase.isinplace(op)
res = zero(op.output_cache)
op.jvp_op(res, v, u, p)
return res
end
return op.jvp_op(v, u, p)
end
end

function (op::JacobianOperator)(::Number, ::Number, _, __)
error("Inplace Jacobian Operator not possible for scalars.")
end

function (op::JacobianOperator)(Jv, v, u, p)
if op.mode isa VJP
if SciMLBase.isinplace(op)
op.vjp_op(Jv, v, u, p)
return
end
copyto!(Jv, op.vjp_op(v, u, p))
return
else
if SciMLBase.isinplace(op)
op.jvp_op(Jv, v, u, p)
return
end
copyto!(Jv, op.jvp_op(v, u, p))
return
end
end

function VecJacOperator(args...; autodiff = nothing, kwargs...)
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
end

function JacVecOperator(args...; autodiff = nothing, kwargs...)
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
end

prepare_vjp(::Val{true}, args...; kwargs...) = nothing

function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
Expand All @@ -71,20 +129,19 @@ end

function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
SciMLBase.has_vjp(f) && return f.vjp, nothing
SciMLBase.has_vjp(f) && return f.vjp

if autodiff === nothing && SciMLBase.has_jac(f)
if SciMLBase.isinplace(f)
vjp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
vjp_op = @closure (vJ, v, u, p, extras) -> begin
f.jac(extras.jac_cache, u, p)
mul!(vec(vJ), extras.jac_cache', vec(v))
jac_cache = similar(u, eltype(fu), length(fu), length(u))
return @closure (vJ, v, u, p) -> begin
f.jac(jac_cache, u, p)
mul!(vec(vJ), jac_cache', vec(v))
return
end
return vjp_op, vjp_extras
else
vjp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p)' * vec(v), size(u))
return vjp_op, nothing
return @closure (v, u, p) -> reshape(f.jac(u, p)' * vec(v), size(u))
end
end

Expand All @@ -102,21 +159,16 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
fu_cache = copy(fu)
v_fake = copy(fu)
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
vjp_op = @closure (vJ, v, u, p, extras) -> begin
DI.pullback!(
fₚ, extras.fu_cache, reshape(vJ, size(u)), autodiff, u, v, extras.di_extras)
return @closure (vJ, v, u, p) -> begin
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff, u, v, di_extras)
end
return vjp_op, (; di_extras, fu_cache)
else
di_extras = DI.prepare_pullback(f, autodiff, u, fu)
vjp_op = @closure (v, u, p, extras) -> begin
return DI.pullback(f, autodiff, u, v, extras.di_extras)
end
return vjp_op, (; di_extras)
return @closure (v, u, p) -> DI.pullback(f, autodiff, u, v, di_extras)
end
end

prepare_jvp(skip::Val{true}, args...; kwargs...) = nothing, nothing
prepare_jvp(skip::Val{true}, args...; kwargs...) = nothing

function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
Expand All @@ -125,20 +177,18 @@ end

function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
SciMLBase.has_vjp(f) && return f.vjp, nothing
SciMLBase.has_vjp(f) && return f.vjp

if autodiff === nothing && SciMLBase.has_jac(f)
if SciMLBase.isinplace(f)
jvp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
jvp_op = @closure (Jv, v, u, p, extras) -> begin
f.jac(extras.jac_cache, u, p)
mul!(vec(Jv), extras.jac_cache, vec(v))
jac_cache = similar(u, eltype(fu), length(fu), length(u))
return @closure (Jv, v, u, p) -> begin
f.jac(jac_cache, u, p)
mul!(vec(Jv), jac_cache, vec(v))
return
end
return jvp_op, jvp_extras
else
jvp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
return jvp_op, nothing
return @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
end
end

Expand All @@ -155,43 +205,29 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
if SciMLBase.isinplace(f)
fu_cache = copy(fu)
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
jvp_op = @closure (Jv, v, u, p, extras) -> begin
DI.pushforward!(fₚ, extras.fu_cache, reshape(Jv, size(extras.fu_cache)),
autodiff, u, v, extras.di_extras)
return @closure (Jv, v, u, p) -> begin
DI.pushforward!(fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v,
di_extras)
return
end
return jvp_op, (; di_extras, fu_cache)
else
di_extras = DI.prepare_pushforward(f, autodiff, u, u)
jvp_op = @closure (v, u, p, extras) -> begin
return DI.pushforward(f, autodiff, u, v, extras.di_extras)
end
return jvp_op, (; di_extras)
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
return @closure (v, u, p) -> DI.pushforward(fₚ, autodiff, u, v, di_extras)
end
end

function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
SciMLBase.has_vjp(f) && return f.vjp, nothing
SciMLBase.has_jvp(f) && return f.jvp, nothing
SciMLBase.has_jac(f) && return @closure((v, u, p, _)->f.jac(u, p) * v), nothing
SciMLBase.has_vjp(f) && return f.vjp
SciMLBase.has_jvp(f) && return f.jvp
SciMLBase.has_jac(f) && return @closure((v, u, p)->f.jac(u, p) * v)

@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
analytic `vjp` or `jvp` or `jac`."
# TODO: Once DI supports const params we can use `p`
fₚ = Base.Fix2(f, prob.p)
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
op = @closure (v, u, p, extras) -> begin
return DI.derivative(fₚ, autodiff, u, extras.di_extras) * v
end
return op, (; di_extras)
end

function VecJacOperator(args...; autodiff = nothing, kwargs...)
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
end

function JacVecOperator(args...; autodiff = nothing, kwargs...)
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v
end

export JacobianOperator, VecJacOperator, JacVecOperator
Expand Down
Empty file.

0 comments on commit 2039879

Please sign in to comment.