Skip to content

Commit

Permalink
Merge pull request #15505 from jw3126/bitrimatmul
Browse files Browse the repository at this point in the history
faster A_mul_B! and * involving bidiagonal and tridiagonal matrices
  • Loading branch information
tkelman committed May 16, 2016
2 parents 5d52f02 + 7182b87 commit 97dc858
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 36 deletions.
174 changes: 172 additions & 2 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,142 @@ end
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper)
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)

SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)

BiTriSym = Union{Bidiagonal, Tridiagonal, SymTridiagonal}
BiTri = Union{Bidiagonal, Tridiagonal}
A_mul_B!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTri, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractVector, A::BiTri, B::AbstractVector) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractVecOrMat, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)


function check_A_mul_B!_sizes(C, A, B)
nA, mA = size(A)
nB, mB = size(B)
nC, mC = size(C)
if !(nA == nC)
throw(DimensionMismatch("Sizes size(A)=$(size(A)) and size(C) = $(size(C)) must match at first entry."))
elseif !(mA == nB)
throw(DimensionMismatch("Second entry of size(A)=$(size(A)) and first entry of size(B) = $(size(B)) must match."))
elseif !(mB == mC)
throw(DimensionMismatch("Sizes size(B)=$(size(B)) and size(C) = $(size(C)) must match at first second entry."))
end
end

function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return A_mul_B!(C, full(A), full(B))
fill!(C, zero(eltype(C)))
Al = diag(A, -1)
Ad = diag(A, 0)
Au = diag(A, 1)
Bl = diag(B, -1)
Bd = diag(B, 0)
Bu = diag(B, 1)
@inbounds begin
# first row of C
C[1,1] = A[1,1]*B[1,1] + A[1, 2]*B[2, 1]
C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2]
C[1,3] = A[1,2]*B[2,3]
# second row of C
C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1]
C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]
C[2,3] = A[2,2]*B[2,3] + A[2,3]*B[3,3]
C[2,4] = A[2,3]*B[3,4]
for j in 3:n-2
Ajj₋1 = Al[j-1]
Ajj = Ad[j]
Ajj₊1 = Au[j]
Bj₋1j₋2 = Bl[j-2]
Bj₋1j₋1 = Bd[j-1]
Bj₋1j = Bu[j-1]
Bjj₋1 = Bl[j-1]
Bjj = Bd[j]
Bjj₊1 = Bu[j]
Bj₊1j = Bl[j]
Bj₊1j₊1 = Bd[j+1]
Bj₊1j₊2 = Bu[j+1]
C[j,j-2] = Ajj₋1*Bj₋1j₋2
C[j, j-1] = Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1
C[j, j ] = Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j
C[j, j+1] = Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1
C[j, j+2] = Ajj₊1*Bj₊1j₊2
end
# row before last of C
C[n-1,n-3] = A[n-1,n-2]*B[n-2,n-3]
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]
C[n-1,n-1] = A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]
# last row of C
C[n,n-2] = A[n,n-1]*B[n-1,n-2]
C[n,n-1] = A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]
C[n,n ] = A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]
end # inbounds
C
end

function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
nA = size(A,1)
nB = size(B,2)
if !(size(C,1) == size(B,1) == nA)
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
end
if size(C,2) != nB
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
nA <= 3 && return A_mul_B!(C, full(A), full(B))
l = diag(A, -1)
d = diag(A, 0)
u = diag(A, 1)
@inbounds begin
for j = 1:nB
b₀, b₊ = B[1, j], B[2, j]
C[1, j] = d[1]*b₀ + u[1]*b₊
for i = 2:nA - 1
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
end
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
end
end
C
end

function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return A_mul_B!(C, full(A), full(B))
m = size(B,2)
Bl = diag(B, -1)
Bd = diag(B, 0)
Bu = diag(B, 1)
@inbounds begin
# first and last column of C
B11 = Bd[1]
B21 = Bl[1]
Bmm = Bd[m]
Bm₋1m = Bu[m-1]
for i in 1:n
C[i, 1] = A[i,1] * B11 + A[i, 2] * B21
C[i, m] = A[i, m-1] * Bm₋1m + A[i, m] * Bmm
end
# middle columns of C
for j = 2:m-1
Bj₋1j = Bu[j-1]
Bjj = Bd[j]
Bj₊1j = Bl[j]
for i = 1:n
C[i, j] = A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j
end
end
end # inbounds
C
end

#Generic multiplication
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
Expand Down Expand Up @@ -329,3 +463,39 @@ function eigvecs{T}(M::Bidiagonal{T})
Q #Actually Triangular
end
eigfact(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M))

