From dfd70754f93ba043b9c1d4460183810ea51205f6 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 11 Mar 2022 20:44:58 +0100 Subject: [PATCH] Fix: allow weights in SqrNormL2 to be zero (#140) --- src/functions/sqrNormL2.jl | 16 ++++++++-------- test/test_calls.jl | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/functions/sqrNormL2.jl b/src/functions/sqrNormL2.jl index 546800e..0069c59 100644 --- a/src/functions/sqrNormL2.jl +++ b/src/functions/sqrNormL2.jl @@ -5,20 +5,20 @@ export SqrNormL2 """ SqrNormL2(λ=1) -With a positive scalar `λ`, return the squared Euclidean norm +With a nonnegative scalar `λ`, return the squared Euclidean norm ```math f(x) = \\tfrac{λ}{2}\\|x\\|^2. ``` -With a positive array `λ`, return the weighted squared Euclidean norm +With a nonnegative array `λ`, return the weighted squared Euclidean norm ```math f(x) = \\tfrac{1}{2}∑_i λ_i x_i^2. ``` """ -struct SqrNormL2{T} +struct SqrNormL2{T,SC} lambda::T - function SqrNormL2{T}(lambda::T) where T - if any(lambda .<= 0) - error("coefficients in λ must be positive") + function SqrNormL2{T,SC}(lambda::T) where {T,SC} + if any(lambda .< 0) + error("coefficients in λ must be nonnegative") else new(lambda) end @@ -29,9 +29,9 @@ is_convex(f::Type{<:SqrNormL2}) = true is_smooth(f::Type{<:SqrNormL2}) = true is_separable(f::Type{<:SqrNormL2}) = true is_generalized_quadratic(f::Type{<:SqrNormL2}) = true -is_strongly_convex(f::Type{<:SqrNormL2}) = true +is_strongly_convex(f::Type{SqrNormL2{T,SC}}) where {T,SC} = SC -SqrNormL2(lambda::T=1) where T = SqrNormL2{T}(lambda) +SqrNormL2(lambda::T=1) where T = SqrNormL2{T,all(lambda .> 0)}(lambda) function (f::SqrNormL2{S})(x) where {S <: Real} return f.lambda / real(eltype(x))(2) * norm(x)^2 diff --git a/test/test_calls.jl b/test/test_calls.jl index dc5b45b..ee24121 100644 --- a/test/test_calls.jl +++ b/test/test_calls.jl @@ -442,6 +442,7 @@ test_cases_spec = [ ( (), randn(Float64, 10) ), ( (rand(Float32),), randn(Float32, 10) ), ( (rand(Float64),), randn(Float64, 10) ), + ( (Float64[1, 1, 1, 1, 0],), randn(Float64, 5) ), ( (rand(Float32, 20),), randn(Float32, 20) ), ( (rand(Float64, 20),), randn(Float64, 20) ), ( (rand(30),), rand(Complex{Float64}, 30) ),