Skip to content

Commit

Permalink
RFC: Extract Bunch-Kaufman factors and use them for printing (#22601)
Browse files Browse the repository at this point in the history
* Add getindex method for extraction of factors in BunchKaufman

Improve BunchKaufman printing

* Adjust bkfact signatures to use Symmetric and Hermitian

Adjust tests

* Wrap LAPACK functions for reconstruction of Bunch-Kaufman with rook pivoting

* Update documentation and use Tridiagonal for storing D
  • Loading branch information
andreasnoack authored Jul 3, 2017
1 parent 4b345c1 commit 709d65e
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 15 deletions.
126 changes: 125 additions & 1 deletion base/linalg/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,135 @@ size(B::BunchKaufman, d::Integer) = size(B.LD, d)
issymmetric(B::BunchKaufman) = B.symmetric
ishermitian(B::BunchKaufman) = !B.symmetric

function _ipiv2perm_bk(v::AbstractVector{T}, maxi::Integer, uplo::Char) where T
p = T[1:maxi;]
uploL = uplo == 'L'
i = uploL ? 1 : maxi
# if uplo == 'U' we construct the permutation backwards
@inbounds while 1 <= i <= length(v)
vi = v[i]
if vi > 0 # the 1x1 blocks
p[i], p[vi] = p[vi], p[i]
i += uploL ? 1 : -1
else # the 2x2 blocks
if uploL
p[i + 1], p[-vi] = p[-vi], p[i + 1]
i += 2
else # 'U'
p[i - 1], p[-vi] = p[-vi], p[i - 1]
i -= 2
end
end
end
return p
end

"""
getindex(B::BunchKaufman, d::Symbol)
Extract the factors of the Bunch-Kaufman factorization `B`. The factorization can take the
two forms `L*D*L'` or `U*D*U'` (or `.'` in the complex symmetric case) where `L` is a
`UnitLowerTriangular` matrix, `U` is a `UnitUpperTriangular`, and `D` is a block diagonal
symmetric or Hermitian matrix with 1x1 or 2x2 blocks. The argument `d` can be
- `:D`: the block diagonal matrix
- `:U`: the upper triangular factor (if factorization is `U*D*U'`)
- `:L`: the lower triangular factor (if factorization is `L*D*L'`)
- `:p`: permutation vector
- `:P`: permutation matrix
# Examples
```jldoctest
julia> A = [1 2 3; 2 1 2; 3 2 1]
3×3 Array{Int64,2}:
1 2 3
2 1 2
3 2 1
julia> F = bkfact(Symmetric(A, :L))
Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}}
D factor:
3×3 Tridiagonal{Float64}:
1.0 3.0 ⋅
3.0 1.0 0.0
⋅ 0.0 -1.0
L factor:
3×3 Base.LinAlg.UnitLowerTriangular{Float64,Array{Float64,2}}:
1.0 0.0 0.0
0.0 1.0 0.0
0.5 0.5 1.0
permutation:
3-element Array{Int64,1}:
1
3
2
successful: true
julia> F[:L]*F[:D]*F[:L].' - A[F[:p], F[:p]]
3×3 Array{Float64,2}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
julia> F = bkfact(Symmetric(A));
julia> F[:U]*F[:D]*F[:U].' - F[:P]*A*F[:P]'
3×3 Array{Float64,2}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
```
"""
function getindex(B::BunchKaufman{T}, d::Symbol) where {T<:BlasFloat}
n = size(B, 1)
if d == :p
return _ipiv2perm_bk(B.ipiv, n, B.uplo)
elseif d == :P
return eye(T, n)[:,invperm(B[:p])]
elseif d == :L || d == :U || d == :D
if B.rook
# syconvf_rook just added to LAPACK 3.7.0. Uncomment and remove error when we distribute LAPACK 3.7.0
# LUD, od = LAPACK.syconvf_rook!(B.uplo, 'C', copy(B.LD), B.ipiv)
throw(ArgumentError("reconstruction rook pivoted Bunch-Kaufman factorization not implemented yet"))
else
LUD, od = LAPACK.syconv!(B.uplo, copy(B.LD), B.ipiv)
end
if d == :D
if B.uplo == 'L'
odl = od[1:n - 1]
return Tridiagonal(odl, diag(LUD), B.symmetric ? odl : conj.(odl))
else # 'U'
odu = od[2:n]
return Tridiagonal(B.symmetric ? odu : conj.(odu), diag(LUD), odu)
end
elseif d == :L
if B.uplo == 'L'
return UnitLowerTriangular(LUD)
else
throw(ArgumentError("factorization is U*D*U.' but you requested L"))
end
else # :U
if B.uplo == 'U'
return UnitUpperTriangular(LUD)
else
throw(ArgumentError("factorization is L*D*L.' but you requested U"))
end
end
else
throw(KeyError(d))
end
end

issuccess(B::BunchKaufman) = B.info == 0

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, B::BunchKaufman)
println(io, summary(B))
print(io, "successful: $(issuccess(B))")
println(io, "D factor:")
show(io, mime, B[:D])
println(io, "\n$(B.uplo) factor:")
show(io, mime, B[Symbol(B.uplo)])
println(io, "\npermutation:")
show(io, mime, B[:p])
print(io, "\nsuccessful: $(issuccess(B))")
end

