Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster A_mul_B! and * involving bidiagonal and tridiagonal matrices #15505

Merged
merged 1 commit into from
May 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 172 additions & 2 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,142 @@ end
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper)
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)

SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)

BiTriSym = Union{Bidiagonal, Tridiagonal, SymTridiagonal}
BiTri = Union{Bidiagonal, Tridiagonal}
A_mul_B!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTri, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 5 lines should probably be expressed with a Union. And the A::BiTri line is redundant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhh I don't know a simpler way then these 5 lines. Its easy to produce ambiguous definitions. For example if I remove the BiTri line I get

WARNING: New definition 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}) at linalg/bidiag.jl:236
is ambiguous with: 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}) at linalg/bidiag.jl:232.
To fix, define 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal})
before the new definition.
WARNING: New definition 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}) at linalg/bidiag.jl:236
is ambiguous with: 
    A_mul_B!(AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}) at linalg/bidiag.jl:234.
To fix, define 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal})
before the new definition.
WARNING: New definition 
    A_mul_B!(Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}) at linalg/bidiag.jl:237
is ambiguous with: 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}) at linalg/bidiag.jl:232.
To fix, define 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal})
before the new definition.
WARNING: New definition 
    A_mul_B!(Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 1}}) at linalg/bidiag.jl:237
is ambiguous with: 
    A_mul_B!(AbstractArray{T<:Any, 2}, AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal}) at linalg/bidiag.jl:234.
To fix, define 
    A_mul_B!(AbstractArray{T<:Any, 2}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal}, Union{Base.LinAlg.Tridiagonal, Base.LinAlg.Bidiagonal, Base.LinAlg.SymTridiagonal})

One big union also definitely does not work. In fact I started with AbstractMatrix (which the Union boils down to)

 +A_mul_B!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B) 

and these five lines are what I had to do to fix ambiguities.

A_mul_B!(C::AbstractVector, A::BiTri, B::AbstractVector) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractVecOrMat, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)


function check_A_mul_B!_sizes(C, A, B)
nA, mA = size(A)
nB, mB = size(B)
nC, mC = size(C)
if !(nA == nC)
throw(DimensionMismatch("Sizes size(A)=$(size(A)) and size(C) = $(size(C)) must match at first entry."))
elseif !(mA == nB)
throw(DimensionMismatch("Second entry of size(A)=$(size(A)) and first entry of size(B) = $(size(B)) must match."))
elseif !(mB == mC)
throw(DimensionMismatch("Sizes size(B)=$(size(B)) and size(C) = $(size(C)) must match at first second entry."))
end
end

function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return A_mul_B!(C, full(A), full(B))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any advantage to using full here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is stuff like C[2,4] = A[2,3]*B[3,4] in the code, which violates bounds of small matrices. Thats why n <= 3 is special.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd likely get much better performance by hoisting out and manually doing the transformation between A[2,3] vs Adu = A.du outside the loop then referring to Adu[2] within the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkelman I think I already did this. For indexing I used the following rules:

  1. Outside loops explicit indexing like A[1,2] is allowed for clarity. The performance overhead is negligible.
  2. Inside loops do all indexing of BiTriSym matrices manually.

If you ever find code like A[i, j] inside a loop, its because A is not BiTriSym e.g. full or something.

