Skip to content

Commit

Permalink
Merge f26f30b into 4ab5afa
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian-sp authored Aug 6, 2021
2 parents 4ab5afa + f26f30b commit eb72536
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ NormL21
NormLinf
NuclearNorm
SqrNormL2
TotalVariation1D
```

## Penalties and other functions
Expand Down
1 change: 1 addition & 0 deletions src/ProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ include("functions/sqrNormL2.jl")
include("functions/sumPositive.jl")
include("functions/sqrHingeLoss.jl")
include("functions/crossEntropy.jl")
include("functions/TotalVariation1D.jl")

# Calculus rules

Expand Down
112 changes: 112 additions & 0 deletions src/functions/TotalVariation1D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# 1-dimensional Total Variation (times a constant)

export TotalVariation1D

"""
** 1-dimensional Total Variation**
TotalVariation1D(λ=1)
With a nonnegative scalar parameter λ, returns the function
```math
f(x) = λ ∑_{i=2}^{n} |x_i - x_{i-1}|.
```
"""
struct TotalVariation1D{T <: Real} <: ProximableFunction
lambda::T
function TotalVariation1D{T}(lambda::T) where {T <: Real}
if lambda < 0
error("parameter λ must be nonnegative")
else
new(lambda)
end
end
end

is_separable(f::TotalVariation1D) = false
is_convex(f::TotalVariation1D) = true
is_positively_homogeneous(f::TotalVariation1D) = true

TotalVariation1D(lambda::R=1) where {R <: Real} = TotalVariation1D{R}(lambda)

function (f::TotalVariation1D)(x::AbstractArray)
return f.lambda * norm(x[2:end] - x[1:end-1], 1)
end

# Condat algorithm
# https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf
function tvnorm_prox_condat(y::AbstractArray, x::AbstractArray, lambda::Real)
# solves y = arg min_z lambda*sum_k |z_{k+1}-z_k| + 1/2 * ||z-x||^2
N = length(x)

k=k0=kmin=kplus=1
vmin = x[1] - lambda
vmax = x[1] + lambda
umin = lambda
umax = -lambda

while 0 < 1
while k == N
if umin < 0
y[k0:kmin] .= vmin
kmin += 1
k = k0 = kmin
vmin = x[k]; umin = lambda;
umax = x[k] + lambda - vmax
elseif umax > 0
y[k0:kplus] .= vmax
kplus +=1
k=k0=kplus
vmax = x[k]; umax = -lambda;
umin = x[k] - lambda - vmin
else
y[k0:N] .= vmin + umin/(k-k0+1)
return
end

if k==N
y[N] = vmin + umin
return
end
end


if x[k+1] + umin < vmin - lambda
y[k0:kmin] .= vmin
kmin +=1
k = k0 =kplus =kmin
vmin = x[k]; vmax = x[k] + 2*lambda;
umin = lambda; umax = -lambda;
elseif x[k+1] + umax > vmax + lambda
y[k0:kplus] .= vmax
kplus +=1
k = k0 =kmin = kplus
vmin = x[k] - 2*lambda; vmax = x[k];
umin = lambda; umax = -lambda;
else
k +=1 ;
umin = umin + x[k] - vmin
umax = umax + x[k] - vmax
if umin >= lambda
vmin = vmin + (umin-lambda)/(k-k0+1)
umin = lambda; kmin = k;
end

if umax <= -lambda
vmax += (umax+lambda)/(k-k0+1)
umax = -lambda; kplus = k;
end
end
end
end

function prox!(y::AbstractArray{T}, f::TotalVariation1D, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real
a = gamma * f.lambda
tvnorm_prox_condat(y, x, a)
return f.lambda * norm(y[2:end] - y[1:end-1], 1)
end

fun_name(f::TotalVariation1D) = "1D Total Variation"
fun_dom(f::TotalVariation1D) = "AbstractArray{Real}"
fun_expr(f::TotalVariation1D) = "x ↦ λ ∑_{i=2}^{n} |x_i - x_{i-1}|"
fun_params(f::TotalVariation1D) = "λ = $(f.lambda)"
37 changes: 37 additions & 0 deletions test/test_optimality_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ check_optimality(f::IndBallL1, x, gamma, y) = begin
return all(sign_is_correct) && check_optimality(IndSimplex(f.r), abs.(x), gamma, abs.(y))
end

check_optimality(f::TotalVariation1D, x, gamma, y) = begin
N = length(x)
# compute dual solution
R = real(eltype(x))
u = zeros(R, N+1)
u[1] = 0
for k in 2:N+1
u[k] = x[k-1]-y[k-1]+u[k-1]
end

# check whether all duals in interval
c1 = all(abs.(u) .<= gamma*f.lambda + 10*eps(R))
# check whether last equals 0 (first by construction)
c2 = isapprox(u[end], 0, atol=10*eps(R))
# check whether equal +- gamma*lambda
h = sign.(y[1:end-1] - y[2:end])
c3 = all(isapprox.( u[2:end-1] .* abs.(h) , h *f.lambda*gamma))
return c1 && c2 && c3
end

test_cases = [
Dict(
"f" => LeastSquares(randn(20, 10), randn(20)),
Expand Down Expand Up @@ -147,6 +167,23 @@ test_cases = [
"x" => randn(3, 5),
"gamma" => 0.1 + rand(),
),

Dict(
"f" => TotalVariation1D(0.01),
"x" => vcat(LinRange(1., -1., 10), -1*ones(3), LinRange(-1., 1., 10)),
"gamma" => 1.,
),

Dict(
"f" => TotalVariation1D(1.0),
"x" => [-2.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 2.4375],
"gamma" => 1.0,
),
Dict(
"f" => TotalVariation1D(1.0),
"x" => [0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 2.4375],
"gamma" => 1.0,
),
]

@testset "Optimality conditions" begin
Expand Down

0 comments on commit eb72536

Please sign in to comment.