diff --git a/base/linalg/lu.jl b/base/linalg/lu.jl index 8e6630a00bd41..8e0c0451f41b4 100644 --- a/base/linalg/lu.jl +++ b/base/linalg/lu.jl @@ -241,34 +241,55 @@ function show(io::IO, F::LU) print(io, "\nsuccessful: $(issuccess(F))") end +_apply_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, 1 : length(A.ipiv), B) +_apply_inverse_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, length(A.ipiv) : -1 : 1, B) + +function _ipiv!(A::LU, order::OrdinalRange, B::StridedVecOrMat) + for i = order + if i != A.ipiv[i] + _swap_rows!(B, i, A.ipiv[i]) + end + end + B +end + +function _swap_rows!(B::StridedVector, i::Integer, j::Integer) + B[i], B[j] = B[j], B[i] + B +end + +function _swap_rows!(B::StridedMatrix, i::Integer, j::Integer) + for col = 1 : size(B, 2) + B[i,col], B[j,col] = B[j,col], B[i,col] + end + B +end + A_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = @assertnonsingular LAPACK.getrs!('N', A.factors, A.ipiv, B) A.info -A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) = - A_ldiv_B!(UpperTriangular(A.factors), - A_ldiv_B!(UnitLowerTriangular(A.factors), b[ipiv2perm(A.ipiv, length(b))])) -A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) = - A_ldiv_B!(UpperTriangular(A.factors), - A_ldiv_B!(UnitLowerTriangular(A.factors), B[ipiv2perm(A.ipiv, size(B, 1)),:])) + +function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat) + _apply_ipiv!(A, B) + A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), B)) +end At_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = @assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, B) A.info -At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) = - At_ldiv_B!(UnitLowerTriangular(A.factors), - At_ldiv_B!(UpperTriangular(A.factors), b))[invperm(ipiv2perm(A.ipiv, length(b)))] -At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) = - At_ldiv_B!(UnitLowerTriangular(A.factors), - At_ldiv_B!(UpperTriangular(A.factors), B))[invperm(ipiv2perm(A.ipiv, size(B,1))),:] + +function At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat) + At_ldiv_B!(UnitLowerTriangular(A.factors), At_ldiv_B!(UpperTriangular(A.factors), B)) + _apply_inverse_ipiv!(A, B) +end Ac_ldiv_B!(F::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:Real} = At_ldiv_B!(F, B) Ac_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasComplex} = @assertnonsingular LAPACK.getrs!('C', A.factors, A.ipiv, B) A.info -Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) = - Ac_ldiv_B!(UnitLowerTriangular(A.factors), - Ac_ldiv_B!(UpperTriangular(A.factors), b))[invperm(ipiv2perm(A.ipiv, length(b)))] -Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) = - Ac_ldiv_B!(UnitLowerTriangular(A.factors), - Ac_ldiv_B!(UpperTriangular(A.factors), B))[invperm(ipiv2perm(A.ipiv, size(B,1))),:] + +function Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat) + Ac_ldiv_B!(UnitLowerTriangular(A.factors), Ac_ldiv_B!(UpperTriangular(A.factors), B)) + _apply_inverse_ipiv!(A, B) +end At_ldiv_Bt(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = @assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, transpose(B)) A.info diff --git a/test/linalg/lu.jl b/test/linalg/lu.jl index 6730a6cd07fd6..dd5a84b7904d1 100644 --- a/test/linalg/lu.jl +++ b/test/linalg/lu.jl @@ -84,11 +84,30 @@ dimg = randn(n)/2 @test norm(a*(lua\c) - c, 1) < ε*κ*n # c is a vector @test norm(a'*(lua'\c) - c, 1) < ε*κ*n # c is a vector @test AbstractArray(lua) ≈ a - if eltya <: Real && eltyb <: Real - @test norm(a.'*(lua.'\b) - b,1) < ε*κ*n*2 # Two because the right hand side has two columns - @test norm(a.'*(lua.'\c) - c,1) < ε*κ*n - end + @test norm(a.'*(lua.'\b) - b,1) < ε*κ*n*2 # Two because the right hand side has two columns + @test norm(a.'*(lua.'\c) - c,1) < ε*κ*n end + + # Test whether Ax_ldiv_B!(y, LU, x) indeed overwrites y + resultT = typeof(oneunit(eltyb) / oneunit(eltya)) + + b_dest = similar(b, resultT) + c_dest = similar(c, resultT) + + A_ldiv_B!(b_dest, lua, b) + A_ldiv_B!(c_dest, lua, c) + @test norm(b_dest - lua \ b, 1) < ε*κ*2n + @test norm(c_dest - lua \ c, 1) < ε*κ*n + + At_ldiv_B!(b_dest, lua, b) + At_ldiv_B!(c_dest, lua, c) + @test norm(b_dest - lua.' \ b, 1) < ε*κ*2n + @test norm(c_dest - lua.' \ c, 1) < ε*κ*n + + Ac_ldiv_B!(b_dest, lua, b) + Ac_ldiv_B!(c_dest, lua, c) + @test norm(b_dest - lua' \ b, 1) < ε*κ*2n + @test norm(c_dest - lua' \ c, 1) < ε*κ*n end if eltya <: BlasFloat && eltyb <: BlasFloat e = rand(eltyb,n,n)