Skip to content

Commit

Permalink
Fix PANOCplus (#66)
Browse files Browse the repository at this point in the history
* fixed backtracks counter reset

* fixed acceleration
  • Loading branch information
Alberto De Marchi authored Feb 18, 2022
1 parent 52ef0f0 commit 69e6fc2
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/algorithms/panocplus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R}
return state, state
end

set_next_direction!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState) = mul!(state.d, state.H, -state.res)
set_next_direction!(::NoAccelerationStyle, ::PANOCplusIteration, state::PANOCplusState) = state.d .= .-state.res
set_next_direction!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState) = mul!(state.d, state.H, -state.res_prev)
set_next_direction!(::NoAccelerationStyle, ::PANOCplusIteration, state::PANOCplusState) = state.d .= .-state.res_prev
set_next_direction!(iter::PANOCplusIteration, state::PANOCplusState) = set_next_direction!(acceleration_style(typeof(iter.directions)), iter, state)

update_direction_state!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState) = update!(state.H, state.x - state.x_prev, state.res - state.res_prev)
Expand Down Expand Up @@ -140,8 +140,10 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where
# backtrack tau 1 → 0
state.tau = R(1)
state.x .= state.x_prev .+ state.d
tau_backtracks = 0
else
state.x .= (1 - state.tau) * (state.x_prev .- state.res_prev) + state.tau * (state.x_prev .+ state.d)
tau_backtracks += 1
end

mul!(state.Ax, iter.A, state.x)
Expand All @@ -154,9 +156,9 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where

f_Az_upp = f_model(iter, state)

mul!(state.Az, iter.A, state.z)
f_Az = gradient!(state.grad_f_Az, iter.f, state.Az)
if (iter.gamma === nothing || iter.adaptive == true)
mul!(state.Az, iter.A, state.z)
f_Az = gradient!(state.grad_f_Az, iter.f, state.Az)
tol = 10 * eps(R) * (1 + abs(f_Az))
if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma
state.gamma *= 0.5
Expand All @@ -167,9 +169,6 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where
reset_direction_state!(iter, state)
continue
end
else
mul!(state.Az, iter.A, state.z)
f_Az = gradient!(state.grad_f_Az, iter.f, state.Az)
end
mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az)

Expand All @@ -178,7 +177,6 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where
break
end
state.tau = tau_backtracks >= iter.max_backtracks - 1 ? R(0) : state.tau / 2
tau_backtracks += 1
can_update_direction = false

end
Expand Down

0 comments on commit 69e6fc2

Please sign in to comment.