Skip to content

Commit

Permalink
add p-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 22, 2021
1 parent 82b2842 commit c6bc594
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,53 +648,59 @@ end
norm(::Missing, p::Real=2) = missing

# With dims keyword
norm0_dims!(B, A, dims) = count!(!iszero, B, A)
norm1_dims!(B, A, dims) = Base.mapreducedim!(norm, +, B, A)
normInf_dims!(B, A, dims) = Base.mapreducedim!(norm, max, B, A)
normMinusInf_dims!(B, A, dims) = Base.mapreducedim!(norm, min, B, A)
norm0_dims!(B, A) = count!(!iszero, B, A)
norm1_dims!(B, A) = Base.mapreducedim!(norm, +, B, A)
normInf_dims!(B, A) = Base.mapreducedim!(norm, max, B, A)
normMinusInf_dims!(B, A) = Base.mapreducedim!(norm, min, B, A)

function norm2_dims!(B::AbstractArray, A::AbstractArray, dims)
sum!(norm_sqr, B, A)
sum!(LinearAlgebra.norm_sqr, B, A)
map!(sqrt, B, B)
# Checking whether `A` is safe for the fast path is slower than taking it, check later:
# Checking whether `A` is safe for the fast path is slower than taking it,
# so check and fix any zero/infinite answers afterwards:
_norm_dims_check!(B, A, dims, LinearAlgebra.norm2)
B
end

function normp_dims!(B::AbstractArray, A::AbstractArray, p::Real, dims)
if p == 0.5
sum!(sqrt norm, B, A)
map!(abs2, B, B)
elseif p == 3
sum!(x -> norm(x)^3, B, A)
map!(cbrt, B, B)
else
sum!(x -> norm(x)^p, B, A)
invp = inv(p)
map!(x -> x^invp, B, B)
end
_norm_dims_check!(B, A, dims, LinearAlgebra.normp, p)
B
end

function _norm_dims_check!(B, A, dims, norm, args...)
if A isa AbstractVecOrMat && dims == 1
for (i,x) in zip(eachindex(B), eachcol(A))
!iszero(B[i]) && isfinite(B[i]) && continue
B[i] = norm2(x)
B[i] = norm(x, args...)
end
elseif A isa AbstractVecOrMat && dims == 2
for (i,x) in zip(eachindex(B), eachrow(A))
!iszero(B[i]) && isfinite(B[i]) && continue
B[i] = norm2(x)
B[i] = norm(x, args...)
end
# In general `eachslice(A; dims)` is not what we need here.
elseif all(y -> !iszero(y) && isfinite(y), B)
for I in CartesianIndices(B)
!iszero(B[I]) && isfinite(B[I]) && continue
# This path is quite slow, but hopefully rare.
# Unfortunately `eachslice(A; dims)` is not what we need here.
# This path is not type-stable, so quite slow, but hopefully rare.
J = ntuple(d -> d in dims ? Colon() : I[d], ndims(A))
B[I] = norm2(view(A, J...))
B[I] = norm(view(A, J...), args...)
end
end
B
end

function normp_dims!(B::AbstractArray, A::AbstractArray, p::Real, dims)
if A isa AbstractVecOrMat && dims == 1
for (i,x) in zip(eachindex(B), eachcol(A))
B[i] = normp(x, p)
end
elseif A isa AbstractVecOrMat && dims == 2
for (i,x) in zip(eachindex(B), eachrow(A))
B[i] = normp(x, p)
end
else
# This is slower, but doesn't affect type-stability of `norm`
copyto!(B, Base.mapslices(x -> normp(x,p), A; dims))
end
B
end

"""
norm(A::AbstractArray, [p]; dims)
Expand Down Expand Up @@ -758,13 +764,13 @@ function norm(A::AbstractArray, p::Real=2; dims=:)
if p == 2
norm2_dims!(B, A, dims)
elseif p == 1
norm1_dims!(B, A, dims)
norm1_dims!(B, A)
elseif p == Inf
normInf_dims!(B, A, dims)
normInf_dims!(B, A)
elseif p == 0
norm0_dims!(B, A, dims)
norm0_dims!(B, A)
elseif p == -Inf
normMinusInf_dims!(B, A, dims)
normMinusInf_dims!(B, A)
else
normp_dims!(B, A, p, dims)
end
Expand Down

0 comments on commit c6bc594

Please sign in to comment.