Skip to content

Commit

Permalink
Default to QR for GaussNewton
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent 611477c commit b25c8bc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.4.0"
version = "2.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
60 changes: 48 additions & 12 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
end

function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
function set_linsolve(alg::GaussNewton{CJ}, linsolve) where {CJ}
return GaussNewton{CJ}(alg.ad, linsolve, alg.precs)

Check warning on line 50 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,

Check warning on line 53 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L53

Added line #L53 was not covered by tests
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
Expand Down Expand Up @@ -81,15 +85,31 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

# Use QR if the user did not specify a linear solver
if alg.linsolve === nothing
alg = set_linsolve(alg, QRFactorization(ColumnNorm(), 16, true))
linsolve_with_JᵀJ = Val(false)

Check warning on line 92 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L90-L92

Added lines #L90 - L92 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 94 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L94

Added line #L94 was not covered by tests
end

u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = f(u, p)
end
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_with_JᵀJ = Val(true))

if SciMLBase._unwrap_val(linsolve_with_JᵀJ)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p,

Check warning on line 106 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L105-L106

Added lines #L105 - L106 were not covered by tests
Val(iip); linsolve_with_JᵀJ)
else
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p,

Check warning on line 109 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L109

Added line #L109 was not covered by tests
Val(iip); linsolve_with_JᵀJ)
JᵀJ, Jᵀf = nothing, nothing

Check warning on line 111 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L111

Added line #L111 was not covered by tests
end

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
Expand All @@ -99,12 +119,20 @@ end
function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
linu = _vec(du), p, reltol = cache.abstol)
if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

Check warning on line 125 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L123-L125

Added lines #L123 - L125 were not covered by tests
end

# u = u - JᵀJ \ Jᵀfu
if cache.JᵀJ === nothing
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),

Check warning on line 130 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L129-L130

Added lines #L129 - L130 were not covered by tests
p, reltol = cache.abstol)
else
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),

Check warning on line 133 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L133

Added line #L133 was not covered by tests
linu = _vec(du), p, reltol = cache.abstol)
end
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,14 +153,22 @@ function perform_step!(cache::GaussNewtonCache{false})

cache.J = jacobian!!(cache.J, cache)

cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
if cache.JᵀJ !== nothing
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1

Check warning on line 158 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L156-L158

Added lines #L156 - L158 were not covered by tests
end

# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
if cache.JᵀJ === nothing
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),

Check warning on line 166 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
linu = _vec(cache.du), p, reltol = cache.abstol)
else
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),

Check warning on line 169 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L169

Added line #L169 was not covered by tests
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
end
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down

0 comments on commit b25c8bc

Please sign in to comment.