fill!(C, zero(eltype(C)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that fill!(C, zero(eltype(C))) is not an ideal way to fill a matrix with zeros. It is inefficient if C is sparse. Is there a more efficient standard way to fill an arbitrary matrix with zeros? Otherwise I would create a zeros! method, which does it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to be inefficient, if there's a specialized fill! method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok currently fill! loops over all indices. I guess this is the optimal thing to do, except for cases of certain special matrices and the fillvalue zero.

So you suggest to have a fill! method for sparse arrays which does the current thing for nonzero fill value and something special (e.g. delete the entries) for fill value zero?
There is another issue. For say Tridiagonal matrices fill! - values other then zero don't make sense anyway. And fill! currently always throws an error on Tridiagonal matrices. So we would need to check for Tridiagonal whether the fill value is zero.

I think zero is special enough that it is reasonable to add a zeros! function anyway. The question is should fill! check for zero and call zeros! in this case or leave fill! as it is and call zeros! manually?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage of going through fill! is that we don't have to export another function, and the runtime check would be extremely cheap (i.e., essentially unmeasurable). But I'm not particularly opposed to zeros!.

It's been debated multiple times, but for SparseMatrixCSC I think the trend is to distinguish structural zeros from value zeros, and say that SparseMatrixCSC is happy to keep value zeros. By that logic, the implementation (once you've checked that the value is 0) is simply fill!(A.nzval, 0). In other words, don't delete any entries, just set them to 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill!(sparse, 0) is now fast. I also added a fill! method for various special matrices that fills all 'data slots'. E.g. it behaves as follows:

import Base.LinAlg.UnitLowerTriangular
a = UnitLowerTriangular(randn(3,3))
fill!(a, 42)
3x3 Base.LinAlg.UnitLowerTriangular{Float64,Array{Float64,2}}:
  1.0   0.0  0.0
 42.0   1.0  0.0
 42.0  42.0  1.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I did not even add that particular method, it existed before. I added analogues for Bidiagonal etc.

Initially I called this kind of function fillslots!, until I discovered that it already existed in the Triangular case under the name fill!. It would probably be saver if fill! would always commute with full or throw an error if this is not possible.
full(fill!(A, a)) == fill!(full(A), a)
Anyway I don't have a strong opinion about renaming fill! in cases where it does not commute with full and certainly did not want to break existing behaviour.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, that seems wrong. I agree that should commute. If the fill-value is nonzero then I think fill! should error if the matrix is large enough that there are implicit zeros present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think the behaviour of 'fill!' should be its own issue. For the matrix multiplication stuff here one only needs cases of zero filling which commute with full anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't need to be fixed here, but don't add any new incorrect methods of fill for structured matrix types here without checking the fill value. zero filling cannot work on UnitTriangular so that case also needs special handling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All fill! methods added by me now raise errors if they do not commute with full.

Al = diag(A, -1)
Ad = diag(A, 0)
Au = diag(A, 1)
Bl = diag(B, -1)
Bd = diag(B, 0)
Bu = diag(B, 1)
@inbounds begin
# first row of C
C[1,1] = A[1,1]*B[1,1] + A[1, 2]*B[2, 1]
C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2]
C[1,3] = A[1,2]*B[2,3]
# second row of C
C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1]
C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]
C[2,3] = A[2,2]*B[2,3] + A[2,3]*B[3,3]
C[2,4] = A[2,3]*B[3,4]
for j in 3:n-2
Ajj₋1 = Al[j-1]
Ajj = Ad[j]
Ajj₊1 = Au[j]
Bj₋1j₋2 = Bl[j-2]
Bj₋1j₋1 = Bd[j-1]
Bj₋1j = Bu[j-1]
Bjj₋1 = Bl[j-1]
Bjj = Bd[j]
Bjj₊1 = Bu[j]
Bj₊1j = Bl[j]
Bj₊1j₊1 = Bd[j+1]
Bj₊1j₊2 = Bu[j+1]
C[j,j-2] = Ajj₋1*Bj₋1j₋2
C[j, j-1] = Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1
C[j, j ] = Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j
C[j, j+1] = Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1
C[j, j+2] = Ajj₊1*Bj₊1j₊2
end
# row before last of C
C[n-1,n-3] = A[n-1,n-2]*B[n-2,n-3]
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]
C[n-1,n-1] = A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]
# last row of C
C[n,n-2] = A[n,n-1]*B[n-1,n-2]
C[n,n-1] = A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]
C[n,n ] = A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]
end # inbounds
C
end

function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
nA = size(A,1)
nB = size(B,2)
if !(size(C,1) == size(B,1) == nA)
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
end
if size(C,2) != nB
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
nA <= 3 && return A_mul_B!(C, full(A), full(B))
l = diag(A, -1)
d = diag(A, 0)
u = diag(A, 1)
@inbounds begin
for j = 1:nB
b₀, b₊ = B[1, j], B[2, j]
C[1, j] = d[1]*b₀ + u[1]*b₊
for i = 2:nA - 1
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
end
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
end
end
C
end

function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return A_mul_B!(C, full(A), full(B))
m = size(B,2)
Bl = diag(B, -1)
Bd = diag(B, 0)
Bu = diag(B, 1)
@inbounds begin
# first and last column of C
B11 = Bd[1]
B21 = Bl[1]
Bmm = Bd[m]
Bm₋1m = Bu[m-1]
for i in 1:n
C[i, 1] = A[i,1] * B11 + A[i, 2] * B21
C[i, m] = A[i, m-1] * Bm₋1m + A[i, m] * Bmm
end
# middle columns of C
for j = 2:m-1
Bj₋1j = Bu[j-1]
Bjj = Bd[j]
Bj₊1j = Bl[j]
for i = 1:n
C[i, j] = A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j
end
end
end # inbounds
C
end

#Generic multiplication
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
Expand Down Expand Up @@ -329,3 +463,39 @@ function eigvecs{T}(M::Bidiagonal{T})
Q #Actually Triangular
end
eigfact(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M))

# fill! methods
_valuefields{T <: Diagonal}(S::Type{T}) = [:diag]
_valuefields{T <: Bidiagonal}(S::Type{T}) = [:dv, :ev]
_valuefields{T <: Tridiagonal}(S::Type{T}) = [:dl, :d, :du]
_valuefields{T <: SymTridiagonal}(S::Type{T}) = [:dv, :ev]
_valuefields{T <: AbstractTriangular}(S::Type{T}) = [:data]

