Skip to content

Commit

Permalink
Merge #2074
Browse files Browse the repository at this point in the history
2074: Use SciMLBase over OrdinaryDiffEq where possible r=charleskawczynski a=charleskawczynski

This PR is an attempt to remove our dependency on OrdinaryDiffEq, in favor of using the lighter weight package, SciMLBase.

Unfortunately, we can't remove the dependency because, for example, `ODE.Tsit5()` is an ODE algorithm that is (I think) first defined in OrdinaryDiffEq. Same goes for a bunch of methods defined in `type_getters.jl`:

```julia
is_explicit_CTS_algo_type(alg_or_tableau) =
    alg_or_tableau <: CTS.ERKAlgorithmName

is_imex_CTS_algo_type(alg_or_tableau) =
    alg_or_tableau <: CTS.IMEXARKAlgorithmName

is_implicit_type(::typeof(ODE.IMEXEuler)) = true
is_implicit_type(alg_or_tableau) =
    alg_or_tableau <: Union{
        ODE.OrdinaryDiffEqImplicitAlgorithm,
        ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
    } || is_imex_CTS_algo_type(alg_or_tableau)

is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true
is_ordinary_diffeq_newton(alg_or_tableau) =
    alg_or_tableau <: Union{
        ODE.OrdinaryDiffEqNewtonAlgorithm,
        ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
    }

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::SciMLBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::SciMLBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
    !(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))

additional_integrator_kwargs(::SciMLBase.AbstractODEAlgorithm) = (;
    adaptive = false,
    progress = isinteractive(),
    progress_steps = isinteractive() ? 1 : 1000,
)
additional_integrator_kwargs(::CTS.DistributedODEAlgorithm) = (;
    kwargshandle = ODE.KeywordArgSilent, # allow custom kwargs
    adjustfinal = true,
    # TODO: enable progress bars in ClimaTimeSteppers
)

is_cts_algo(::SciMLBase.AbstractODEAlgorithm) = false
is_cts_algo(::CTS.DistributedODEAlgorithm) = true
```
A bunch of these ODE. types are (I think) first defined in OrdinaryDiffEq.jl.

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski authored Sep 8, 2023
2 parents 7768031 + cf95924 commit 4205c01
Show file tree
Hide file tree
Showing 21 changed files with 75 additions and 74 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -38,6 +37,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RRTMGP = "a01a1ee8-cea4-48fc-987c-fc7878d79da1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -59,7 +59,6 @@ ClimaTimeSteppers = "0.7"
CloudMicrophysics = "0.13"
Colors = "0.12"
Dierckx = "0.5"
DiffEqBase = "6"
DiffEqCallbacks = "2"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -76,6 +75,7 @@ OrdinaryDiffEq = "5, 6"
Pkg = "1.8"
RRTMGP = "0.9"
RootSolvers = "0.2, 0.3, 0.4"
SciMLBase = "1"
StaticArrays = "1"
StatsBase = "0.33"
SurfaceFluxes = "0.7"
Expand Down
12 changes: 6 additions & 6 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "484a0d15ba3f1cc8d5c8863c4fe1999899c42df7"
project_hash = "7fc686bc0a71a5f83e48b1285d03c4910b039279"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -247,7 +247,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -697,9 +697,9 @@ version = "0.16.16"

