-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PythonCall Extension #519
Add PythonCall Extension #519
Conversation
Codecov Report
@@ Coverage Diff @@
## master #519 +/- ##
==========================================
- Coverage 54.25% 53.63% -0.63%
==========================================
Files 51 52 +1
Lines 3854 3897 +43
==========================================
- Hits 2091 2090 -1
- Misses 1763 1807 +44
... and 5 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
I think this makes sense. The other place that it could be would be the DiffEqBase solve.jl, which intercepts right before solve and makes conversions so that the problem is solvable by a given solver (and does a bunch of error throws according to the options). However, that is done at a later time because those conversions can be dependent on the solver that is chosen. Here, this would be conversions that should just always happen. So I guess this is needed here, though it is a bit tedious to add it everywhere |
Here are my internal TDD notes for posterity's sake and to draw from when writing tests# Patch 0
using SciMLBase
SciMLBase.numargs(f::ComposedFunction) = SciMLBase.numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506
# Test 1
using DifferentialEquations, PythonCall
pyexec("""
from juliacall import Main
de = Main.seval("DifferentialEquations")
def f(u,p,t):
return -u
u0 = 0.5
tspan = (0., 1.)
prob = de.ODEProblem(f, u0, tspan)
sol = de.solve(prob)
""", @__MODULE__)
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution
# Test 1 Failure 1: "Detected an in-place function with an initial condition of type Number or SArray."
# Patch 1
using PythonCall: Py, pyimport, hasproperty, pyconvert
using SciMLBase: SciMLBase
# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end
# Test 1 Failure 2 "ERROR: Python: Julia: MethodError: Cannot `convert` an object of type Py to an object of type Float64"
# Patch 2
function SciMLBase.ODEProblem(f::Py, u0, tspan, args...)
ODEProblem(Base.Fix1(pyconvert, Any) ∘ f, pyconvert(Any, u0), pyconvert(Any, tspan), pyconvert.(Any, args)...)
end
# Test 1 Pass.
# Test 2
pyexec("""
def f(u,p,t):
x, y, z = u
sigma, rho, beta = p
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,8/3]
prob = de.ODEProblem(f, u0, tspan, p)
sol = de.solve(prob,saveat=0.01)
""", @__MODULE__)
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution
# Patch 3 (replaces patch 2)
using PythonCall: pyisinstance
_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x
function SciMLBase.ODEProblem(f::Py, u0, tspan, args...)
ODEProblem(_pyconvert ∘ f, _pyconvert(u0), _pyconvert(tspan), pyconvert.(Any, args)...)
end
# Test 2 passes
# Test 2 continued
pyexec("""
import matplotlib.pyplot as plt
plt.plot(sol.t, de.transpose(de.stack(sol.u))) # :( fails without the conversion
plt.show()
""", @__MODULE__)
# Test 3
@pyexec """
jul_f = Main.seval(""\"
function f(du,u,p,t)
x, y, z = u
sigma, rho, beta = p
du[1] = sigma * (y - x)
du[2] = x * (rho - z) - y
du[3] = x * y - beta * z
end""\")
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.ODEProblem(jul_f, u0, tspan, p)
sol = de.solve(prob)
"""
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution
# Test 3 failure 1: "ERROR: Python: Julia: MethodError: no method matching oneunit(::Type{Any})"
# Patch 4 (replaces patch 3)
using PythonCall: pyisinstance, Py, PyList, pybuiltins, pyconvert
_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x
SciMLBase.prepare_u0(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_f(f::Py) = _pyconvert ∘ f
# upstreamed
@eval SciMLBase begin
prepare_u0(u0) = u0
prepare_f(f) = f
function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
ODEProblem{isinplace(f)}(prepare_f(f), prepare_u0(u0), tspan, args...; kwargs...)
end
function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
_f = prepare_f(f)
iip = isinplace(_f, 4)
_u0 = prepare_u0(u0)
_tspan = promote_tspan(tspan)
__f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(_f)
ODEProblem{isinplace(__f)}(__f, _u0, _tspan, p; kwargs...)
end
end
# Test 3 passes
# Test 4
pyexec("""
def f(u,p,t):
return 1.01*u
def g(u,p,t):
return 0.87*u
u0 = 0.5
tspan = (0.0,1.0)
prob = de.SDEProblem(f,g,u0,tspan)
sol = de.solve(prob,reltol=1e-3,abstol=1e-3)
""", @__MODULE__)
# Test 4 failure 1: "ERROR: Python: TypeError: 'float' object is not iterable"
# Patch 5 (upstreamed)
@eval SciMLBase begin
function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...)
SDEProblem{isinplace(f)}(prepare_f(f), prepare_f(g), prepare_u0(u0), tspan, p; kwargs...)
end
function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
_g = prepare_f(g)
SDEProblem(SDEFunction(prepare_f(f), _g), _g, prepare_u0(u0), tspan, p; kwargs...)
end
end
# Test 4 Pass.
# Patch Summary
# SciMLBase
numargs(f::ComposedFunction) = numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506
"""
prepare_initial_state(u0) = u0
Whenever an initial state is passed to the SciML ecosystem, is passed to
`prepare_initial_state` and the result is used instead. If you define a
type which cannot be used as a state but can be converted to something that
can be, then you may define `prepare_initial_state(x::YourType) = ...`.
!!! warning
This function is experimental and may be removed in the future.
See also: `prepare_function`.
"""
prepare_initial_state(u0) = u0
"""
prepare_function(f) = f
Whenever a function is passed to the SciML ecosystem, is passed to
`prepare_function` and the result is used instead. If you define a type which
cannot be used as a function in the SciML ecosystem but can be converted to
something that can be, then you may define `prepare_function(x::YourType) = ...`.
!!! warning
This function is experimental and may be removed in the future.
See also: `prepare_initial_state`.
"""
prepare_function(f) = f
# begin approx
function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
ODEProblem{isinplace(f)}(f, prepare_initial_state(u0), tspan, args...; kwargs...)
end
function ODEFunction(f; kwargs...)
_f = prepare_function(f)
ODEFunction{isinplace(_f, 4), FullSpecialize}(_f; kwargs...)
end
function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...)
SDEProblem{isinplace(f)}(f, g, prepare_initial_state(u0), tspan, p; kwargs...)
end
function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
_f = prepare_function(f)
_g = prepare_function(g)
SDEProblem(SDEFunction(_f, _g), _g, u0, tspan, p; kwargs...)
end
...
# end approx
# SciMLBase / PythonCall extension
using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins
using SciMLBase: SciMLBase
# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end
_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x
SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_function(f::Py) = _pyconvert ∘ f |
TODO: add prepare_u0 to all problems (currently just ODE and SDE) TODO: add tests TODO: add package extension boilerplate (e.g. update Project.toml)
ff257cf
to
be6174a
Compare
Force push was a clean rebase onto master |
I agree that delaying these conversions is, unfortunately, probably not a great idea. For example, I think it would be reasonable when choosing a solver to perform a query on Also, looking at DiffEqBase/src/solve.jl, it seems that it would still be a bit messy to extract u0 and all user functions and convert them. I'll proceed with adding these conversions to all entrypoints I can find. |
Makes sense |
…he spooky action at a distance
CodeCov claims this has pretty high patch coverage, but that is sort of a lie. In theory, this PR enables full usage of all of DifferentialEquations via PythonCall. To test that claim would require rewriting all downstream tests in Python. That's probably not worth doing, but I want to be clear that if I failed to insert a call to |
@avik-pal @ErikQQY there's still one last remake issue with BVPs: https://github.com/SciML/SciMLBase.jl/actions/runs/6439762476/job/17487819266?pr=519#step:6:880 |
The failing tests are about |
That PR is stale though since the function form was updated and the |
The goal of this PR is to make PythonCall and the DifferentialEquations ecosystem fully compatible, making SciML/diffeqpy#118 trivial.
I think I've implemented the nontrivial design decisions that have to be made, so this is ready for review. If the design looks good and once #502 merges, I'll finish the details to get this to a mergeable state.