Skip to content

Commit

Permalink
Merge pull request #275 from avik-pal/ap/scimlbase_nullparams
Browse files Browse the repository at this point in the history
Default to the older behavior
  • Loading branch information
ChrisRackauckas authored Nov 17, 2023
2 parents 9a42b50 + 77a337a commit 43b6b6b
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 23 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ Zygote = "0.6"
julia = "1.6"

[extras]
ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -77,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
68 changes: 52 additions & 16 deletions src/differentiation/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,70 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop
(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u)

function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
# NOTE: `use_deprecated_ordering` is a way for external libraries to update to the correct
# style. In the next release, we will drop the first check
function JacFunctionWrapper(f::F, fu_, u, p, t;
use_deprecated_ordering::Val{deporder} = Val(true)) where {F, deporder}
# The warning instead of error ensures a non-breaking change for users relying on an
# undefined / undocumented feature
fu = fu_ === nothing ? copy(u) : copy(fu_)

if deporder
# Check this first else we were breaking things
# In the next breaking release, we will fix the ordering of the checks
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
if iip || oop
if p !== nothing || t !== nothing
Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not
be supplied.
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in
future breaking releases, this check will be performed last, which means
that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given
precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be
given precedence.""", :JacFunctionWrapper)
end
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
end

if t !== nothing
iip = static_hasmethod(f, typeof((fu, u, p, t)))
oop = static_hasmethod(f, typeof((u, p, t)))
if !iip && !oop
@warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined
for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
else
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
throw(ArgumentError("""`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)`
not defined for `f`!"""))
end
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
elseif p !== nothing
iip = static_hasmethod(f, typeof((fu, u, p)))
oop = static_hasmethod(f, typeof((u, p)))
if !iip && !oop
@warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will
fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
else
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
throw(ArgumentError("""`p` is provided but `f(u, p)` or `f(fu, u, p)`
not defined for `f`!"""))
end
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end

if !deporder
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
if !iip && !oop
throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for
`f`!"""))
end
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
else
throw(ArgumentError("""Couldn't determine the function signature of `f` to
construct a JacobianWrapper!"""))
end
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
!iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
5 changes: 3 additions & 2 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ f(du, u) # Otherwise
```
"""
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
ff = JacFunctionWrapper(f, fu, u, p, t)
autodiff = AutoForwardDiff(), tag = DeivVecTag(),
use_deprecated_ordering::Val = Val(true), kwargs...)
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)

cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
Expand Down
4 changes: 2 additions & 2 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ f(du, u) # Otherwise
```
"""
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
autodiff = AutoFiniteDiff(), kwargs...)
ff = JacFunctionWrapper(f, fu, u, p, t)
autodiff = AutoFiniteDiff(), use_deprecated_ordering::Val = Val(true), kwargs...)
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)

if !__internal_oop(ff) && autodiff isa AutoZygote
msg = "Zygote requires an out of place method with signature f(u)."
Expand Down

0 comments on commit 43b6b6b

Please sign in to comment.