# fill! methods
_valuefields{T <: Diagonal}(S::Type{T}) = [:diag]
_valuefields{T <: Bidiagonal}(S::Type{T}) = [:dv, :ev]
_valuefields{T <: Tridiagonal}(S::Type{T}) = [:dl, :d, :du]
_valuefields{T <: SymTridiagonal}(S::Type{T}) = [:dv, :ev]
_valuefields{T <: AbstractTriangular}(S::Type{T}) = [:data]

SpecialArrays = Union{Diagonal,
Bidiagonal,
Tridiagonal,
SymTridiagonal,
AbstractTriangular}

@generated function fillslots!(A::SpecialArrays, x)
ex = :(xT = convert(eltype(A), x))
for field in _valuefields(A)
ex = :($ex; fill!(A.$field, xT))
end
:($ex;return A)
end

# for historical reasons:
fill!(a::AbstractTriangular, x) = fillslots!(a, x);
fill!(D::Diagonal, x) = fillslots!(D, x);

_small_enough(A::Bidiagonal) = size(A, 1) <= 1
_small_enough(A::Tridiagonal) = size(A, 1) <= 2
_small_enough(A::SymTridiagonal) = size(A, 1) <= 2

function fill!(A::Union{Bidiagonal, Tridiagonal, SymTridiagonal} ,x)
xT = convert(eltype(A), x)
(xT == zero(eltype(A)) || _small_enough(A)) && return fillslots!(A, xT)
throw(ArgumentError("Array A of type $(typeof(A)) and size $(size(A)) can
not be filled with x=$x, since some of its entries are constrained."))
end
2 changes: 0 additions & 2 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ function size(D::Diagonal,d::Integer)
return d<=2 ? length(D.diag) : 1
end

fill!(D::Diagonal, x) = (fill!(D.diag, x); D)

full(D::Diagonal) = diagm(D.diag)

@inline function getindex(D::Diagonal, i::Int, j::Int)
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1))
full(A::AbstractTriangular) = convert(Matrix, A)
parent(A::AbstractTriangular) = A.data

fill!(A::AbstractTriangular, x) = (fill!(A.data, x); A)

# then handle all methods that requires specific handling of upper/lower and unit diagonal

function convert{Tret,T,S}(::Type{Matrix{Tret}}, A::LowerTriangular{T,S})
Expand Down Expand Up @@ -380,6 +378,8 @@ scale!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = scale!(A,c)
######################

A_mul_B!(A::Tridiagonal, B::AbstractTriangular) = A*full!(B)
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::Tridiagonal) = A_mul_B!(C, full(A), B)
A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractTriangular) = A_mul_B!(C, A, full(B))
A_mul_B!(C::AbstractVector, A::AbstractTriangular, B::AbstractVector) = A_mul_B!(A, copy!(C, B))
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))
A_mul_B!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))
Expand Down
30 changes: 0 additions & 30 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,33 +497,3 @@ function convert{T}(::Type{SymTridiagonal{T}}, M::Tridiagonal)
throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal"))
end
end

A_mul_B!(C::AbstractVector, A::Tridiagonal, B::AbstractVector) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)

