diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index ef815b3ad708b..058b1992f6625 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -414,6 +414,9 @@ switch_dim12(B::AbstractArray) = PermutedDimsArray(B, (2, 1, ntuple(Base.Fix1(+, (-)(A::Adjoint) = Adjoint( -A.parent) (-)(A::Transpose) = Transpose(-A.parent) +tr(A::Adjoint) = adjoint(tr(parent(A))) +tr(A::Transpose) = transpose(tr(parent(A))) + ## multiplication * function _dot_nonrecursive(u, v) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index a452fe43987d4..218fd67a1b9d2 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -208,6 +208,8 @@ convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m) similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo) similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...) +tr(B::Bidiagonal) = sum(B.dv) + function kron(A::Diagonal, B::Bidiagonal) # `_droplast!` is only guaranteed to work with `Vector` kdv = _makevector(kron(diag(A), B.dv)) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 12a77d7a662d9..7f5e44382f5c5 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -344,7 +344,7 @@ diagm(m::Integer, n::Integer, v::AbstractVector) = diagm(m, n, 0 => v) function tr(A::Matrix{T}) where T n = checksquare(A) t = zero(T) - for i=1:n + @inbounds @simd for i in 1:n t += A[i,i] end t diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 24da88ad20e87..038188139aa30 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -363,6 +363,7 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) = Base.copy(A::Transpose{<:Any,<:Hermitian}) = Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U)) +tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) tr(A::Hermitian) = real(tr(A.data)) Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 0879340220482..bd21471deeedd 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -428,6 +428,11 @@ function -(A::UnitUpperTriangular) UpperTriangular(Anew) end +tr(A::LowerTriangular) = tr(A.data) +tr(A::UnitLowerTriangular) = size(A, 1) * oneunit(eltype(A)) +tr(A::UpperTriangular) = tr(A.data) +tr(A::UnitUpperTriangular) = size(A, 1) * oneunit(eltype(A)) + # copy and scale function copyto!(A::T, B::T) where T<:Union{UpperTriangular,UnitUpperTriangular} n = size(B,1) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 649ab12ab5034..e43e9e699e3a9 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -172,6 +172,8 @@ Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(ad ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S)) issymmetric(S::SymTridiagonal) = true +tr(S::SymTridiagonal) = sum(S.dv) + function diag(M::SymTridiagonal{T}, n::Integer=0) where T<:Number # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n @@ -747,6 +749,8 @@ function triu!(M::Tridiagonal{T}, k::Integer=0) where T return M end +tr(M::Tridiagonal) = sum(M.d) + ################### # Generic methods # ################### diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index e96ea28531d37..7479057d9f027 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -636,4 +636,11 @@ end @test mapreduce(string, *, [1 2; 3 4]') == mapreduce(string, *, copy([1 2; 3 4]')) == "1234" end +@testset "trace" begin + for T in (Float64, ComplexF64), t in (adjoint, transpose) + A = randn(T, 10, 10) + @test tr(t(A)) == tr(copy(t(A))) == t(tr(A)) + end +end + end # module TestAdjointTranspose diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 22c070be13cb5..9866fce047dd1 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -218,6 +218,17 @@ Random.seed!(1) end end + @testset "trace" begin + for uplo in (:U, :L) + B = Bidiagonal(dv, ev, uplo) + if relty <: Integer + @test tr(B) == tr(Matrix(B)) + else + @test tr(B) ≈ tr(Matrix(B)) rtol=2eps(relty) + end + end + end + Tfull = Array(T) @testset "Linear solves" begin if relty <: AbstractFloat diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 4475dde1e543b..8c9f6494205a6 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -169,6 +169,9 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo # diag @test diag(A1) == diag(Matrix(A1)) + # tr + @test tr(A1)::elty1 == tr(Matrix(A1)) + # real @test real(A1) == real(Matrix(A1)) @test imag(A1) == imag(Matrix(A1)) diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 0fcd8744142be..590870d4dad0a 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -263,6 +263,13 @@ end @test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d) @test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl) end + @testset "trace" begin + if real(elty) <: Integer + @test tr(A) == tr(fA) + else + @test tr(A) ≈ tr(fA) rtol=2eps(real(elty)) + end + end @testset "Idempotent tests" begin for func in (conj, transpose, adjoint) @test func(func(A)) == A