From c65200430eac0f4b6a157bcf08607664988ac7be Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 12 Dec 2023 20:40:02 +0530 Subject: [PATCH 1/2] Pass mutable copies to inplace LinearAlgebra functions --- stdlib/LinearAlgebra/src/generic.jl | 10 +++++----- stdlib/LinearAlgebra/test/generic.jl | 13 +++++++++++++ test/testhelpers/FillArrays.jl | 2 ++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 929618b95ce0a..ca01646820ccf 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -338,7 +338,7 @@ julia> triu(a) 0.0 0.0 0.0 1.0 ``` """ -triu(M::AbstractMatrix) = triu!(copy(M)) +triu(M::AbstractMatrix) = triu!(copymutable(M)) """ tril(M) @@ -362,7 +362,7 @@ julia> tril(a) 1.0 1.0 1.0 1.0 ``` """ -tril(M::AbstractMatrix) = tril!(copy(M)) +tril(M::AbstractMatrix) = tril!(copymutable(M)) """ triu(M, k::Integer) @@ -393,7 +393,7 @@ julia> triu(a,-3) 1.0 1.0 1.0 1.0 ``` """ -triu(M::AbstractMatrix,k::Integer) = triu!(copy(M),k) +triu(M::AbstractMatrix,k::Integer) = triu!(copymutable(M),k) """ tril(M, k::Integer) @@ -424,7 +424,7 @@ julia> tril(a,-3) 1.0 0.0 0.0 0.0 ``` """ -tril(M::AbstractMatrix,k::Integer) = tril!(copy(M),k) +tril(M::AbstractMatrix,k::Integer) = tril!(copymutable(M),k) """ triu!(M) @@ -1760,7 +1760,7 @@ Calculates the determinant of a matrix using the [Bareiss Algorithm](https://en.wikipedia.org/wiki/Bareiss_algorithm). Also refer to [`det_bareiss!`](@ref). """ -det_bareiss(M) = det_bareiss!(copy(M)) +det_bareiss(M) = det_bareiss!(copymutable(M)) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index c79173ad1011a..b8cb15ff695cb 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -15,6 +15,9 @@ using .Main.OffsetArrays isdefined(Main, :DualNumbers) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "DualNumbers.jl")) using .Main.DualNumbers +isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl")) +using .Main.FillArrays + Random.seed!(123) n = 5 # should be odd @@ -647,4 +650,14 @@ end end end +@testset "immutable arrays" begin + A = FillArrays.Fill(big(3), (4, 4)) + M = Array(A) + @test triu(A) == triu(M) + @test triu(A, -1) == triu(M, -1) + @test tril(A) == tril(M) + @test tril(A, 1) == tril(M, 1) + @test det(A) == det(M) +end + end # module TestGeneric diff --git a/test/testhelpers/FillArrays.jl b/test/testhelpers/FillArrays.jl index 1f36a77bf8c12..7ba18f22307ca 100644 --- a/test/testhelpers/FillArrays.jl +++ b/test/testhelpers/FillArrays.jl @@ -9,6 +9,8 @@ Fill(v, size::Vararg{Integer}) = Fill(v, size) Base.size(F::Fill) = F.size +Base.copy(F::Fill) = F + @inline getindex_value(F::Fill) = F.value @inline function Base.getindex(F::Fill{<:Any,N}, i::Vararg{Int,N}) where {N} From 242fabd2021c131eb464c5086c96dd7ebd261a60 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 13 Dec 2023 11:53:45 +0530 Subject: [PATCH 2/2] Access filled indices in inplace triu for UpperTri --- stdlib/LinearAlgebra/src/triangular.jl | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 6d703b2436b23..f057133a2533c 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -332,7 +332,15 @@ function tril!(A::UpperTriangular{T}, k::Integer=0) where {T} return UpperTriangular(tril!(A.data,k)) end end -triu!(A::UpperTriangular, k::Integer=0) = UpperTriangular(triu!(A.data, k)) +function triu!(A::UpperTriangular, k::Integer=0) + n = size(A,1) + if k > 0 + for j in 1:n, i in max(1,j-k+1):j + A.data[i,j] = zero(eltype(A)) + end + end + return A +end function tril!(A::UnitUpperTriangular{T}, k::Integer=0) where {T} n = size(A,1) @@ -375,7 +383,15 @@ function triu!(A::LowerTriangular{T}, k::Integer=0) where {T} end end -tril!(A::LowerTriangular, k::Integer=0) = LowerTriangular(tril!(A.data, k)) +function tril!(A::LowerTriangular, k::Integer=0) + n = size(A,1) + if k < 0 + for j in 1:n, i in j:min(j-k-1,n) + A.data[i, j] = zero(eltype(A)) + end + end + A +end function triu!(A::UnitLowerTriangular{T}, k::Integer=0) where T n = size(A,1)