Skip to content

Commit

Permalink
also do linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Sep 20, 2018
1 parent 79e851e commit 5eb628a
Showing 1 changed file with 31 additions and 56 deletions.
87 changes: 31 additions & 56 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,87 +224,65 @@ end

@inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b)
@generated function _vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
if prod(S) == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(conj(a[1]) * b[1])
for j = 2:prod(S)
expr = :($expr + conj(a[$j]) * b[$j])
end

return quote
@_inline_meta
@inbounds return $expr
s = zero(promote_op(*, eltype(a), eltype(b)))
@inbounds @simd for j = 1:$(prod(S))
s += conj(a[j]) * b[j]
end
return s
end
end

@inline bilinear_vecdot(a::StaticArray, b::StaticArray) = _bilinear_vecdot(same_size(a, b), a, b)
@generated function _bilinear_vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
if prod(S) == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(a[1] * b[1])
for j = 2:prod(S)
expr = :($expr + a[$j] * b[$j])
end

return quote
@_inline_meta
@inbounds return $expr
s = zero(promote_op(*, eltype(a), eltype(b)))
@inbounds @simd for j = 1:$(prod(S))
s += a[j] * b[j]
end
return s
end
end

@inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v))))

@inline norm(a::StaticArray) = _norm(Size(a), a)
@generated function _norm(::Size{S}, a::StaticArray) where {S}
if prod(S) == 0
return zero(real(eltype(a)))
end

expr = :(abs2(a[1]))
for j = 2:prod(S)
expr = :($expr + abs2(a[$j]))
end

return quote
$(Expr(:meta, :inline))
@inbounds return sqrt($expr)
@_inline_meta
s = zero(real(eltype(a)))
@inbounds @simd for j = 1:$(prod(S))
s += abs2(a[j])
end
return sqrt(s)
end
end

_norm_p0(x) = x == 0 ? zero(x) : one(x)

@inline norm(a::StaticArray, p::Real) = _norm(Size(a), a, p)
@generated function _norm(::Size{S}, a::StaticArray, p::Real) where {S}
if prod(S) == 0
return zero(real(eltype(a)))
end

expr = :(abs(a[1])^p)
for j = 2:prod(S)
expr = :($expr + abs(a[$j])^p)
end

expr_p1 = :(abs(a[1]))
for j = 2:prod(S)
expr_p1 = :($expr_p1 + abs(a[$j]))
end

return quote
$(Expr(:meta, :inline))
@_inline_meta
s = zero(real(eltype(a)))
if p == Inf
return mapreduce(abs, max, a; init=$(zero(real(eltype(a)))))
elseif p == 1
@inbounds return $expr_p1
@inbounds @simd for j = 1:$(prod(S))
s += abs(a[j])
end
return s
elseif p == 2
return norm(a)
elseif p == 0
return mapreduce(_norm_p0, +, a; init=$(zero(real(eltype(a)))))
else
@inbounds return ($expr)^(inv(p))
@inbounds @simd for j = 1:$(prod(S))
s += abs(a[j])^p
end
return s^(inv(p))
end
end
end
Expand All @@ -321,16 +299,13 @@ end
throw(DimensionMismatch("matrix is not square"))
end

if S[1] == 0
return zero(eltype(a))
end

exprs = [:(a[$(LinearIndices(S)[i, i])]) for i = 1:S[1]]
total = reduce((ex1, ex2) -> :($ex1 + $ex2), exprs)

return quote
@_inline_meta
@inbounds return $total
s = zero(eltype(a))
@inbounds @simd for i in 1:$(S[1])
s += a[i,i]
end
return s
end
end

Expand Down

0 comments on commit 5eb628a

Please sign in to comment.