From d7a01684ac50de1be859263e0863372ddb559c10 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Tue, 3 Aug 2021 16:13:12 +0200 Subject: [PATCH 1/8] add initial normTV --- src/functions/normTV.jl | 111 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 src/functions/normTV.jl diff --git a/src/functions/normTV.jl b/src/functions/normTV.jl new file mode 100644 index 0000000..92136d5 --- /dev/null +++ b/src/functions/normTV.jl @@ -0,0 +1,111 @@ +# Total variation norm (times a constant) + +export NormTV + +""" +**``TV`` norm** + + NormTV(λ=1) + +With a nonnegative scalar parameter λ, returns the function +```math +f(x) = λ ∑_{i=2}^{n} |x_i - x_{i-1}|. +``` +""" +struct NormTV{T <: Real} <: ProximableFunction + lambda::T + function NormTV{T}(lambda::T) where {T <: Real} + if lambda < 0 + error("parameter λ must be nonnegative") + else + new(lambda) + end + end +end + +is_separable(f::NormTV) = false +is_convex(f::NormTV) = true +is_positively_homogeneous(f::NormTV) = true + +NormTV(lambda::R=1) where {R <: Real} = NormTV{R}(lambda) + +function (f::NormTV)(x::AbstractArray) + return f.lambda * norm(x[2:end] - x[1:end-1], 1) +end + +# Condat algorithm +# https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf +function condat(y::AbstractArray, x::AbstractArray, lambda::Real) + N = length(x); + + k=k0=kmin=kplus=1; + vmin = x[1] - lambda; + vmax = x[1] + lambda; + umin = lambda; + umax = -lambda; + + while 0 < 1 + while k == N + if umin < 0 + y[k0:kmin] .= vmin; + kmin += 1; + k = k0 = kmin; + vmin = x[k]; umin = lambda; + umax = x[k] + lambda - vmax; + elseif umax > 0 + y[k0:kplus] .= vmax; + kplus +=1; + k=k0=kplus; + vmax = x[k]; umax = -lambda; + umin = x[k] - lambda - vmin; + else + y[k0:N] .= vmin + umin/(k-k0+1); + return + end + + if k==N + y[N] = vmin + umin; + return + end + end + + + if x[k+1] + umin < vmin - lambda + y[k0:kmin] .= vmin; + kmin +=1; + k = k0 =kplus =kmin; + vmin = x[k]; vmax = x[k] + 2*lambda; + umin = lambda; umax = -lambda; + elseif x[k+1] + umax > vmax + lambda + y[k0:kplus] .= vmax; + kplus +=1; + k = k0 =kmin = kplus; + vmin = x[k] - 2*lambda; vmax = x[k]; + umin = lambda; umax = -lambda; + else + k +=1 ; + umin = umin + x[k] - vmin; + umax = umax + x[k] - vmax; + if umin >= lambda + vmin = vmin + (umin-lambda)/(k-k0+1); + umin = lambda; kmin = k; + end + + if umax <= -lambda + vmax += (umax+lambda)/(k-k0+1); + umax = -lambda; kplus = k; + end + end + end +end + +function prox!(y::AbstractArray{T}, f::NormTV, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real + a = gamma * f.lambda + y = condat(y, x, a) + return y +end + +fun_name(f::NormTV) = "1D Total variation norm" +fun_dom(f::NormTV) = "AbstractArray{Real}" +fun_expr(f::NormTV) = "x ↦ λ ∑_{i=2}^{n} |x_i - x_{i-1}|" +fun_params(f::NormTV) = "λ = $(f.lambda)" From b1d7f96addefcdc6c0a7635db147d9a0cba38a57 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Wed, 4 Aug 2021 09:26:39 +0200 Subject: [PATCH 2/8] fix all issues --- src/functions/normTV.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/functions/normTV.jl b/src/functions/normTV.jl index 92136d5..c6eb4d4 100644 --- a/src/functions/normTV.jl +++ b/src/functions/normTV.jl @@ -3,7 +3,7 @@ export NormTV """ -**``TV`` norm** +** 1-dimensional ``TV`` norm** NormTV(λ=1) @@ -35,7 +35,7 @@ end # Condat algorithm # https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf -function condat(y::AbstractArray, x::AbstractArray, lambda::Real) +function tvnorm_prox_condat(y::AbstractArray, x::AbstractArray, lambda::Real) N = length(x); k=k0=kmin=kplus=1; @@ -101,8 +101,8 @@ end function prox!(y::AbstractArray{T}, f::NormTV, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real a = gamma * f.lambda - y = condat(y, x, a) - return y + y = tvnorm_prox_condat(y, x, a) + return f.lambda * norm(y[2:end] - y[1:end-1], 1) end fun_name(f::NormTV) = "1D Total variation norm" From 7596cb81257e7582241a2fbeab1b3b006b5c8133 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Wed, 4 Aug 2021 16:58:47 +0200 Subject: [PATCH 3/8] add to docs --- docs/src/functions.md | 1 + src/functions/normTV.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/functions.md b/docs/src/functions.md index 0e3a974..55ff0dd 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -65,6 +65,7 @@ NormL21 NormLinf NuclearNorm SqrNormL2 +NormTV ``` ## Penalties and other functions diff --git a/src/functions/normTV.jl b/src/functions/normTV.jl index c6eb4d4..85ba0f2 100644 --- a/src/functions/normTV.jl +++ b/src/functions/normTV.jl @@ -36,6 +36,7 @@ end # Condat algorithm # https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf function tvnorm_prox_condat(y::AbstractArray, x::AbstractArray, lambda::Real) + # solves y = arg min_z lambda*sum_k |z_{k+1}-z_k| + 1/2 * ||z-x||^2 N = length(x); k=k0=kmin=kplus=1; @@ -101,7 +102,7 @@ end function prox!(y::AbstractArray{T}, f::NormTV, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real a = gamma * f.lambda - y = tvnorm_prox_condat(y, x, a) + tvnorm_prox_condat(y, x, a) return f.lambda * norm(y[2:end] - y[1:end-1], 1) end From f0d6a2a0c9aa27a3d18290cff6057ad90e59f46c Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Wed, 4 Aug 2021 17:11:35 +0200 Subject: [PATCH 4/8] add opt conditions test --- test/test_optimality_conditions.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_optimality_conditions.jl b/test/test_optimality_conditions.jl index 5ea11af..beb7f16 100644 --- a/test/test_optimality_conditions.jl +++ b/test/test_optimality_conditions.jl @@ -21,6 +21,25 @@ check_optimality(f::IndBallL1, x, gamma, y) = begin return all(sign_is_correct) && check_optimality(IndSimplex(f.r), abs.(x), gamma, abs.(y)) end +check_optimality(f::NormTV, x, gamma, y) = begin + N = length(x) + # compute dual solution + u = zeros(N+1) + u[1] = 0 + for k in 2:N+1 + u[k] = x[k-1]-y[k-1]+u[k-1] + end + + # check whether all duals in interval + c1 = all(abs.(u) .<= gamma*f.lambda + 1e-10) + # check whether last equals 0 (first by construction) + c2 = isapprox(u[end], 0, atol=1e-12) + # check whether equal +- gamma*lambda + h = sign.(y[1:end-1] - y[2:end]) + c3 = all(isapprox.( u[2:end-1] .* abs.(h) , h *f.lambda*gamma)) + return c1 && c2 && c3 +end + test_cases = [ Dict( "f" => LeastSquares(randn(20, 10), randn(20)), From e10a6c3c48dc0fb5240909302f4d4ba3c69d5de8 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Thu, 5 Aug 2021 09:21:37 +0200 Subject: [PATCH 5/8] rename to TotalVariation1D, add to include list --- docs/src/functions.md | 2 +- src/ProximalOperators.jl | 1 + .../{normTV.jl => TotalVariation1D.jl} | 32 +++++++++---------- test/test_optimality_conditions.jl | 2 +- 4 files changed, 19 insertions(+), 18 deletions(-) rename src/functions/{normTV.jl => TotalVariation1D.jl} (73%) diff --git a/docs/src/functions.md b/docs/src/functions.md index 55ff0dd..3d48381 100644 --- a/docs/src/functions.md +++ b/docs/src/functions.md @@ -65,7 +65,7 @@ NormL21 NormLinf NuclearNorm SqrNormL2 -NormTV +TotalVariation1D ``` ## Penalties and other functions diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index a626774..ce5717a 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -66,6 +66,7 @@ include("functions/sqrNormL2.jl") include("functions/sumPositive.jl") include("functions/sqrHingeLoss.jl") include("functions/crossEntropy.jl") +include("functions/TotalVariation1D.jl") # Calculus rules diff --git a/src/functions/normTV.jl b/src/functions/TotalVariation1D.jl similarity index 73% rename from src/functions/normTV.jl rename to src/functions/TotalVariation1D.jl index 85ba0f2..e1b780f 100644 --- a/src/functions/normTV.jl +++ b/src/functions/TotalVariation1D.jl @@ -1,20 +1,20 @@ -# Total variation norm (times a constant) +# 1-dimensional Total Variation (times a constant) -export NormTV +export TotalVariation1D """ -** 1-dimensional ``TV`` norm** +** 1-dimensional Total Variation** - NormTV(λ=1) + TotalVariation1D(λ=1) With a nonnegative scalar parameter λ, returns the function ```math f(x) = λ ∑_{i=2}^{n} |x_i - x_{i-1}|. ``` """ -struct NormTV{T <: Real} <: ProximableFunction +struct TotalVariation1D{T <: Real} <: ProximableFunction lambda::T - function NormTV{T}(lambda::T) where {T <: Real} + function TotalVariation1D{T}(lambda::T) where {T <: Real} if lambda < 0 error("parameter λ must be nonnegative") else @@ -23,13 +23,13 @@ struct NormTV{T <: Real} <: ProximableFunction end end -is_separable(f::NormTV) = false -is_convex(f::NormTV) = true -is_positively_homogeneous(f::NormTV) = true +is_separable(f::TotalVariation1D) = false +is_convex(f::TotalVariation1D) = true +is_positively_homogeneous(f::TotalVariation1D) = true -NormTV(lambda::R=1) where {R <: Real} = NormTV{R}(lambda) +TotalVariation1D(lambda::R=1) where {R <: Real} = TotalVariation1D{R}(lambda) -function (f::NormTV)(x::AbstractArray) +function (f::TotalVariation1D)(x::AbstractArray) return f.lambda * norm(x[2:end] - x[1:end-1], 1) end @@ -100,13 +100,13 @@ function tvnorm_prox_condat(y::AbstractArray, x::AbstractArray, lambda::Real) end end -function prox!(y::AbstractArray{T}, f::NormTV, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real +function prox!(y::AbstractArray{T}, f::TotalVariation1D, x::AbstractArray{T}, gamma::Real=1.0) where T <: Real a = gamma * f.lambda tvnorm_prox_condat(y, x, a) return f.lambda * norm(y[2:end] - y[1:end-1], 1) end -fun_name(f::NormTV) = "1D Total variation norm" -fun_dom(f::NormTV) = "AbstractArray{Real}" -fun_expr(f::NormTV) = "x ↦ λ ∑_{i=2}^{n} |x_i - x_{i-1}|" -fun_params(f::NormTV) = "λ = $(f.lambda)" +fun_name(f::TotalVariation1D) = "1D Total Variation" +fun_dom(f::TotalVariation1D) = "AbstractArray{Real}" +fun_expr(f::TotalVariation1D) = "x ↦ λ ∑_{i=2}^{n} |x_i - x_{i-1}|" +fun_params(f::TotalVariation1D) = "λ = $(f.lambda)" diff --git a/test/test_optimality_conditions.jl b/test/test_optimality_conditions.jl index beb7f16..c4d0245 100644 --- a/test/test_optimality_conditions.jl +++ b/test/test_optimality_conditions.jl @@ -21,7 +21,7 @@ check_optimality(f::IndBallL1, x, gamma, y) = begin return all(sign_is_correct) && check_optimality(IndSimplex(f.r), abs.(x), gamma, abs.(y)) end -check_optimality(f::NormTV, x, gamma, y) = begin +check_optimality(f::TotalVariation1D, x, gamma, y) = begin N = length(x) # compute dual solution u = zeros(N+1) From 2b8a70fe38722c056ba69a706f87cfa3623f2ab4 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Thu, 5 Aug 2021 11:12:53 +0200 Subject: [PATCH 6/8] add tests to list --- test/test_optimality_conditions.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_optimality_conditions.jl b/test/test_optimality_conditions.jl index c4d0245..161413f 100644 --- a/test/test_optimality_conditions.jl +++ b/test/test_optimality_conditions.jl @@ -24,16 +24,17 @@ end check_optimality(f::TotalVariation1D, x, gamma, y) = begin N = length(x) # compute dual solution - u = zeros(N+1) + R = real(eltype(x)) + u = zeros(R, N+1) u[1] = 0 for k in 2:N+1 u[k] = x[k-1]-y[k-1]+u[k-1] end # check whether all duals in interval - c1 = all(abs.(u) .<= gamma*f.lambda + 1e-10) + c1 = all(abs.(u) .<= gamma*f.lambda + 10*eps(R)) # check whether last equals 0 (first by construction) - c2 = isapprox(u[end], 0, atol=1e-12) + c2 = isapprox(u[end], 0, atol=10*eps(R)) # check whether equal +- gamma*lambda h = sign.(y[1:end-1] - y[2:end]) c3 = all(isapprox.( u[2:end-1] .* abs.(h) , h *f.lambda*gamma)) @@ -154,6 +155,18 @@ test_cases = [ "x" => randn(3, 5), "gamma" => 0.1 + rand(), ), + + Dict( + "f" => TotalVariation1D(1.), + "x" => randn(5), + "gamma" => 0.1 + rand(), + ), + + Dict( + "f" => TotalVariation1D(0.1), + "x" => [0.5, 0.4, 0.3, 0.2], + "gamma" => 1.0, + ), ] @testset "Optimality conditions" begin From 82ff56b63ca26c5761c739a1e11a8f90d6990d80 Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Thu, 5 Aug 2021 17:41:28 +0200 Subject: [PATCH 7/8] remove semicolon and change test --- src/functions/TotalVariation1D.jl | 52 +++++++++++++++--------------- test/test_optimality_conditions.jl | 6 ++-- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/functions/TotalVariation1D.jl b/src/functions/TotalVariation1D.jl index e1b780f..494e140 100644 --- a/src/functions/TotalVariation1D.jl +++ b/src/functions/TotalVariation1D.jl @@ -37,63 +37,63 @@ end # https://lcondat.github.io/publis/Condat-fast_TV-SPL-2013.pdf function tvnorm_prox_condat(y::AbstractArray, x::AbstractArray, lambda::Real) # solves y = arg min_z lambda*sum_k |z_{k+1}-z_k| + 1/2 * ||z-x||^2 - N = length(x); + N = length(x) - k=k0=kmin=kplus=1; - vmin = x[1] - lambda; - vmax = x[1] + lambda; - umin = lambda; - umax = -lambda; + k=k0=kmin=kplus=1 + vmin = x[1] - lambda + vmax = x[1] + lambda + umin = lambda + umax = -lambda while 0 < 1 while k == N if umin < 0 - y[k0:kmin] .= vmin; - kmin += 1; - k = k0 = kmin; + y[k0:kmin] .= vmin + kmin += 1 + k = k0 = kmin vmin = x[k]; umin = lambda; - umax = x[k] + lambda - vmax; + umax = x[k] + lambda - vmax elseif umax > 0 - y[k0:kplus] .= vmax; - kplus +=1; - k=k0=kplus; + y[k0:kplus] .= vmax + kplus +=1 + k=k0=kplus vmax = x[k]; umax = -lambda; - umin = x[k] - lambda - vmin; + umin = x[k] - lambda - vmin else - y[k0:N] .= vmin + umin/(k-k0+1); + y[k0:N] .= vmin + umin/(k-k0+1) return end if k==N - y[N] = vmin + umin; + y[N] = vmin + umin return end end if x[k+1] + umin < vmin - lambda - y[k0:kmin] .= vmin; - kmin +=1; - k = k0 =kplus =kmin; + y[k0:kmin] .= vmin + kmin +=1 + k = k0 =kplus =kmin vmin = x[k]; vmax = x[k] + 2*lambda; umin = lambda; umax = -lambda; elseif x[k+1] + umax > vmax + lambda - y[k0:kplus] .= vmax; - kplus +=1; - k = k0 =kmin = kplus; + y[k0:kplus] .= vmax + kplus +=1 + k = k0 =kmin = kplus vmin = x[k] - 2*lambda; vmax = x[k]; umin = lambda; umax = -lambda; else k +=1 ; - umin = umin + x[k] - vmin; - umax = umax + x[k] - vmax; + umin = umin + x[k] - vmin + umax = umax + x[k] - vmax if umin >= lambda - vmin = vmin + (umin-lambda)/(k-k0+1); + vmin = vmin + (umin-lambda)/(k-k0+1) umin = lambda; kmin = k; end if umax <= -lambda - vmax += (umax+lambda)/(k-k0+1); + vmax += (umax+lambda)/(k-k0+1) umax = -lambda; kplus = k; end end diff --git a/test/test_optimality_conditions.jl b/test/test_optimality_conditions.jl index 161413f..994e61b 100644 --- a/test/test_optimality_conditions.jl +++ b/test/test_optimality_conditions.jl @@ -157,9 +157,9 @@ test_cases = [ ), Dict( - "f" => TotalVariation1D(1.), - "x" => randn(5), - "gamma" => 0.1 + rand(), + "f" => TotalVariation1D(0.01), + "x" => vcat(LinRange(1., -1., 10), -1*ones(3), LinRange(-1., 1., 10)), + "gamma" => 1., ), Dict( From 58daebf7b65059b38e12dec2956c742d6bbb92bb Mon Sep 17 00:00:00 2001 From: fabian-sp Date: Fri, 6 Aug 2021 12:00:44 +0200 Subject: [PATCH 8/8] add pathological case --- test/test_optimality_conditions.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/test_optimality_conditions.jl b/test/test_optimality_conditions.jl index 994e61b..c66bf7e 100644 --- a/test/test_optimality_conditions.jl +++ b/test/test_optimality_conditions.jl @@ -163,8 +163,13 @@ test_cases = [ ), Dict( - "f" => TotalVariation1D(0.1), - "x" => [0.5, 0.4, 0.3, 0.2], + "f" => TotalVariation1D(1.0), + "x" => [-2.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 2.4375], + "gamma" => 1.0, + ), + Dict( + "f" => TotalVariation1D(1.0), + "x" => [0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 2.4375], "gamma" => 1.0, ), ]