function inv(B::BunchKaufman{<:BlasReal})
Expand Down
91 changes: 85 additions & 6 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3990,9 +3990,9 @@ for (syconv, sysv, sytrf, sytri, sytrs, elty) in
end

# Rook-pivoting variants of symmetric-matrix algorithms
for (sysv, sytrf, sytri, sytrs, elty) in
((:dsysv_rook_,:dsytrf_rook_,:dsytri_rook_,:dsytrs_rook_,:Float64),
(:ssysv_rook_,:ssytrf_rook_,:ssytri_rook_,:ssytrs_rook_,:Float32))
for (sysv, sytrf, sytri, sytrs, syconvf, elty) in
((:dsysv_rook_,:dsytrf_rook_,:dsytri_rook_,:dsytrs_rook_,:dsyconvf_rook_,:Float64),
(:ssysv_rook_,:ssytrf_rook_,:ssytri_rook_,:ssytrs_rook_,:ssyconvf_rook_,:Float32))
@eval begin
# SUBROUTINE DSYSV_ROOK(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK,
# LWORK, INFO )
Expand Down Expand Up @@ -4107,6 +4107,45 @@ for (sysv, sytrf, sytri, sytrs, elty) in
chklapackerror(info[])
B
end

# SUBROUTINE DSYCONVF_ROOK( UPLO, WAY, N, A, LDA, IPIV, E, INFO )
#
# .. Scalar Arguments ..
# CHARACTER UPLO, WAY
# INTEGER INFO, LDA, N
# ..
# .. Array Arguments ..
# INTEGER IPIV( * )
# DOUBLE PRECISION A( LDA, * ), E( * )
function syconvf_rook!(uplo::Char, way::Char, A::StridedMatrix{$elty},
ipiv::StridedVector{BlasInt}, e::StridedVector{$elty})
# extract
n = checksquare(A)

# check
chkuplo(uplo)
if way != :C && way != :R
throw(ArgumentError("way must be :C or :R"))
end
if length(ipiv) != n
throw(ArgumentError("length of pivot vector was $(length(ipiv)) but should have been $n"))
end
if length(e) != n
throw(ArgumentError("length of e vector was $(length(ipiv)) but should have been $n"))
end

# allocate
info = Ref{BlasInt}()

