Skip to content

Commit

Permalink
LinearAlgebra: Speed up the trace function (#47585)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Nov 24, 2022
1 parent 113efb6 commit 25b2746
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 1 deletion.
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ switch_dim12(B::AbstractArray) = PermutedDimsArray(B, (2, 1, ntuple(Base.Fix1(+,
(-)(A::Adjoint) = Adjoint( -A.parent)
(-)(A::Transpose) = Transpose(-A.parent)

tr(A::Adjoint) = adjoint(tr(parent(A)))
tr(A::Transpose) = transpose(tr(parent(A)))

## multiplication *

function _dot_nonrecursive(u, v)
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)
similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo)
similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...)

tr(B::Bidiagonal) = sum(B.dv)

function kron(A::Diagonal, B::Bidiagonal)
# `_droplast!` is only guaranteed to work with `Vector`
kdv = _makevector(kron(diag(A), B.dv))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ diagm(m::Integer, n::Integer, v::AbstractVector) = diagm(m, n, 0 => v)
function tr(A::Matrix{T}) where T
n = checksquare(A)
t = zero(T)
for i=1:n
@inbounds @simd for i in 1:n
t += A[i,i]
end
t
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Base.copy(A::Transpose{<:Any,<:Hermitian}) =
Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))

tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian) = real(tr(A.data))

Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
Expand Down
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ function -(A::UnitUpperTriangular)
UpperTriangular(Anew)
end

tr(A::LowerTriangular) = tr(A.data)
tr(A::UnitLowerTriangular) = size(A, 1) * oneunit(eltype(A))
tr(A::UpperTriangular) = tr(A.data)
tr(A::UnitUpperTriangular) = size(A, 1) * oneunit(eltype(A))

# copy and scale
function copyto!(A::T, B::T) where T<:Union{UpperTriangular,UnitUpperTriangular}
n = size(B,1)
Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(ad
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
issymmetric(S::SymTridiagonal) = true

tr(S::SymTridiagonal) = sum(S.dv)

function diag(M::SymTridiagonal{T}, n::Integer=0) where T<:Number
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down Expand Up @@ -747,6 +749,8 @@ function triu!(M::Tridiagonal{T}, k::Integer=0) where T
return M
end

tr(M::Tridiagonal) = sum(M.d)

###################
# Generic methods #
###################
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -636,4 +636,11 @@ end
@test mapreduce(string, *, [1 2; 3 4]') == mapreduce(string, *, copy([1 2; 3 4]')) == "1234"
end

@testset "trace" begin
for T in (Float64, ComplexF64), t in (adjoint, transpose)
A = randn(T, 10, 10)
@test tr(t(A)) == tr(copy(t(A))) == t(tr(A))
end
end

end # module TestAdjointTranspose
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ Random.seed!(1)
end
end

@testset "trace" begin
for uplo in (:U, :L)
B = Bidiagonal(dv, ev, uplo)
if relty <: Integer
@test tr(B) == tr(Matrix(B))
else
@test tr(B) tr(Matrix(B)) rtol=2eps(relty)
end
end
end

Tfull = Array(T)
@testset "Linear solves" begin
if relty <: AbstractFloat
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo
# diag
@test diag(A1) == diag(Matrix(A1))

# tr
@test tr(A1)::elty1 == tr(Matrix(A1))

# real
@test real(A1) == real(Matrix(A1))
@test imag(A1) == imag(Matrix(A1))
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,13 @@ end
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
end
@testset "trace" begin
if real(elty) <: Integer
@test tr(A) == tr(fA)
else
@test tr(A) tr(fA) rtol=2eps(real(elty))
end
end
@testset "Idempotent tests" begin
for func in (conj, transpose, adjoint)
@test func(func(A)) == A
Expand Down

0 comments on commit 25b2746

Please sign in to comment.