[[deps.HDF5_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
git-tree-sha1 = "10c72358aaaa5cd6bc7cc39b95e6eadf92f5a336"
git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739"
uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
version = "1.14.2+0"
version = "1.14.2+1"

[[deps.HostCPUFeatures]]
deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"]
Expand Down Expand Up @@ -1574,9 +1574,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
Expand Down
8 changes: 4 additions & 4 deletions examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "bbd263a7239c44698f2ead1e0abc80d97dea3133"
project_hash = "29934b734692261a625d91c36cff961d40b7b683"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -286,7 +286,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -2349,9 +2349,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down
1 change: 0 additions & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
4 changes: 2 additions & 2 deletions examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ using Statistics: mean
import ClimaAtmos.Parameters as CAP
import Thermodynamics as TD
import ClimaComms
using OrdinaryDiffEq
using SciMLBase
using PrettyTables
using DiffEqCallbacks
import DiffEqCallbacks as DECB
using JLD2
using NCDatasets
using ClimaTimeSteppers
Expand Down
12 changes: 6 additions & 6 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "9556cf181fd1b65b5db679b93b4882215a7f787a"
project_hash = "967da1a008eca13e2322097c9c4f6c2a0c6d67ad"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -297,7 +297,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -2431,9 +2431,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down Expand Up @@ -2623,9 +2623,9 @@ version = "1.3.0"

[[deps.TypedSyntax]]
deps = ["CodeTracking", "JuliaSyntax"]
git-tree-sha1 = "34f0ab1aa1b869840cfc4e1e33074030e90ece7e"
git-tree-sha1 = "79ea8a4993ed5d341580c4044433e0259fceb4c6"
uuid = "d265eb64-f81a-44ad-a842-4247ee1503de"
version = "1.2.2"
version = "1.2.3"

[[deps.URIs]]
git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0"
Expand Down
1 change: 0 additions & 1 deletion perf/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
6 changes: 3 additions & 3 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ integrator = CA.get_integrator(config)

(; parsed_args) = config

import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaTimeSteppers as CTS
ODE.step!(integrator) # compile first
SciMLBase.step!(integrator) # compile first

(; sol, u, p, dt, t) = integrator

Expand Down Expand Up @@ -41,7 +41,7 @@ trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args
trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!");
trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!");
trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, (X, u, p, t), "hyperdiffusion_tendency!");
trials["step!"] = get_trial(ODE.step!, (integrator, ), "step!");
trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!");
#! format: on

table_summary = OrderedCollections.OrderedDict()
Expand Down
12 changes: 6 additions & 6 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ integrator = CA.get_integrator(config)
# The callbacks flame graph is very expensive, so only do 2 steps.
@info "running step"

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # compile first
import SciMLBase
SciMLBase.step!(integrator) # compile first
CA.call_all_callbacks!(integrator) # compile callbacks
import Profile, ProfileCanvas
(; output_dir, job_id) = integrator.p.simulation
Expand All @@ -18,7 +18,7 @@ mkpath(output_dir)

@info "collect profile"
Profile.clear()
prof = Profile.@profile OrdinaryDiffEq.step!(integrator)
prof = Profile.@profile SciMLBase.step!(integrator)
results = Profile.fetch()
Profile.clear()

Expand All @@ -32,7 +32,7 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
# use new allocation profiler
@info "collecting allocations"
Profile.Allocs.clear()
Profile.Allocs.@profile sample_rate = 0.01 OrdinaryDiffEq.step!(integrator)
Profile.Allocs.@profile sample_rate = 0.01 SciMLBase.step!(integrator)
results = Profile.Allocs.fetch()
Profile.Allocs.clear()
profile = ProfileCanvas.view_allocs(results)
Expand All @@ -49,8 +49,8 @@ buffer = occursin("threaded", job_id) ? 1.4 : 1


## old allocation profiler (TODO: remove this)
allocs = @allocated OrdinaryDiffEq.step!(integrator)
@timev OrdinaryDiffEq.step!(integrator)
allocs = @allocated SciMLBase.step!(integrator)
@timev SciMLBase.step!(integrator)
@info "`allocs ($job_id)`: $(allocs)"

allocs_limit = Dict()
Expand Down
6 changes: 3 additions & 3 deletions perf/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
JET.@test_opt OrdinaryDiffEq.step!(integrator)
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors
JET.@test_opt SciMLBase.step!(integrator)
4 changes: 2 additions & 2 deletions perf/jet_report_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ macro n_failures(ex)
)
end

import OrdinaryDiffEq
import SciMLBase
import ClimaAtmos as CA
n = Dict()
Y = integrator.u;
Expand All @@ -20,7 +20,7 @@ t = integrator.t;
Yₜ = similar(Y);
ref_Y = similar(Y);
#! format: off
n["step!"] = @n_failures OrdinaryDiffEq.step!(integrator);
n["step!"] = @n_failures SciMLBase.step!(integrator);
n["limited_tendency!"] = @n_failures CA.limited_tendency!(Yₜ, Y, p, t);
n["horizontal_advection_tendency!"] = @n_failures CA.horizontal_advection_tendency!(Yₜ, Y, p, t);
n["horizontal_tracer_advection_tendency!"] = @n_failures CA.horizontal_tracer_advection_tendency!(Yₜ, Y, p, t);
Expand Down
8 changes: 4 additions & 4 deletions perf/jet_test_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors

