Skip to content

Commit

Permalink
Add messages to DimensionMismatch errors (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Oct 31, 2023
1 parent 81fc6f3 commit f455a8e
Showing 1 changed file with 64 additions and 31 deletions.
95 changes: 64 additions & 31 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
end

function _spmatmul!(C, A, B, α, β)
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
size(A, 2) == size(B, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
size(A, 1) == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
Expand All @@ -76,9 +79,12 @@ end
(T = promote_op(matprod, TA, eltype(B)); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B))

function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
size(A, 2) == size(C, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
size(A, 1) == size(B, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
Expand Down Expand Up @@ -110,9 +116,12 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::Strided
end
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
size(A, 2) == size(C, 2) || throw(DimensionMismatch())
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
Expand All @@ -127,9 +136,12 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixC
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
size(A, 2) == size(C, 2) || throw(DimensionMismatch())
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
Expand All @@ -149,9 +161,12 @@ end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::AbstractSparseMatrixCSC, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) || throw(DimensionMismatch())
mA == size(C, 1) || throw(DimensionMismatch())
size(B, 1) == size(C, 2) || throw(DimensionMismatch())
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
mA == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
size(B, 1) == size(C, 2) ||
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
rv = rowvals(B)
nzv = nonzeros(B)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
Expand Down Expand Up @@ -201,7 +216,8 @@ function spmatmul(A::SparseOrTri, B::Union{SparseOrTri,AbstractCompressedVector,
Ti = promote_type(indtype(A), indtype(B))
mA, nA = size(A)
nB = size(B, 2)
nA == size(B, 1) || throw(DimensionMismatch())
mB = size(B, 1)
nA == mB || throw(DimensionMismatch("second dimension of A, $nA, does not match the first dimension of B, $mB"))

nnzC = min(estimate_mulsize(mA, nnz(A), nA, nnz(B), nB) * 11 ÷ 10 + mA, mA*nB)
colptrC = Vector{Ti}(undef, nB+1)
Expand Down Expand Up @@ -346,14 +362,15 @@ end
function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractVector{T3}) where {T1,T2,T3}
require_one_based_indexing(x, y)
m, n = size(A)
(length(x) == m && n == length(y)) || throw(DimensionMismatch())
(length(x) == m && n == length(y)) ||
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
s = dot(zero(T1), zero(T2), zero(T3))
T = typeof(s)
(iszero(m) || iszero(n)) && return s

rowvals = getrowval(A)
nzvals = getnzval(A)

@inbounds @simd for col in 1:n
ycol = y[col]
for j in nzrange(A, col)
Expand All @@ -366,7 +383,8 @@ function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractV
end
function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector)
m, n = size(A)
length(x) == m && n == length(y) || throw(DimensionMismatch())
length(x) == m && n == length(y) ||
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
if iszero(m) || iszero(n)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
Expand Down Expand Up @@ -403,7 +421,7 @@ function dot(A::Union{DenseMatrixUnion,WrapperMatrixTypes{<:Any,Union{DenseMatri
T = promote_type(eltype(A), eltype(B))
(m, n) = size(A)
if (m, n) != size(B)
throw(DimensionMismatch())
throw(DimensionMismatch("A has size ($m, $n) but B has size $(size(B))"))
end
s = zero(T)
if m * n == 0
Expand Down Expand Up @@ -872,7 +890,8 @@ end
function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::StridedVecOrMat{T}, A, B, α, β) where T
n = size(A, 2)
m = size(B, 2)
n == size(B, 1) == size(C, 1) && m == size(C, 2) || throw(DimensionMismatch())
n == size(B, 1) == size(C, 1) && m == size(C, 2) ||
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
rv = rowvals(A)
nzv = nonzeros(A)
let z = T(0), sumcol=z, αxj=z, aarc=z, α = α
Expand Down Expand Up @@ -916,7 +935,8 @@ dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC
function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, rangefun::Function, diagop::Function, odiagop::Function)
require_one_based_indexing(x, y)
m, n = size(A)
(length(x) == m && n == length(y)) || throw(DimensionMismatch())
(length(x) == m && n == length(y)) ||
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
if iszero(m) || iszero(n)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
Expand Down Expand Up @@ -946,7 +966,8 @@ dot(x::SparseVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC},
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real)
function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rangefun::Function, diagop::Function)
m, n = size(A)
length(x) == m && n == length(y) || throw(DimensionMismatch())
length(x) == m && n == length(y) ||
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
if iszero(m) || iszero(n)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
Expand Down Expand Up @@ -1591,7 +1612,9 @@ end
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal)
m, n = size(A)
b = D.diag
(n==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
lb = length(b)
n == lb || throw(DimensionMismatch("A has size ($m, $n) but D has size ($lb, $lb)"))
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
copyinds!(C, A)
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
Expand All @@ -1605,7 +1628,9 @@ end
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC)
m, n = size(A)
b = D.diag
(m==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
lb = length(b)
m == lb || throw(DimensionMismatch("D has size ($lb, $lb) but A has size ($m, $n)"))
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
copyinds!(C, A)
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
Expand All @@ -1618,15 +1643,15 @@ function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCS
end

function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, b::Number)
size(A)==size(C) || throw(DimensionMismatch())
size(A)==size(C) || throw(DimensionMismatch("A has size $(size(A)) but C has size $(size(C))"))
copyinds!(C, A)
resize!(nonzeros(C), length(nonzeros(A)))
mul!(nonzeros(C), nonzeros(A), b)
C
end

function mul!(C::AbstractSparseMatrixCSC, b::Number, A::AbstractSparseMatrixCSC)
size(A)==size(C) || throw(DimensionMismatch())
size(A)==size(C) || throw(DimensionMismatch("A has size $(size(A)) but C has size $(size(C))"))
copyinds!(C, A)
resize!(nonzeros(C), length(nonzeros(A)))
mul!(nonzeros(C), b, nonzeros(A))
Expand All @@ -1645,7 +1670,8 @@ end

function rmul!(A::AbstractSparseMatrixCSC, D::Diagonal)
m, n = size(A)
(n == size(D, 1)) || throw(DimensionMismatch())
szD = size(D, 1)
(n == szD) || throw(DimensionMismatch("A has size ($m, $n) but D has size ($szD, $szD)"))
Anzval = nonzeros(A)
@inbounds for col in 1:n, p in nzrange(A, col)
Anzval[p] = Anzval[p] * D.diag[col]
Expand All @@ -1655,7 +1681,8 @@ end

function lmul!(D::Diagonal, A::AbstractSparseMatrixCSC)
m, n = size(A)
(m == size(D, 2)) || throw(DimensionMismatch())
ds2 = size(D, 2)
(m == ds2) || throw(DimensionMismatch("D has size ($ds2, $ds2) but A has size ($m, $n)"))
Anzval = nonzeros(A)
Arowval = rowvals(A)
@inbounds for col in 1:n, p in nzrange(A, col)
Expand All @@ -1667,7 +1694,10 @@ end
function ldiv!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC)
m, n = size(A)
b = D.diag
(m==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
lb = length(b)
m==lb || throw(DimensionMismatch("D has size ($lb, $lb) but A has size ($m, $n)"))
szC = size(C)
size(A) == szC || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $szC"))
copyinds!(C, A)
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
Expand All @@ -1682,7 +1712,10 @@ end
function LinearAlgebra._rdiv!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal)
m, n = size(A)
b = D.diag
(n==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
lb = length(b)
n == lb || throw(DimensionMismatch("A has size ($m, $n) but D has size ($lb, $lb)"))
szC = size(C)
size(A) == szC || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $szC"))
copyinds!(C, A)
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
Expand Down

0 comments on commit f455a8e

Please sign in to comment.