SpecialArrays = Union{Diagonal,
Bidiagonal,
Tridiagonal,
SymTridiagonal,
AbstractTriangular}

@generated function fillslots!(A::SpecialArrays, x)
ex = :(xT = convert(eltype(A), x))
for field in _valuefields(A)
ex = :($ex; fill!(A.$field, xT))
end
:($ex;return A)
end

# for historical reasons:
fill!(a::AbstractTriangular, x) = fillslots!(a, x);
fill!(D::Diagonal, x) = fillslots!(D, x);

_small_enough(A::Bidiagonal) = size(A, 1) <= 1
_small_enough(A::Tridiagonal) = size(A, 1) <= 2
_small_enough(A::SymTridiagonal) = size(A, 1) <= 2

function fill!(A::Union{Bidiagonal, Tridiagonal, SymTridiagonal} ,x)
xT = convert(eltype(A), x)
(xT == zero(eltype(A)) || _small_enough(A)) && return fillslots!(A, xT)
throw(ArgumentError("Array A of type $(typeof(A)) and size $(size(A)) can
not be filled with x=$x, since some of its entries are constrained."))
end
2 changes: 0 additions & 2 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ function size(D::Diagonal,d::Integer)
return d<=2 ? length(D.diag) : 1
end

fill!(D::Diagonal, x) = (fill!(D.diag, x); D)

full(D::Diagonal) = diagm(D.diag)

@inline function getindex(D::Diagonal, i::Int, j::Int)
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1))
full(A::AbstractTriangular) = convert(Matrix, A)
parent(A::AbstractTriangular) = A.data

fill!(A::AbstractTriangular, x) = (fill!(A.data, x); A)

# then handle all methods that requires specific handling of upper/lower and unit diagonal

function convert{Tret,T,S}(::Type{Matrix{Tret}}, A::LowerTriangular{T,S})
Expand Down Expand Up @@ -376,6 +374,8 @@ scale!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = scale!(A,c)
######################

A_mul_B!(A::Tridiagonal, B::AbstractTriangular) = A*full!(B)
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::Tridiagonal) = A_mul_B!(C, full(A), B)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that it's a good idea to have all these ! methods calling full. It used to be the case that (with very few exceptions) A_mul_B! didn't allocate for dense matrices and that these methods were the low-level "expert" methods. Therefore, it might be better to have these conversion at a higher level, i.e. *, and Ax_mul_B.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure either. However I need to define these A_mul_B! to resolve definition ambiguity. Do you have a suggestion what to use instead of full? Full is also not that bad here since I expect AbstractTriangular to be half full anyway.

A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractTriangular) = A_mul_B!(C, A, full(B))
A_mul_B!(C::AbstractVector, A::AbstractTriangular, B::AbstractVector) = A_mul_B!(A, copy!(C, B))
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))
A_mul_B!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))
Expand Down
30 changes: 0 additions & 30 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,33 +497,3 @@ function convert{T}(::Type{SymTridiagonal{T}}, M::Tridiagonal)
throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal"))
end
end

A_mul_B!(C::AbstractVector, A::Tridiagonal, B::AbstractVector) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
A_mul_B!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)

