diff --git a/src/linalg.jl b/src/linalg.jl index 45273d0ed..1e4267a96 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -224,35 +224,25 @@ end @inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b) @generated function _vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S} - if prod(S) == 0 - return :(zero(promote_op(*, eltype(a), eltype(b)))) - end - - expr = :(conj(a[1]) * b[1]) - for j = 2:prod(S) - expr = :($expr + conj(a[$j]) * b[$j]) - end - return quote @_inline_meta - @inbounds return $expr + s = zero(promote_op(*, eltype(a), eltype(b))) + @inbounds @simd for j = 1:$(prod(S)) + s += conj(a[j]) * b[j] + end + return s end end @inline bilinear_vecdot(a::StaticArray, b::StaticArray) = _bilinear_vecdot(same_size(a, b), a, b) @generated function _bilinear_vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S} - if prod(S) == 0 - return :(zero(promote_op(*, eltype(a), eltype(b)))) - end - - expr = :(a[1] * b[1]) - for j = 2:prod(S) - expr = :($expr + a[$j] * b[$j]) - end - return quote @_inline_meta - @inbounds return $expr + s = zero(promote_op(*, eltype(a), eltype(b))) + @inbounds @simd for j = 1:$(prod(S)) + s += a[j] * b[j] + end + return s end end @@ -260,18 +250,13 @@ end @inline norm(a::StaticArray) = _norm(Size(a), a) @generated function _norm(::Size{S}, a::StaticArray) where {S} - if prod(S) == 0 - return zero(real(eltype(a))) - end - - expr = :(abs2(a[1])) - for j = 2:prod(S) - expr = :($expr + abs2(a[$j])) - end - return quote - $(Expr(:meta, :inline)) - @inbounds return sqrt($expr) + @_inline_meta + s = zero(real(eltype(a))) + @inbounds @simd for j = 1:$(prod(S)) + s += abs2(a[j]) + end + return sqrt(s) end end @@ -279,32 +264,25 @@ _norm_p0(x) = x == 0 ? zero(x) : one(x) @inline norm(a::StaticArray, p::Real) = _norm(Size(a), a, p) @generated function _norm(::Size{S}, a::StaticArray, p::Real) where {S} - if prod(S) == 0 - return zero(real(eltype(a))) - end - - expr = :(abs(a[1])^p) - for j = 2:prod(S) - expr = :($expr + abs(a[$j])^p) - end - - expr_p1 = :(abs(a[1])) - for j = 2:prod(S) - expr_p1 = :($expr_p1 + abs(a[$j])) - end - return quote - $(Expr(:meta, :inline)) + @_inline_meta + s = zero(real(eltype(a))) if p == Inf return mapreduce(abs, max, a; init=$(zero(real(eltype(a))))) elseif p == 1 - @inbounds return $expr_p1 + @inbounds @simd for j = 1:$(prod(S)) + s += abs(a[j]) + end + return s elseif p == 2 return norm(a) elseif p == 0 return mapreduce(_norm_p0, +, a; init=$(zero(real(eltype(a))))) else - @inbounds return ($expr)^(inv(p)) + @inbounds @simd for j = 1:$(prod(S)) + s += abs(a[j])^p + end + return s^(inv(p)) end end end @@ -321,16 +299,13 @@ end throw(DimensionMismatch("matrix is not square")) end - if S[1] == 0 - return zero(eltype(a)) - end - - exprs = [:(a[$(LinearIndices(S)[i, i])]) for i = 1:S[1]] - total = reduce((ex1, ex2) -> :($ex1 + $ex2), exprs) - return quote @_inline_meta - @inbounds return $total + s = zero(eltype(a)) + @inbounds @simd for i in 1:$(S[1]) + s += a[i,i] + end + return s end end