Skip to content

Commit

Permalink
actually avoid making allocations for inexact jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
yonatanwesen committed Oct 28, 2023
1 parent 7efded5 commit 467009c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Expand All @@ -35,9 +36,8 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
BandedMatrices = "1"
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.130"
EnumX = "1"
Expand All @@ -63,6 +63,7 @@ julia = "1.9"
[extras]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -79,7 +80,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import PrecompileTools

PrecompileTools.@recompile_invalidations begin
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
using FastBroadcast: @.., True, False
import ArrayInterface: restructure

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix,
matrix_colors, parameterless_type, ismutable, issingular
matrix_colors, parameterless_type, ismutable, issingular,fast_scalar_indexing
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff
Expand Down
16 changes: 13 additions & 3 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,19 @@ end
function perform_step!(cache::PseudoTransientCache{true})
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, alpha, tc_storage = cache
jacobian!!(J, cache)
inv_alpha = inv(alpha)

if J isa SciMLBase.AbstractSciMLOperator
J = J - (1 / alpha) * I
J = J - inv_alpha * I

Check warning on line 120 in src/pseudotransient.jl

View check run for this annotation

Codecov / codecov/patch

src/pseudotransient.jl#L119-L120

Added lines #L119 - L120 were not covered by tests
else
J .= J - (1 / alpha) * I
idxs = diagind(J)
if fast_scalar_indexing(J)

Check warning on line 123 in src/pseudotransient.jl

View check run for this annotation

Codecov / codecov/patch

src/pseudotransient.jl#L123

Added line #L123 was not covered by tests
@inbounds for i in axes(J, 1)
J[i, i] = J[i, i] - inv_alpha
end
else
@.. broadcast=false @view(J[idxs])=@view(J[idxs]) - inv_alpha

Check warning on line 128 in src/pseudotransient.jl

View check run for this annotation

Codecov / codecov/patch

src/pseudotransient.jl#L128

Added line #L128 was not covered by tests
end
end

termination_condition = cache.termination_condition(tc_storage)
Expand Down Expand Up @@ -151,8 +160,9 @@ function perform_step!(cache::PseudoTransientCache{false})
termination_condition = cache.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)
inv_alpha = inv(alpha)

cache.J = cache.J - (1 / alpha) * I
cache.J = cache.J - inv_alpha * I
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
Expand Down

0 comments on commit 467009c

Please sign in to comment.