From 7708986aa7e5992228a819bc3709d51fa8bd673c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 2 Oct 2024 18:35:49 +0530 Subject: [PATCH] Specialize triu/tril for StaticMatrix (#1241) --- src/StaticArrays.jl | 3 ++- src/linalg.jl | 34 ++++++++++++++++++++++++++++++++++ test/linalg.jl | 10 ++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index c4a0437e..be48eced 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -16,7 +16,8 @@ using LinearAlgebra import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr, kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv, factorize, ishermitian, issymmetric, isposdef, issuccess, normalize, - normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \ + normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \, + triu, tril using LinearAlgebra: checksquare using PrecompileTools diff --git a/src/linalg.jl b/src/linalg.jl index a06b263c..2fe43081 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -522,3 +522,37 @@ end # Some shimming for special linear algebra matrix types @inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo)) @inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo)) + +# triu/tril +function triu(S::StaticMatrix, k::Int=0) + if length(S) <= 32 + C = CartesianIndices(S) + t = Tuple(S) + for (linind, CI) in enumerate(C) + i, j = Tuple(CI) + if j-i < k + t = Base.setindex(t, zero(t[linind]), linind) + end + end + similar_type(S)(t) + else + M = triu!(copyto!(similar(S), S), k) + similar_type(S)(M) + end +end +function tril(S::StaticMatrix, k::Int=0) + if length(S) <= 32 + C = CartesianIndices(S) + t = Tuple(S) + for (linind, CI) in enumerate(C) + i, j = Tuple(CI) + if j-i > k + t = Base.setindex(t, zero(t[linind]), linind) + end + end + similar_type(S)(t) + else + M = tril!(copyto!(similar(S), S), k) + similar_type(S)(M) + end +end diff --git a/test/linalg.jl b/test/linalg.jl index 04b901dd..4e577bb1 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -471,4 +471,14 @@ end m23 = SA[1 2 3; 4 5 6] @test_inlined checksquare(m23) false end + + @testset "triu/tril" begin + for S in (SMatrix{7,5}(1:35), MMatrix{4,6}(1:24), SizedArray{Tuple{2,2}}([1 2; 3 4])) + M = Matrix(S) + for k in -10:10 + @test triu(S, k) == triu(M, k) + @test tril(S, k) == tril(M, k) + end + end + end end