ccall((@blasfunc($syconvf), liblapack), Void,
(Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &way, &n, A,
&lda, ipiv, e, info)

chklapackerror(info[])
return A, e
end
end
end

Expand Down Expand Up @@ -4548,9 +4587,9 @@ for (sysv, sytrf, sytri, sytrs, elty, relty) in
end
end

for (sysv, sytrf, sytri, sytrs, elty, relty) in
((:zsysv_rook_,:zsytrf_rook_,:zsytri_rook_,:zsytrs_rook_,:Complex128, :Float64),
(:csysv_rook_,:csytrf_rook_,:csytri_rook_,:csytrs_rook_,:Complex64, :Float32))
for (sysv, sytrf, sytri, sytrs, syconvf, elty, relty) in
((:zsysv_rook_,:zsytrf_rook_,:zsytri_rook_,:zsytrs_rook_,:zsyconvf_rook_,:Complex128, :Float64),
(:csysv_rook_,:csytrf_rook_,:csytri_rook_,:csytrs_rook_,:csyconvf_rook_,:Complex64, :Float32))
@eval begin
# SUBROUTINE ZSYSV_ROOK(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK,
# $ LWORK, INFO )
Expand Down Expand Up @@ -4667,6 +4706,46 @@ for (sysv, sytrf, sytri, sytrs, elty, relty) in
chklapackerror(info[])
B
end

# SUBROUTINE ZSYCONVF_ROOK( UPLO, WAY, N, A, LDA, IPIV, E, INFO )
#
# .. Scalar Arguments ..
# CHARACTER UPLO, WAY
# INTEGER INFO, LDA, N
# ..
# .. Array Arguments ..
# INTEGER IPIV( * )
# COMPLEX*16 A( LDA, * ), E( * )
function syconvf_rook!(uplo::Char, way::Char, A::StridedMatrix{$elty},
ipiv::StridedVector{BlasInt}, e::StridedVector{$elty} = Vector{$elty}(length(ipiv)))
# extract
n = checksquare(A)
lda = stride(A, 2)

# check
chkuplo(uplo)
if way != 'C' && way != 'R'
throw(ArgumentError("way must be 'C' or 'R'"))
end
if length(ipiv) != n
throw(ArgumentError("length of pivot vector was $(length(ipiv)) but should have been $n"))
end
if length(e) != n
throw(ArgumentError("length of e vector was $(length(ipiv)) but should have been $n"))
end

# allocate
info = Ref{BlasInt}()

ccall((@blasfunc($syconvf), liblapack), Void,
(Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &way, &n, A,
&max(1, lda), ipiv, e, info)

chklapackerror(info[])
return A, e
end
end
end

Expand Down
18 changes: 17 additions & 1 deletion doc/src/manual/linear-algebra.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,23 @@ julia> B = [1.5 2 -4; 2 -1 -3; -4 -3 5]
-4.0 -3.0 5.0
julia> factorize(B)
Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}}([-1.64286 0.142857 -0.8; 2.0 -2.8 -0.6; -4.0 -3.0 5.0], [1, 2, 3], 'U', true, false, 0)
Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}}
D factor:
3×3 Tridiagonal{Float64}:
-1.64286 0.0 ⋅
0.0 -2.8 0.0
⋅ 0.0 5.0
U factor:
3×3 Base.LinAlg.UnitUpperTriangular{Float64,Array{Float64,2}}:
1.0 0.142857 -0.8
0.0 1.0 -0.6
0.0 0.0 1.0
permutation:
3-element Array{Int64,1}:
1
2
3
successful: true
```

Here, Julia was able to detect that `B` is in fact symmetric, and used a more appropriate factorization.
Expand Down
22 changes: 15 additions & 7 deletions test/linalg/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ bimg = randn(n,2)/2

@testset for eltyb in (Float32, Float64, Complex64, Complex128, Int)
b = eltyb == Int ? rand(1:5, n, 2) : convert(Matrix{eltyb}, eltyb <: Complex ? complex.(breal, bimg) : breal)

# check that factorize gives a Bunch-Kaufman
@test isa(factorize(asym), LinAlg.BunchKaufman)
@test isa(factorize(aher), LinAlg.BunchKaufman)

@testset for btype in ("Array", "SubArray")
if btype == "Array"
b = b
Expand All @@ -49,10 +54,6 @@ bimg = randn(n,2)/2
εb = eps(abs(float(one(eltyb))))
ε = max(εa,εb)

# check that factorize gives a Bunch-Kaufman
@test isa(factorize(asym), LinAlg.BunchKaufman)
@test isa(factorize(aher), LinAlg.BunchKaufman)

@testset "$uplo Bunch-Kaufman factor of indefinite matrix" for uplo in (:L, :U)
bc1 = bkfact(Hermitian(aher, uplo))
@test LinAlg.issuccess(bc1)
Expand All @@ -73,6 +74,15 @@ bimg = randn(n,2)/2
@test_throws ArgumentError bkfact(a)
end
end
# Test extraction of factors
# syconvf_rook just added to LAPACK 3.7.0. Test when we distribute LAPACK 3.7.0
@test bc1[uplo]*bc1[:D]*bc1[uplo]' aher[bc1[:p], bc1[:p]]
@test bc1[uplo]*bc1[:D]*bc1[uplo]' bc1[:P]*aher*bc1[:P]'
if eltya <: Complex
bc1 = bkfact(Symmetric(asym, uplo))
@test bc1[uplo]*bc1[:D]*bc1[uplo].' asym[bc1[:p], bc1[:p]]
@test bc1[uplo]*bc1[:D]*bc1[uplo].' bc1[:P]*asym*bc1[:P]'
end
end

@testset "$uplo Bunch-Kaufman factors of a pos-def matrix" for uplo in (:U, :L)
Expand Down Expand Up @@ -122,9 +132,7 @@ end
end
end


# test example due to @timholy in PR 15354
let
@testset "test example due to @timholy in PR 15354" begin
A = rand(6,5); A = complex(A'*A) # to avoid calling the real-lhs-complex-rhs method
F = cholfact(A);
v6 = rand(Complex128, 6)
Expand Down

2 comments on commit 709d65e

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @jrevels

Please sign in to comment.