Skip to content

Commit

Permalink
Added bounds check for sigma_n
Browse files Browse the repository at this point in the history
  • Loading branch information
axla-io committed Sep 19, 2023
1 parent 4b069bf commit 957a4ab
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ function perform_step!(cache::DFSaneCache{true})
# Line search direction
@. cache.𝒹 = -σₙ * cache.fuₙ₋₁

η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)
η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)

= maximum(cache.ℋ)
α₊ = α₁
Expand All @@ -154,7 +154,6 @@ function perform_step!(cache::DFSaneCache{true})
f(cache.fuₙ, cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = sum(abs2, cache.fuₙ)
f₍ₙₒᵣₘ₎ₙ ^= (nₑₓₚ / 2)

for _ in 1:(cache.alg.max_inner_iterations)
𝒸 =+ η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁

Expand All @@ -164,19 +163,19 @@ function perform_step!(cache::DFSaneCache{true})
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹 # correct order?
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹

f(cache.fuₙ, cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = sum(abs2, cache.fuₙ)
f₍ₙₒᵣₘ₎ₙ ^= (nₑₓₚ / 2)

(f₍ₙₒᵣₘ₎ₙ .≤ 𝒸) && break
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)

@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹 # correct order?
@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
f(cache.fuₙ, cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = sum(abs2, cache.fuₙ)
f₍ₙₒᵣₘ₎ₙ ^= (nₑₓₚ / 2)
Expand All @@ -193,7 +192,19 @@ function perform_step!(cache::DFSaneCache{true})
α₊ = sum(abs2, cache.uₙ₋₁)
@. cache.uₙ₋₁ = cache.uₙ₋₁ * cache.fuₙ₋₁
α₋ = sum(cache.uₙ₋₁)
cache.σₙ = α₊ / (α₋ + T(1e-5))
cache.σₙ = α₊ / α₋

# Spectral parameter bounds check
if abs(cache.σₙ) > σₘₐₓ || abs(cache.σₙ) < σₘᵢₙ
test_norm = sqrt(sum(abs2, cache.fuₙ₋₁))
if test_norm > 1
cache.σₙ = 1.0
elseif testnorm < 1e-5
cache.σₙ = 1e5
else
cache.σₙ = 1.0 / test_norm
end
end

# Take step
@. cache.uₙ₋₁ = cache.uₙ
Expand All @@ -219,7 +230,7 @@ function perform_step!(cache::DFSaneCache{false})
# Line search direction
@. cache.𝒹 = -σₙ * cache.fuₙ₋₁

η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)
η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)

= maximum(cache.ℋ)
α₊ = α₁
Expand Down Expand Up @@ -268,7 +279,19 @@ function perform_step!(cache::DFSaneCache{false})
α₊ = sum(abs2, cache.uₙ₋₁)
@. cache.uₙ₋₁ = cache.uₙ₋₁ * cache.fuₙ₋₁
α₋ = sum(cache.uₙ₋₁)
cache.σₙ = α₊ / (α₋ + T(1e-5))
cache.σₙ = α₊ / α₋

# Spectral parameter bounds check
if abs(cache.σₙ) > σₘₐₓ || abs(cache.σₙ) < σₘᵢₙ
test_norm = sqrt(sum(abs2, cache.fuₙ₋₁))
if test_norm > 1
cache.σₙ = 1.0
elseif testnorm < 1e-5
cache.σₙ = 1e5
else
cache.σₙ = 1.0 / test_norm
end
end

# Take step
@. cache.uₙ₋₁ = cache.uₙ
Expand Down

0 comments on commit 957a4ab

Please sign in to comment.