Skip to content

Commit

Permalink
Fix: allow weights in SqrNormL2 to be zero (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Mar 11, 2022
1 parent 7c029e9 commit dfd7075
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/functions/sqrNormL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ export SqrNormL2
"""
SqrNormL2(λ=1)
With a positive scalar `λ`, return the squared Euclidean norm
With a nonnegative scalar `λ`, return the squared Euclidean norm
```math
f(x) = \\tfrac{λ}{2}\\|x\\|^2.
```
With a positive array `λ`, return the weighted squared Euclidean norm
With a nonnegative array `λ`, return the weighted squared Euclidean norm
```math
f(x) = \\tfrac{1}{2}∑_i λ_i x_i^2.
```
"""
struct SqrNormL2{T}
struct SqrNormL2{T,SC}
lambda::T
function SqrNormL2{T}(lambda::T) where T
if any(lambda .<= 0)
error("coefficients in λ must be positive")
function SqrNormL2{T,SC}(lambda::T) where {T,SC}
if any(lambda .< 0)
error("coefficients in λ must be nonnegative")
else
new(lambda)
end
Expand All @@ -29,9 +29,9 @@ is_convex(f::Type{<:SqrNormL2}) = true
is_smooth(f::Type{<:SqrNormL2}) = true
is_separable(f::Type{<:SqrNormL2}) = true
is_generalized_quadratic(f::Type{<:SqrNormL2}) = true
is_strongly_convex(f::Type{<:SqrNormL2}) = true
is_strongly_convex(f::Type{SqrNormL2{T,SC}}) where {T,SC} = SC

SqrNormL2(lambda::T=1) where T = SqrNormL2{T}(lambda)
SqrNormL2(lambda::T=1) where T = SqrNormL2{T,all(lambda .> 0)}(lambda)

function (f::SqrNormL2{S})(x) where {S <: Real}
return f.lambda / real(eltype(x))(2) * norm(x)^2
Expand Down
1 change: 1 addition & 0 deletions test/test_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ test_cases_spec = [
( (), randn(Float64, 10) ),
( (rand(Float32),), randn(Float32, 10) ),
( (rand(Float64),), randn(Float64, 10) ),
( (Float64[1, 1, 1, 1, 0],), randn(Float64, 5) ),
( (rand(Float32, 20),), randn(Float32, 20) ),
( (rand(Float64, 20),), randn(Float64, 20) ),
( (rand(30),), rand(Complex{Float64}, 30) ),
Expand Down

0 comments on commit dfd7075

Please sign in to comment.