function A_mul_B_td!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat)
nA = size(A,1)
nB = size(B,2)
if !(size(C,1) == size(B,1) == nA)
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
end
if size(C,2) != nB
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
l = A.dl
d = A.d
u = A.du
@inbounds begin
for j = 1:nB
b₀, b₊ = B[1, j], B[2, j]
C[1, j] = d[1]*b₀ + u[1]*b₊
for i = 2:nA - 1
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
end
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
end
end
C
end
40 changes: 40 additions & 0 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1681,3 +1681,43 @@ droptol!(x::SparseVector, tol, trim::Bool = true) = fkeep!(x, (i, x, tol) -> abs

dropzeros!(x::SparseVector, trim::Bool = true) = fkeep!(x, (i, x, other) -> x != 0, nothing, trim)
dropzeros(x::SparseVector, trim::Bool = true) = dropzeros!(copy(x), trim)

function _fillnonzero!{Tv,Ti}(arr::SparseMatrixCSC{Tv, Ti}, val)
m, n = size(arr)
resize!(arr.colptr, n+1)
resize!(arr.rowval, m*n)
resize!(arr.nzval, m*n)
copy!(arr.colptr, 1:m:n*m+1)
fill!(arr.nzval, val)
index = 1
@inbounds for _ in 1:n
for i in 1:m
arr.rowval[index] = Ti(i)
index += 1
end
end
arr
end

function _fillnonzero!{Tv,Ti}(arr::SparseVector{Tv,Ti}, val)
n = arr.n
resize!(arr.nzind, n)
resize!(arr.nzval, n)
@inbounds for i in 1:n
arr.nzind[i] = Ti(i)
end
fill!(arr.nzval, val)
arr
end

import Base.fill!
function fill!(A::Union{SparseVector, SparseMatrixCSC}, x)
T = eltype(A)
xT = convert(T, x)
if xT == zero(T)
fill!(A.nzval, xT)
else
_fillnonzero!(A, xT)
end
return A
end
46 changes: 46 additions & 0 deletions test/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,49 @@ C = Tridiagonal(rand(Float64,9),rand(Float64,10),rand(Float64,9))
@test promote_rule(Matrix{Float64}, Bidiagonal{Float64}) == Matrix{Float64}
@test promote(B,A) == (B,convert(Matrix{Float64},full(A)))
@test promote(C,A) == (C,Tridiagonal(zeros(Float64,9),convert(Vector{Float64},A.dv),convert(Vector{Float64},A.ev)))

import Base.LinAlg: fillslots!, UnitLowerTriangular
let #fill!
let # fillslots!
A = Tridiagonal(randn(2), randn(3), randn(2))
@test fillslots!(A, 3) == Tridiagonal([3, 3.], [3, 3, 3.], [3, 3.])
B = Bidiagonal(randn(3), randn(2), true)
@test fillslots!(B, 2) == Bidiagonal([2.,2,2], [2,2.], true)
S = SymTridiagonal(randn(3), randn(2))
@test fillslots!(S, 1) == SymTridiagonal([1,1,1.], [1,1.])
Ult = UnitLowerTriangular(randn(3,3))
@test fillslots!(Ult, 3) == UnitLowerTriangular([1 0 0; 3 1 0; 3 3 1])
end
let # fill!(exotic, 0)
exotic_arrays = Any[Tridiagonal(randn(3), randn(4), randn(3)),
Bidiagonal(randn(3), randn(2), rand(Bool)),
SymTridiagonal(randn(3), randn(2)),
sparse(randn(3,4)),
Diagonal(randn(5)),
sparse(rand(3)),
LowerTriangular(randn(3,3)),
UpperTriangular(randn(3,3))
]
for A in exotic_arrays
fill!(A, 0)
for a in A
@test a == 0
end
end
end
let # fill!(small, x)
val = randn()
b = Bidiagonal(randn(1,1), true)
st = SymTridiagonal(randn(1,1))
for x in (b, st)
@test full(fill!(x, val)) == fill!(full(x), val)
end
b = Bidiagonal(randn(2,2), true)
st = SymTridiagonal(randn(3), randn(2))
t = Tridiagonal(randn(3,3))
for x in (b, t, st)
@test_throws ArgumentError fill!(x, val)
@test full(fill!(x, 0)) == fill!(full(x), 0)
end
end
end
56 changes: 56 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,59 @@ a = [RootInt(2),RootInt(10)]
@test a*a' == [4 20; 20 100]
A = [RootInt(3) RootInt(5)]
@test A*a == [56]

function test_mul(C, A, B)
A_mul_B!(C, A, B)
@test full(A) * full(B) ≈ C
@test A*B ≈ C
end

let
eltypes = [Float32, Float64, Int64]
for k in [3, 4, 10]
T = rand(eltypes)
bi1 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
bi2 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
tri1 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
tri2 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
stri1 = SymTridiagonal(rand(T, k), rand(T, k-1))
stri2 = SymTridiagonal(rand(T, k), rand(T, k-1))
C = rand(T, k, k)
specialmatrices = (bi1, bi2, tri1, tri2, stri1, stri2)
for A in specialmatrices
B = specialmatrices[rand(1:length(specialmatrices))]
test_mul(C, A, B)
end
for S in specialmatrices
l = rand(1:6)
B = randn(k, l)
C = randn(k, l)
test_mul(C, S, B)
A = randn(l, k)
C = randn(l, k)
test_mul(C, A, S)
end
end
for T in eltypes
A = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
B = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
C = randn(2,2)
test_mul(C, A, B)
B = randn(2, 9)
C = randn(2, 9)
test_mul(C, A, B)
end
let
tri44 = Tridiagonal(randn(3), randn(4), randn(3))
tri33 = Tridiagonal(randn(2), randn(3), randn(2))
full43 = randn(4, 3)
full24 = randn(2, 4)
full33 = randn(3, 3)
full44 = randn(4, 4)
@test_throws DimensionMismatch A_mul_B!(full43, tri44, tri33)
@test_throws DimensionMismatch A_mul_B!(full44, tri44, tri33)
@test_throws DimensionMismatch A_mul_B!(full44, tri44, full43)
@test_throws DimensionMismatch A_mul_B!(full43, tri33, full43)
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end
Loading