Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework default algorithm to be fully type stable #307

Merged
merged 21 commits into from
May 31, 2023
Merged

Conversation

ChrisRackauckas
Copy link
Member

@ChrisRackauckas ChrisRackauckas commented May 29, 2023

Two possible ideas. One is to put all algorithms and caches in the default algorithm into a unityper. The other is to make a "mega algorithm" with runtime information on the choice. This takes the latter approach because that keeps most of the package code intact and makes it so that any algorithm choice by the user will not have any runtime behavior.

This uses an enum inside of the algorithm struct in order to choose the actual solver for the given process. In init and solve static dispatching is done through hardcoded branches.

Todo:

  • Finish the implementation to make it actually work (check the generated functions)
  • Make the OperatorAssumptions be dynamic information
  • Test and time to see the real overhead
  • Munge the output into a solution struct that doesn't have the actual algorithm so that the output is type stable

Things to consider:

The one thing that may be a blocker is the init cost. While I don't think it will be an issue, we will need to see if constructing all of the empty arrays simply to hold a thing of the right type is too costly. Basically, what we may need is for Float64[] to be optimized to be a no-op zero allocation, in which case essentially the whole cache structure should be free to build. As it stands, this may be a cause of overhead as it needs to build all of the potential caches even if it only ever uses one.

Two possible ideas. One is to put all algorithms and caches in the default algorithm into a unityper. The other is to make a "mega algorithm" with runtime information on the choice. This takes the latter approach because that keeps most of the package code intact and makes it so that any algorithm choice by the user will not have any runtime behavior.

This uses an enum inside of the algorithm struct in order to choose the actual solver for the given process. In init and solve static dispatching is done through hardcoded branches.

Todo:

- [ ] Finish the implementation to make it actually work (check the generated functions)
- [ ] Make the OperatorAssumptions be dynamic information
- [ ] Test and time to see the real overhead
- [ ] Munge the output into a solution struct that doesn't have the actual algorithm so that the output is type stable

Things to consider:

The one thing that may be a blocker is the `init` cost. While I don't think it will be an issue, we will need to see if constructing all of the empty arrays simply to hold a thing of the right type is too costly. Basically, what we may need is for `Float64[]` to be optimized to be a no-op zero allocation, in which case essentially the whole cache structure should be free to build. As it stands, this may be a cause of overhead as it needs to build all of the potential caches even if it only ever uses one.
@ChrisRackauckas
Copy link
Member Author

Inference Works:

using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob)
using Test
@inferred solve(prob)
@inferred init(prob, nothing)

@ChrisRackauckas
Copy link
Member Author

julia> @benchmark solve($prob)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min  max):  2.653 μs   1.069 ms  ┊ GC (min  max):  0.00%  74.47%
 Time  (median):     2.875 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.443 μs ± 36.512 μs  ┊ GC (mean ± σ):  24.39% ±  2.96%

    ▃▅▆▇█████▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁  ▁ ▁▁▁▁                        ▃
  ▆████████████████████████████████████████████▇█▇▇▅▇▆▇▇▅▆▅▅ █
  2.65 μs      Histogram: log(frequency) by time     4.02 μs <

 Memory estimate: 8.83 KiB, allocs estimate: 101.

julia> @benchmark solve($prob, $(LUFactorization()))
BenchmarkTools.Trial: 10000 samples with 112 evaluations.
 Range (min  max):  762.652 ns  35.605 μs  ┊ GC (min  max): 0.00%  96.61%
 Time  (median):     778.643 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   858.785 ns ±  1.231 μs  ┊ GC (mean ± σ):  8.18% ±  5.54%

  ▂▅▇██▇▇▅▄▃▂▂▂▂▂▁▁▁                                           ▂
  ██████████████████████▇▆▇▇▇▇█▇▇▇▇▆▇▆▆▆▆▆▆▇▇▆▆▆▆▆▆▆▆▅▅▆▆▅▅▅▃▅ █
  763 ns        Histogram: log(frequency) by time       966 ns <

 Memory estimate: 1.47 KiB, allocs estimate: 15.

@ChrisRackauckas
Copy link
Member Author

Needs JuliaArrays/ArrayInterface.jl#415

@ChrisRackauckas
Copy link
Member Author

using LinearSolve

