diff --git a/base/linalg/bidiag.jl b/base/linalg/bidiag.jl index 98117ee41b7357..a56959040610be 100644 --- a/base/linalg/bidiag.jl +++ b/base/linalg/bidiag.jl @@ -224,8 +224,142 @@ end /(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper) ==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper) -SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular} -*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B) + +BiTriSym = Union{Bidiagonal, Tridiagonal, SymTridiagonal} +BiTri = Union{Bidiagonal, Tridiagonal} +A_mul_B!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractMatrix, A::BiTri, B::BiTriSym) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractVector, A::BiTri, B::AbstractVector) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractMatrix, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B) +A_mul_B!(C::AbstractVecOrMat, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B) + + +function check_A_mul_B!_sizes(C, A, B) + nA, mA = size(A) + nB, mB = size(B) + nC, mC = size(C) + if !(nA == nC) + throw(DimensionMismatch("Sizes size(A)=$(size(A)) and size(C) = $(size(C)) must match at first entry.")) + elseif !(mA == nB) + throw(DimensionMismatch("Second entry of size(A)=$(size(A)) and first entry of size(B) = $(size(B)) must match.")) + elseif !(mB == mC) + throw(DimensionMismatch("Sizes size(B)=$(size(B)) and size(C) = $(size(C)) must match at first second entry.")) + end +end + +function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) + check_A_mul_B!_sizes(C, A, B) + n = size(A,1) + n <= 3 && return A_mul_B!(C, full(A), full(B)) + fill!(C, zero(eltype(C))) + Al = diag(A, -1) + Ad = diag(A, 0) + Au = diag(A, 1) + Bl = diag(B, -1) + Bd = diag(B, 0) + Bu = diag(B, 1) + @inbounds begin + # first row of C + C[1,1] = A[1,1]*B[1,1] + A[1, 2]*B[2, 1] + C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2] + C[1,3] = A[1,2]*B[2,3] + # second row of C + C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1] + C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2] + C[2,3] = A[2,2]*B[2,3] + A[2,3]*B[3,3] + C[2,4] = A[2,3]*B[3,4] + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j,j-2] = Ajj₋1*Bj₋1j₋2 + C[j, j-1] = Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1 + C[j, j ] = Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j + C[j, j+1] = Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1 + C[j, j+2] = Ajj₊1*Bj₊1j₊2 + end + # row before last of C + C[n-1,n-3] = A[n-1,n-2]*B[n-2,n-3] + C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2] + C[n-1,n-1] = A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1] + C[n-1,n ] = A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ] + # last row of C + C[n,n-2] = A[n,n-1]*B[n-1,n-2] + C[n,n-1] = A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1] + C[n,n ] = A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ] + end #inbounds + C +end + +function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat) + nA = size(A,1) + nB = size(B,2) + if !(size(C,1) == size(B,1) == nA) + throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match")) + end + if size(C,2) != nB + throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match")) + end + nA <= 3 && return A_mul_B!(C, full(A), full(B)) + l = diag(A, -1) + d = diag(A, 0) + u = diag(A, 1) + @inbounds begin + for j = 1:nB + b₀, b₊ = B[1, j], B[2, j] + C[1, j] = d[1]*b₀ + u[1]*b₊ + for i = 2:nA - 1 + b₋, b₀, b₊ = b₀, b₊, B[i + 1, j] + C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊ + end + C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊ + end + end + C +end + +function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) + check_A_mul_B!_sizes(C, A, B) + n = size(A,1) + n <= 3 && return A_mul_B!(C, full(A), full(B)) + m = size(B,2) + Bl = diag(B, -1) + Bd = diag(B, 0) + Bu = diag(B, 1) + @inbounds begin + # first and last column of C + B11 = Bd[1] + B21 = Bl[1] + Bmm = Bd[m] + Bm₋1m = Bu[m-1] + for i in 1:n + C[i, 1] = A[i,1] * B11 + A[i, 2] * B21 + C[i, m] = A[i, m-1] * Bm₋1m + A[i, m] * Bmm + end + # middle columns of C + for j = 2:m-1 + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + for i = 1:n + C[i, j] = A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j + end + end + end#inbounds + C +end #Generic multiplication for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc) @@ -329,3 +463,43 @@ function eigvecs{T}(M::Bidiagonal{T}) Q #Actually Triangular end eigfact(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M)) + +# fill! methods +_valuefields{T <: Diagonal}(S::Type{T}) = [:diag] +_valuefields{T <: Bidiagonal}(S::Type{T}) = [:dv, :ev] +_valuefields{T <: Tridiagonal}(S::Type{T}) = [:dl, :d, :du] +_valuefields{T <: SymTridiagonal}(S::Type{T}) = [:dv, :ev] +_valuefields{T <: AbstractTriangular}(S::Type{T}) = [:data] + +SpecialArrays = Union{Diagonal, + Bidiagonal, + Tridiagonal, + SymTridiagonal, + AbstractTriangular} + +@generated function fillslots!(A::SpecialArrays, x) + ex = quote + xT=convert(eltype(A), x) + end + for field in _valuefields(A) + ex_field = :(fill!(A.$(field), xT);) + append!(ex.args, ex_field.args) + end + append!(ex.args, :(return A).args) + return ex +end + +# for historical reasons: +fill!(a::AbstractTriangular, x) = fillslots!(a, x); +fill!(D::Diagonal, x) = fillslots!(D, x); + +_small_enough(A::Bidiagonal) = size(A, 1) <= 1 +_small_enough(A::Tridiagonal) = size(A, 1) <= 2 +_small_enough(A::SymTridiagonal) = size(A, 1) <= 2 + +function fill!(A::Union{Bidiagonal, Tridiagonal, SymTridiagonal} ,x) + xT = convert(eltype(A), x) + (xT == zero(eltype(A)) || _small_enough(A)) && return fillslots!(A, xT) + throw(ArgumentError("Array A of type $(typeof(A)) and size $(size(A)) can + not be filled with x=$x since some of its entries are constrained.")) +end diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index d08e716dfb253b..eaa0d782482557 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -29,8 +29,6 @@ function size(D::Diagonal,d::Integer) return d<=2 ? length(D.diag) : 1 end -fill!(D::Diagonal, x) = (fill!(D.diag, x); D) - full(D::Diagonal) = diagm(D.diag) @inline function getindex(D::Diagonal, i::Int, j::Int) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 6a58d538c6257a..b2a792943dba56 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -47,8 +47,6 @@ imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1)) full(A::AbstractTriangular) = convert(Matrix, A) parent(A::AbstractTriangular) = A.data -fill!(A::AbstractTriangular, x) = (fill!(A.data, x); A) - # then handle all methods that requires specific handling of upper/lower and unit diagonal function convert{Tret,T,S}(::Type{Matrix{Tret}}, A::LowerTriangular{T,S}) @@ -376,6 +374,8 @@ scale!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = scale!(A,c) ###################### A_mul_B!(A::Tridiagonal, B::AbstractTriangular) = A*full!(B) +A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::Tridiagonal) = A_mul_B!(C, full(A), B) +A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractTriangular) = A_mul_B!(C, A, full(B)) A_mul_B!(C::AbstractVector, A::AbstractTriangular, B::AbstractVector) = A_mul_B!(A, copy!(C, B)) A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B)) A_mul_B!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B)) diff --git a/base/linalg/tridiag.jl b/base/linalg/tridiag.jl index e2be1b59092335..b88250375fafbb 100644 --- a/base/linalg/tridiag.jl +++ b/base/linalg/tridiag.jl @@ -497,33 +497,3 @@ function convert{T}(::Type{SymTridiagonal{T}}, M::Tridiagonal) throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal")) end end - -A_mul_B!(C::AbstractVector, A::Tridiagonal, B::AbstractVector) = A_mul_B_td!(C, A, B) -A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B) -A_mul_B!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B) - -function A_mul_B_td!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat) - nA = size(A,1) - nB = size(B,2) - if !(size(C,1) == size(B,1) == nA) - throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match")) - end - if size(C,2) != nB - throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match")) - end - l = A.dl - d = A.d - u = A.du - @inbounds begin - for j = 1:nB - b₀, b₊ = B[1, j], B[2, j] - C[1, j] = d[1]*b₀ + u[1]*b₊ - for i = 2:nA - 1 - b₋, b₀, b₊ = b₀, b₊, B[i + 1, j] - C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊ - end - C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊ - end - end - C -end diff --git a/base/sparse/sparsevector.jl b/base/sparse/sparsevector.jl index b8c2c87c04d22d..c08892018021ff 100644 --- a/base/sparse/sparsevector.jl +++ b/base/sparse/sparsevector.jl @@ -1681,3 +1681,35 @@ droptol!(x::SparseVector, tol, trim::Bool = true) = fkeep!(x, (i, x, tol) -> abs dropzeros!(x::SparseVector, trim::Bool = true) = fkeep!(x, (i, x, other) -> x != 0, nothing, trim) dropzeros(x::SparseVector, trim::Bool = true) = dropzeros!(copy(x), trim) + +function _fillnonzero!{Tv,Ti}(arr::SparseMatrixCSC{Tv, Ti}, val) + m, n = size(arr) + arr.colptr = convert(Vector{Ti}, collect(1:m:n*m+1)) + arr.rowval = convert(Vector{Ti}, vcat([1:m for _ in 1:n]...)) + resize!(arr.nzval, n*m) + fill!(arr.nzval, val) + arr +end + +function _fillnonzero!{Tv,Ti}(arr::SparseVector{Tv,Ti}, val) + n = arr.n + resize!(arr.nzind, n) + resize!(arr.nzval, n) + for i in 1:n + arr.nzind[i] = Ti(i) + end + fill!(arr.nzval, val) + arr +end + +import Base.fill! +function fill!(A:: Union{SparseVector, SparseMatrixCSC}, x) + T = eltype(A) + xT = convert(T, x) + if xT == zero(T) + fill!(A.nzval, xT) + else + _fillnonzero!(A, xT) + end + return A +end diff --git a/test/linalg/bidiag.jl b/test/linalg/bidiag.jl index a0492544317911..168790a2dddd10 100644 --- a/test/linalg/bidiag.jl +++ b/test/linalg/bidiag.jl @@ -236,3 +236,49 @@ C = Tridiagonal(rand(Float64,9),rand(Float64,10),rand(Float64,9)) @test promote_rule(Matrix{Float64}, Bidiagonal{Float64}) == Matrix{Float64} @test promote(B,A) == (B,convert(Matrix{Float64},full(A))) @test promote(C,A) == (C,Tridiagonal(zeros(Float64,9),convert(Vector{Float64},A.dv),convert(Vector{Float64},A.ev))) + +import Base.LinAlg: fillslots!, UnitLowerTriangular +let #fill! + let # fillslots! + A = Tridiagonal(randn(2), randn(3), randn(2)) + @test fillslots!(A, 3) == Tridiagonal([3, 3.], [3, 3, 3.], [3, 3.]) + B = Bidiagonal(randn(3), randn(2), true) + @test fillslots!(B, 2) == Bidiagonal([2.,2,2], [2,2.], true) + S = SymTridiagonal(randn(3), randn(2)) + @test fillslots!(S, 1) == SymTridiagonal([1,1,1.], [1,1.]) + Ult = UnitLowerTriangular(randn(3,3)) + @test fillslots!(Ult, 3) == UnitLowerTriangular([1 0 0; 3 1 0; 3 3 1]) + end + let # fill!(exotic, 0) + exotic_arrays = Any[Tridiagonal(randn(3), randn(4), randn(3)), + Bidiagonal(randn(3), randn(2), rand(Bool)), + SymTridiagonal(randn(3), randn(2)), + sparse(randn(3,4)), + Diagonal(randn(5)), + sparse(rand(3)), + LowerTriangular(randn(3,3)), + UpperTriangular(randn(3,3)) + ] + for A in exotic_arrays + fill!(A, 0) + for a in A + @test a == 0 + end + end + end + let # fill!(small, x) + val = randn() + b = Bidiagonal(randn(1,1), true) + st = SymTridiagonal(randn(1,1)) + for x in (b, st) + @test full(fill!(x, val)) == fill!(full(x), val) + end + b = Bidiagonal(randn(2,2), true) + st = SymTridiagonal(randn(3), randn(2)) + t = Tridiagonal(randn(3,3)) + for x in (b, t, st) + @test_throws ArgumentError fill!(x, val) + @test full(fill!(x, 0)) == fill!(full(x), 0) + end + end +end diff --git a/test/linalg/matmul.jl b/test/linalg/matmul.jl index a336b92f7ad5d1..ec7070d7d6e4cd 100644 --- a/test/linalg/matmul.jl +++ b/test/linalg/matmul.jl @@ -332,3 +332,59 @@ a = [RootInt(2),RootInt(10)] @test a*a' == [4 20; 20 100] A = [RootInt(3) RootInt(5)] @test A*a == [56] + +function test_mul(C, A, B) + A_mul_B!(C, A, B) + @test full(A) * full(B) ≈ C + @test A*B ≈ C +end + +let + eltypes = [Float32, Float64, Int64] + for k in [3, 4, 10] + T = rand(eltypes) + bi1 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool)) + bi2 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool)) + tri1 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1)) + tri2 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1)) + stri1 = SymTridiagonal(rand(T, k), rand(T, k-1)) + stri2 = SymTridiagonal(rand(T, k), rand(T, k-1)) + C = rand(T, k, k) + specialmatrices = (bi1, bi2, tri1, tri2, stri1, stri2) + for A in specialmatrices + B = specialmatrices[rand(1:length(specialmatrices))] + test_mul(C, A, B) + end + for S in specialmatrices + l = rand(1:6) + B = randn(k, l) + C = randn(k, l) + test_mul(C, S, B) + A = randn(l, k) + C = randn(l, k) + test_mul(C, A, S) + end + end + for T in eltypes + A = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool)) + B = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool)) + C = randn(2,2) + test_mul(C, A, B) + B = randn(2, 9) + C = randn(2, 9) + test_mul(C, A, B) + end + let + tri44 = Tridiagonal(randn(3), randn(4), randn(3)) + tri33 = Tridiagonal(randn(2), randn(3), randn(2)) + full43 = randn(4, 3) + full24 = randn(2, 4) + full33 = randn(3, 3) + full44 = randn(4, 4) + @test_throws DimensionMismatch A_mul_B!(full43, tri44, tri33) + @test_throws DimensionMismatch A_mul_B!(full44, tri44, tri33) + @test_throws DimensionMismatch A_mul_B!(full44, tri44, full43) + @test_throws DimensionMismatch A_mul_B!(full43, tri33, full43) + @test_throws DimensionMismatch A_mul_B!(full43, full43, tri44) + end +end diff --git a/test/sparsedir/sparsevector.jl b/test/sparsedir/sparsevector.jl index 07d00954372ad6..8f260e328fee00 100644 --- a/test/sparsedir/sparsevector.jl +++ b/test/sparsedir/sparsevector.jl @@ -950,3 +950,20 @@ x = sparsevec(1:7, [3., 2., -1., 1., -2., -3., 3.], 15) @test collect(sort(x, by=abs)) == sort(collect(x), by=abs) @test collect(sort(x, by=sign)) == sort(collect(x), by=sign) @test collect(sort(x, by=inv)) == sort(collect(x), by=inv) + +#fill! +for Tv in [Float32, Float64, Int64, Int32, Complex128] + for Ti in [Int16, Int32, Int64, BigInt] + sptypes = (SparseMatrixCSC{Tv, Ti}, SparseVector{Tv, Ti}) + sizes = [(3, 4), (3,)] + for (siz, Sp) in zip(sizes, sptypes) + arr = rand(Tv, siz...) + sparr = Sp(arr) + fillval = rand(Tv) + fill!(sparr, fillval) + @test full(sparr) == fillval * ones(arr) + fill!(sparr, 0) + @test full(sparr) == zeros(arr) + end + end +end