function A_mul_B_td!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat)
nA = size(A,1)
nB = size(B,2)
if !(size(C,1) == size(B,1) == nA)
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
end
if size(C,2) != nB
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
l = A.dl
d = A.d
u = A.du
@inbounds begin
for j = 1:nB
b₀, b₊ = B[1, j], B[2, j]
C[1, j] = d[1]*b₀ + u[1]*b₊
for i = 2:nA - 1
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
end
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
end
end
C
end
40 changes: 40 additions & 0 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1680,3 +1680,43 @@ droptol!(x::SparseVector, tol, trim::Bool = true) = fkeep!(x, (i, x) -> abs(x) >

dropzeros!(x::SparseVector, trim::Bool = true) = fkeep!(x, (i, x) -> x != 0, trim)
dropzeros(x::SparseVector, trim::Bool = true) = dropzeros!(copy(x), trim)

function _fillnonzero!{Tv,Ti}(arr::SparseMatrixCSC{Tv, Ti}, val)
m, n = size(arr)
resize!(arr.colptr, n+1)
resize!(arr.rowval, m*n)
resize!(arr.nzval, m*n)
copy!(arr.colptr, 1:m:n*m+1)
fill!(arr.nzval, val)
index = 1
@inbounds for _ in 1:n
for i in 1:m
arr.rowval[index] = Ti(i)
index += 1
end
end
arr
end

function _fillnonzero!{Tv,Ti}(arr::SparseVector{Tv,Ti}, val)
n = arr.n
resize!(arr.nzind, n)
resize!(arr.nzval, n)
@inbounds for i in 1:n
arr.nzind[i] = Ti(i)
end
fill!(arr.nzval, val)
arr
end

import Base.fill!
function fill!(A::Union{SparseVector, SparseMatrixCSC}, x)
T = eltype(A)
xT = convert(T, x)
if xT == zero(T)
fill!(A.nzval, xT)
else
_fillnonzero!(A, xT)
end
return A
end
46 changes: 46 additions & 0 deletions test/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,49 @@ C = Tridiagonal(rand(Float64,9),rand(Float64,10),rand(Float64,9))
@test promote_rule(Matrix{Float64}, Bidiagonal{Float64}) == Matrix{Float64}
@test promote(B,A) == (B,convert(Matrix{Float64},full(A)))
@test promote(C,A) == (C,Tridiagonal(zeros(Float64,9),convert(Vector{Float64},A.dv),convert(Vector{Float64},A.ev)))

import Base.LinAlg: fillslots!, UnitLowerTriangular
let #fill!
let # fillslots!
A = Tridiagonal(randn(2), randn(3), randn(2))
@test fillslots!(A, 3) == Tridiagonal([3, 3.], [3, 3, 3.], [3, 3.])
B = Bidiagonal(randn(3), randn(2), true)
@test fillslots!(B, 2) == Bidiagonal([2.,2,2], [2,2.], true)
S = SymTridiagonal(randn(3), randn(2))
@test fillslots!(S, 1) == SymTridiagonal([1,1,1.], [1,1.])
Ult = UnitLowerTriangular(randn(3,3))
@test fillslots!(Ult, 3) == UnitLowerTriangular([1 0 0; 3 1 0; 3 3 1])
end
let # fill!(exotic, 0)
exotic_arrays = Any[Tridiagonal(randn(3), randn(4), randn(3)),
Bidiagonal(randn(3), randn(2), rand(Bool)),
SymTridiagonal(randn(3), randn(2)),
sparse(randn(3,4)),
Diagonal(randn(5)),
sparse(rand(3)),
LowerTriangular(randn(3,3)),
UpperTriangular(randn(3,3))
]
for A in exotic_arrays
fill!(A, 0)
for a in A
@test a == 0
end
end
end
let # fill!(small, x)
val = randn()
b = Bidiagonal(randn(1,1), true)
st = SymTridiagonal(randn(1,1))
for x in (b, st)
@test full(fill!(x, val)) == fill!(full(x), val)
end
b = Bidiagonal(randn(2,2), true)
st = SymTridiagonal(randn(3), randn(2))
t = Tridiagonal(randn(3,3))
for x in (b, t, st)
@test_throws ArgumentError fill!(x, val)
@test full(fill!(x, 0)) == fill!(full(x), 0)
end
end
end
56 changes: 56 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,59 @@ a = [RootInt(2),RootInt(10)]
@test a*a' == [4 20; 20 100]
A = [RootInt(3) RootInt(5)]
@test A*a == [56]

function test_mul(C, A, B)
A_mul_B!(C, A, B)
@test full(A) * full(B) C
@test A*B C
end

let
eltypes = [Float32, Float64, Int64]
for k in [3, 4, 10]
T = rand(eltypes)
bi1 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
bi2 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
tri1 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
tri2 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
stri1 = SymTridiagonal(rand(T, k), rand(T, k-1))
stri2 = SymTridiagonal(rand(T, k), rand(T, k-1))
C = rand(T, k, k)
specialmatrices = (bi1, bi2, tri1, tri2, stri1, stri2)
for A in specialmatrices
B = specialmatrices[rand(1:length(specialmatrices))]
test_mul(C, A, B)
end
for S in specialmatrices
l = rand(1:6)
B = randn(k, l)
C = randn(k, l)
test_mul(C, S, B)
A = randn(l, k)
C = randn(l, k)
test_mul(C, A, S)
end
end
for T in eltypes
A = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
B = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
C = randn(2,2)
test_mul(C, A, B)
B = randn(2, 9)
C = randn(2, 9)
test_mul(C, A, B)
end
let
tri44 = Tridiagonal(randn(3), randn(4), randn(3))
tri33 = Tridiagonal(randn(2), randn(3), randn(2))
full43 = randn(4, 3)
full24 = randn(2, 4)
full33 = randn(3, 3)
full44 = randn(4, 4)
@test_throws DimensionMismatch A_mul_B!(full43, tri44, tri33)
@test_throws DimensionMismatch A_mul_B!(full44, tri44, tri33)
@test_throws DimensionMismatch A_mul_B!(full44, tri44, full43)
@test_throws DimensionMismatch A_mul_B!(full43, tri33, full43)
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end
Loading

0 comments on commit 97dc858

Please sign in to comment.