A = rand(100, 100)
b = rand(100)
prob = LinearProblem(A, b)

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(LUFactorization()))
julia> @benchmark solve($prob)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  66.166 μs   2.565 ms  ┊ GC (min  max): 0.00%  89.21%
 Time  (median):     72.355 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   77.852 μs ± 88.509 μs  ┊ GC (mean ± σ):  4.89% ±  4.16%

            ▃██▅▂                                              
  ▂▄▅▄▂▂▂▁▁▄█████▆▆▆█▇▆▅▄▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  66.2 μs         Histogram: frequency by time        91.7 μs <

 Memory estimate: 110.23 KiB, allocs estimate: 124.

julia> @benchmark solve($prob, $(LUFactorization()))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  107.875 μs   6.946 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     116.917 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   121.663 μs ± 92.707 μs  ┊ GC (mean ± σ):  1.87% ± 3.81%

        ▆█▁  ▆▂                                                 
  ▁▁▃▂▂▅███▄███▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  108 μs          Histogram: frequency by time          163 μs <

 Memory estimate: 81.80 KiB, allocs estimate: 16.

The problem seems to be that we need a way to lazily init the GMRES cache

@ChrisRackauckas
Copy link
Member Author

using LinearSolve

A = rand(100, 100)
b = rand(100)
prob = LinearProblem(A, b)

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(RFLUFactorization()))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  65.333 μs   2.537 ms  ┊ GC (min  max): 0.00%  87.08%
 Time  (median):     70.583 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   76.125 μs ± 91.330 μs  ┊ GC (mean ± σ):  4.66% ±  3.77%

   ▄▅▄▁   ▁▆▇█▇▆▅▃▄▄▅▅▄▄▃▃▃▂▂▁▁▁▁▂▂▁▁                         ▂
  ▇████▆▅▂█████████████████████████████████████▇▆▇▇▇▅▇▆▆▄▅▄▅▅ █
  65.3 μs      Histogram: log(frequency) by time      90.2 μs <

 Memory estimate: 88.84 KiB, allocs estimate: 105.

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  63.458 μs   2.124 ms  ┊ GC (min  max): 0.00%  95.95%
 Time  (median):     68.208 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   72.105 μs ± 57.144 μs  ┊ GC (mean ± σ):  3.27% ±  3.97%

   ▁       ▅█▅                                                 
  ▄█▂▁▁▁▁▁▅████▄▃▃▃▄▄▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  63.5 μs         Histogram: frequency by time        87.5 μs <

 Memory estimate: 81.81 KiB, allocs estimate: 16.

reasonable scaling

@ChrisRackauckas
Copy link
Member Author

using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)

solve(prob).alg.alg

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(GenericLUFactorization()))

Small scale is not okay.

BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min  max):  2.653 μs  904.287 μs  ┊ GC (min  max):  0.00%  72.30%
 Time  (median):     2.796 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.310 μs ±  33.458 μs  ┊ GC (mean ± σ):  21.23% ±  2.74%

  ▁▅▇██▇▇▅▄▃▁     ▁ ▁▁  ▁ ▁▁▁▁▁▁                              ▂
  ███████████████████████████████▇▇▇▆▇▆▆▆▆▆▇▆▆▆▆▄▆▆▆▄▆▆▄▅▅▅▅▅ █
  2.65 μs      Histogram: log(frequency) by time      4.47 μs <

 Memory estimate: 8.59 KiB, allocs estimate: 105.

BenchmarkTools.Trial: 10000 samples with 140 evaluations.
 Range (min  max):  705.057 ns   18.727 μs  ┊ GC (min  max): 0.00%  95.22%
 Time  (median):     727.971 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   813.099 ns ± 938.705 ns  ┊ GC (mean ± σ):  7.74% ±  6.35%

   ▂▅▇██▇▄▃▃▃▄▄▄▄▂▂▂▂▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁                        ▂
  ▆████████████████████████████████████████████▇█▅▇▇▆▆▇▆▆▆▅▆▆▆▅ █
  705 ns        Histogram: log(frequency) by time        934 ns <

 Memory estimate: 1.47 KiB, allocs estimate: 15.

@ChrisRackauckas
Copy link
Member Author

This is actually very good now.

using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)

solve(prob).alg.alg

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(GenericLUFactorization(LinearSolve.RowMaximum())))
BenchmarkTools.Trial: 10000 samples with 202 evaluations.
 Range (min  max):  379.743 ns    7.574 μs  ┊ GC (min  max): 0.00%  93.00%
 Time  (median):     398.099 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   427.764 ns ± 306.002 ns  ┊ GC (mean ± σ):  4.29% ±  5.58%

       ▇█▄                                                       
  ▁▁▂▃████▇▆▆▅▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  380 ns           Histogram: frequency by time          516 ns <

 Memory estimate: 880 bytes, allocs estimate: 8.

