Skip to content

Commit

Permalink
Fix tall qr multiplication (#38002) (#38360)
Browse files Browse the repository at this point in the history
(cherry picked from commit 24750c6)

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
KristofferC and dkarrasch authored Nov 9, 2020
1 parent 006853a commit 04ba8e0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
14 changes: 13 additions & 1 deletion stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,19 @@ function *(adjA::Adjoint{<:Any,<:StridedVecOrMat}, adjQ::Adjoint{<:Any,<:Abstrac
end

### mul!
mul!(C::StridedVecOrMat{T}, Q::AbstractQ{T}, B::StridedVecOrMat{T}) where {T} = lmul!(Q, copyto!(C, B))
function mul!(C::StridedVecOrMat{T}, Q::AbstractQ{T}, B::StridedVecOrMat{T}) where {T}
require_one_based_indexing(C, B)
mB = size(B, 1)
mC = size(C, 1)
if mB < mC
inds = CartesianIndices(B)
copyto!(C, inds, B, inds)
C[CartesianIndices((mB+1:mC, axes(C, 2)))] .= zero(T)
return lmul!(Q, C)
else
return lmul!(Q, copyto!(C, B))
end
end
mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, Q::AbstractQ{T}) where {T} = rmul!(copyto!(C, A), Q)
mul!(C::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}, B::StridedVecOrMat{T}) where {T} = lmul!(adjQ, copyto!(C, B))
mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}) where {T} = rmul!(copyto!(C, A), adjQ)
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,12 @@ rectangularQ(Q::LinearAlgebra.AbstractQ) = convert(Array, Q)

b = similar(a); rand!(b)
c = similar(a)
d = similar(a[:,1:n1])
@test mul!(c, q, b) q*b
@test mul!(d, q, r) q*r a[:,qrpa.p]
@test mul!(c, q', b) q'*b
@test mul!(d, q', a[:,qrpa.p])[1:n1,:] r
@test all(x -> abs(x) < ε*norm(a), d[n1+1:end,:])
@test mul!(c, b, q) b*q
@test mul!(c, b, q') b*q'
@test_throws DimensionMismatch mul!(Matrix{eltya}(I, n+1, n), q, b)
Expand All @@ -196,7 +200,10 @@ rectangularQ(Q::LinearAlgebra.AbstractQ) = convert(Array, Q)
@test_throws DimensionMismatch q * Matrix{Int8}(I, n+4, n+4)

@test mul!(c, q, b) q*b
@test mul!(d, q, r) a[:,1:n1]
@test mul!(c, q', b) q'*b
@test mul!(d, q', a[:,1:n1])[1:n1,:] r
@test all(x -> abs(x) < ε*norm(a), d[n1+1:end,:])
@test mul!(c, b, q) b*q
@test mul!(c, b, q') b*q'
@test_throws DimensionMismatch mul!(Matrix{eltya}(I, n+1, n), q, b)
Expand Down

0 comments on commit 04ba8e0

Please sign in to comment.