# Suggested in: https://github.com/aviatesk/JET.jl/issues/455
macro n_failures(ex)
Expand All @@ -20,13 +20,13 @@ end

using Test
@testset "Test N-jet failures" begin
n = @n_failures OrdinaryDiffEq.step!(integrator)
n = @n_failures SciMLBase.step!(integrator)
# This test is intended to provide some friction when we
# add code to our tendency function that results in degraded
# inference. By increasing this counter, we acknowledge that
# we have introduced an inference failure. We hope to drive
# this number down to 0.
n_allowed_failures = 256
n_allowed_failures = 680
@test n n_allowed_failures
if n < n_allowed_failures
@info "Please update the n-failures to $n"
Expand Down
20 changes: 10 additions & 10 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import DiffEqCallbacks
import SciMLBase
#####
##### Callback helpers
#####
Expand All @@ -7,7 +7,7 @@ function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
previous_step = Ref(0)
@assert n Inf "Adding callback that never gets called!"
cb! = AtmosCallback(f!, EveryNSteps(n))
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2]),
Expand All @@ -31,7 +31,7 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
next_t[] = min(next_t[], t_end)
end
end
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
Expand All @@ -50,42 +50,42 @@ function callback_from_affect(affect!)
x = getproperty(affect!, p)
if x isa AtmosCallback
return x
elseif x isa DiffEqCallbacks.SavedValues
elseif x isa DECB.SavedValues
return x
end
end
error("Callback not found in $(affect!)")
end
function atmos_callbacks(cbs::ODE.CallbackSet)
function atmos_callbacks(cbs::SciMLBase.CallbackSet)
all_cbs = [cbs.continuous_callbacks..., cbs.discrete_callbacks...]
callback_objs = map(cb -> callback_from_affect(cb.affect!), all_cbs)
filter!(x -> !(x isa DiffEqCallbacks.SavedValues), callback_objs)
filter!(x -> !(x isa DECB.SavedValues), callback_objs)
return callback_objs
end

n_measured_calls(integrator) = n_measured_calls(integrator.callback)
n_measured_calls(cbs::ODE.CallbackSet) =
n_measured_calls(cbs::SciMLBase.CallbackSet) =
map(x -> x.n_measured_calls, atmos_callbacks(cbs))

n_expected_calls(integrator) = n_expected_calls(
integrator.callback,
integrator.dt,
integrator.sol.prob.tspan,
)
n_expected_calls(cbs::ODE.CallbackSet, dt, tspan) =
n_expected_calls(cbs::SciMLBase.CallbackSet, dt, tspan) =
map(x -> n_expected_calls(x, dt, tspan), atmos_callbacks(cbs))

n_steps_per_cycle(integrator) =
n_steps_per_cycle(integrator.callback, integrator.dt)
function n_steps_per_cycle(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle(cbs::SciMLBase.CallbackSet, dt)
nspc = n_steps_per_cycle_per_cb(cbs, dt)
return isempty(nspc) ? 1 : lcm(nspc)
end

n_steps_per_cycle_per_cb(integrator) =
n_steps_per_cycle_per_cb(integrator.callback, integrator.dt)

function n_steps_per_cycle_per_cb(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle_per_cb(cbs::SciMLBase.CallbackSet, dt)
return map(atmos_callbacks(cbs)) do cb
cbf = callback_frequency(cb)
if cbf isa EveryΔt
Expand Down
6 changes: 3 additions & 3 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import ClimaCore.Fields
import ClimaComms
import ClimaCore as CC
import ClimaCore.Spaces
import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaAtmos.Parameters as CAP
import DiffEqCallbacks as DEQ
import DiffEqCallbacks as DECB
import ClimaCore: InputOutput
import Dates
using Insolation: instantaneous_zenith_angle
Expand Down Expand Up @@ -70,7 +70,7 @@ function turb_conv_affect_filter!(integrator)
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
ODE.u_modified!(integrator, false)
SciMLBase.u_modified!(integrator, false)
end

function rrtmgp_model_callback!(integrator)
Expand Down
Loading

0 comments on commit 4205c01

Please sign in to comment.