BenchmarkTools.Trial: 10000 samples with 197 evaluations.
 Range (min  max):  452.198 ns   11.249 μs  ┊ GC (min  max): 0.00%  92.86%
 Time  (median):     463.411 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   498.177 ns ± 441.242 ns  ┊ GC (mean ± σ):  4.74% ±  5.09%

   ▃▇██▆▄▃▃▃▄▄▃▃▃▂▁▁▁▂▂▂▁▁▁▁▁   ▁                               ▂
  ▇█████████████████████████████████▇▇█▇▇▇▇█▇▆▆▄▇▇▅▆▅▄▆▅▅▄▅▅▅▄▄ █
  452 ns        Histogram: log(frequency) by time        599 ns <

 Memory estimate: 704 bytes, allocs estimate: 9.

Not sure why the default method is faster, but it's pretty consistently faster. And inference works:

using Test
@inferred solve(prob)
@inferred init(prob, nothing)

@ChrisRackauckas ChrisRackauckas changed the title RFC/WIP: Rework default algorithm to be fully type stable Rework default algorithm to be fully type stable May 30, 2023
@codecov
Copy link

codecov bot commented May 30, 2023

Codecov Report

Merging #307 (73dfa2c) into main (cb31d58) will increase coverage by 2.99%.
The diff coverage is 78.91%.

@@            Coverage Diff             @@
##             main     #307      +/-   ##
==========================================
+ Coverage   73.39%   76.38%   +2.99%     
==========================================
  Files          15       15              
  Lines        1026     1207     +181     
==========================================
+ Hits          753      922     +169     
- Misses        273      285      +12     
Impacted Files Coverage Δ
ext/LinearSolveHYPREExt.jl 90.80% <ø> (ø)
src/LinearSolve.jl 54.54% <ø> (-28.22%) ⬇️
src/default.jl 67.26% <68.55%> (+30.19%) ⬆️
src/factorization.jl 79.67% <85.62%> (-0.46%) ⬇️
src/iterative_wrappers.jl 78.97% <94.44%> (-0.03%) ⬇️
src/common.jl 92.30% <100.00%> (+1.19%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@ChrisRackauckas
Copy link
Member Author

Next up I want to make IterativeSolvers and KrylovKit into extensions, and then setup every solver with docstrings, and that's 2.0.

@ChrisRackauckas ChrisRackauckas merged commit 9718ea2 into main May 31, 2023
@ChrisRackauckas ChrisRackauckas deleted the defaultalg branch May 31, 2023 16:45
@chriselrod
Copy link
Contributor

Not sure why the default method is faster, but it's pretty consistently faster. And inference works:

using Test
@inferred solve(prob)
@inferred init(prob, nothing)

Does it pass JET.@test_opt?

@ChrisRackauckas
Copy link
Member Author

I didn't know about that, but I did check Cthulhu and it worked. If you can PR to add the JET testing that would be helpful.

@ChrisRackauckas
Copy link
Member Author

#318

some work, others hit random stuff like print.

ChrisRackauckas added a commit to SciML/OrdinaryDiffEq.jl that referenced this pull request Jan 1, 2024
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.
oscardssmith pushed a commit to oscardssmith/OrdinaryDiffEq.jl that referenced this pull request May 8, 2024
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.

fix typo / test

fix precompilation choices

Update src/composite_algs.jl

Co-authored-by: Nathanael Bosch <[email protected]>

Update src/composite_algs.jl

switch CompositeCache away from tuple so it can start undef

Default Cache

fix precompile

remove fallbacks

remove fallbacks
oscardssmith pushed a commit to oscardssmith/OrdinaryDiffEq.jl that referenced this pull request May 8, 2024
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.

fix typo / test

fix precompilation choices

Update src/composite_algs.jl

Co-authored-by: Nathanael Bosch <[email protected]>

Update src/composite_algs.jl

switch CompositeCache away from tuple so it can start undef

Default Cache

fix precompile

remove fallbacks

remove fallbacks
oscardssmith pushed a commit to oscardssmith/OrdinaryDiffEq.jl that referenced this pull request May 14, 2024
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.

fix typo / test

fix precompilation choices

Update src/composite_algs.jl

Co-authored-by: Nathanael Bosch <[email protected]>

Update src/composite_algs.jl

switch CompositeCache away from tuple so it can start undef

Default Cache

fix precompile

remove fallbacks

remove fallbacks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants