From 04ba8e0790fba1408d2c2e85e8a8c203c2e9c966 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Mon, 9 Nov 2020 14:36:36 +0100 Subject: [PATCH] Fix tall qr multiplication (#38002) (#38360) (cherry picked from commit 24750c67ff9bf7b6fd929e393a9821ca97a79540) Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/qr.jl | 14 +++++++++++++- stdlib/LinearAlgebra/test/qr.jl | 7 +++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/qr.jl b/stdlib/LinearAlgebra/src/qr.jl index 995aa1ce8c895..f72ded31102a3 100644 --- a/stdlib/LinearAlgebra/src/qr.jl +++ b/stdlib/LinearAlgebra/src/qr.jl @@ -749,7 +749,19 @@ function *(adjA::Adjoint{<:Any,<:StridedVecOrMat}, adjQ::Adjoint{<:Any,<:Abstrac end ### mul! -mul!(C::StridedVecOrMat{T}, Q::AbstractQ{T}, B::StridedVecOrMat{T}) where {T} = lmul!(Q, copyto!(C, B)) +function mul!(C::StridedVecOrMat{T}, Q::AbstractQ{T}, B::StridedVecOrMat{T}) where {T} + require_one_based_indexing(C, B) + mB = size(B, 1) + mC = size(C, 1) + if mB < mC + inds = CartesianIndices(B) + copyto!(C, inds, B, inds) + C[CartesianIndices((mB+1:mC, axes(C, 2)))] .= zero(T) + return lmul!(Q, C) + else + return lmul!(Q, copyto!(C, B)) + end +end mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, Q::AbstractQ{T}) where {T} = rmul!(copyto!(C, A), Q) mul!(C::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}, B::StridedVecOrMat{T}) where {T} = lmul!(adjQ, copyto!(C, B)) mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}) where {T} = rmul!(copyto!(C, A), adjQ) diff --git a/stdlib/LinearAlgebra/test/qr.jl b/stdlib/LinearAlgebra/test/qr.jl index 3d6247798067e..ee8d0c4bdeb8a 100644 --- a/stdlib/LinearAlgebra/test/qr.jl +++ b/stdlib/LinearAlgebra/test/qr.jl @@ -180,8 +180,12 @@ rectangularQ(Q::LinearAlgebra.AbstractQ) = convert(Array, Q) b = similar(a); rand!(b) c = similar(a) + d = similar(a[:,1:n1]) @test mul!(c, q, b) ≈ q*b + @test mul!(d, q, r) ≈ q*r ≈ a[:,qrpa.p] @test mul!(c, q', b) ≈ q'*b + @test mul!(d, q', a[:,qrpa.p])[1:n1,:] ≈ r + @test all(x -> abs(x) < ε*norm(a), d[n1+1:end,:]) @test mul!(c, b, q) ≈ b*q @test mul!(c, b, q') ≈ b*q' @test_throws DimensionMismatch mul!(Matrix{eltya}(I, n+1, n), q, b) @@ -196,7 +200,10 @@ rectangularQ(Q::LinearAlgebra.AbstractQ) = convert(Array, Q) @test_throws DimensionMismatch q * Matrix{Int8}(I, n+4, n+4) @test mul!(c, q, b) ≈ q*b + @test mul!(d, q, r) ≈ a[:,1:n1] @test mul!(c, q', b) ≈ q'*b + @test mul!(d, q', a[:,1:n1])[1:n1,:] ≈ r + @test all(x -> abs(x) < ε*norm(a), d[n1+1:end,:]) @test mul!(c, b, q) ≈ b*q @test mul!(c, b, q') ≈ b*q' @test_throws DimensionMismatch mul!(Matrix{eltya}(I, n+1, n), q, b)