Skip to content

Commit

Permalink
check sizes of arguments in dot; fixes #28617
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha authored and andreasnoack committed Sep 28, 2021
1 parent 84cc901 commit 6e97d2a
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 15 deletions.
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

function dot(x::BitVector, y::BitVector)
# simplest way to mimic Array dot behavior
length(x) == length(y) || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
s = 0
xc = x.chunks
yc = y.chunks
Expand Down
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,24 +400,24 @@ end
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
Expand Down
26 changes: 22 additions & 4 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -904,12 +904,30 @@ end

dot(x::Number, y::Number) = conj(x) * y

"""
dot(x, y)
x ⋅ y
Compute the dot product between two arrays of the same size as if they were
vectors. For complex arrays, the elements of the first array are conjugated.
This is the classical dot product for vectors and the Hilbert-Schmidt dot
product `tr(x' * y)` for matrices. When the arrays have equal sizes, calling
`dot` is semantically equivalent to `sum(dot(vx,vy) for (vx,vy) in zip(x, y))`.
# Examples
```jldoctest
julia> dot([1; 1], [2; 3])
5
julia> dot([im; im], [1; 1])
0 - 2im
```
"""
function dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
if lx == 0
if length(x) == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
s = zero(dot(first(x), first(y)))
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,23 @@ Random.seed!(100)
x2 = convert(Vector{elty}, randn(n))
@test BLAS.dot(x1,x2) sum(x1.*x2)
@test_throws DimensionMismatch BLAS.dot(x1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, randn(4,4))
y2 = convert(Matrix{elty}, randn(2,8))
@test_throws DimensionMismatch BLAS.dot(y1, y2)
@test sum(y1[i] * y2[i] for i in 1:16) BLAS.dot(vec(y1), vec(y2))
else
z1 = convert(Vector{elty}, complex.(randn(n),randn(n)))
z2 = convert(Vector{elty}, complex.(randn(n),randn(n)))
@test BLAS.dotc(z1,z2) sum(conj(z1).*z2)
@test BLAS.dotu(z1,z2) sum(z1.*z2)
@test_throws DimensionMismatch BLAS.dotc(z1,rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.dotu(z1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, complex.(randn(4,4),randn(4,4)))
y2 = convert(Matrix{elty}, complex.(randn(2,8),randn(2,8)))
@test_throws DimensionMismatch BLAS.dotc(y1, y2)
@test_throws DimensionMismatch BLAS.dotu(y1, y2)
@test sum(conj(y1[i]) * y2[i] for i in 1:16) BLAS.dotc(vec(y1), vec(y2))
@test sum(y1[i] * y2[i] for i in 1:16) BLAS.dotu(vec(y1), vec(y2))
end
end
@testset "iamax" begin
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ end
@test dot(X, Y) == convert(elty, 35.0)
Z = convert(Vector{Matrix{elty}},[reshape(1:4, 2, 2), fill(1, 2, 2)])
@test dot(Z, Z) == convert(elty, 34.0)
Y2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
@test_throws DimensionMismatch dot(X, Y2)
@test_throws DimensionMismatch dot(vec(X), Y2)
@test dot(X, Y) == dot(vec(X), vec(Y2))
end

dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
Expand All @@ -454,6 +458,21 @@ dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y)
end
end
end
for elty in (Float32, Float64, ComplexF32, ComplexF64)
XX = convert(Matrix{elty},[1.0 2.0; 3.0 4.0])
YY = convert(Matrix{elty},[1.5 2.5; 3.5 4.5])
YY2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
for X in (copy(XX), view(XX, 1:2, 1:2)), Y in (copy(YY), view(YY, 1:2, 1:2)), Y2 in (copy(YY2), view(YY2, 1:1, 1:4))
@test dot1(X, Y) == convert(elty, 35.0)
@test dot2(X, Y) == convert(elty, 35.0)
@test dot1(X, Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(X, Y2)
@test dot1(vec(X), Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(vec(X), Y2)
@test dot1(X, Y) == dot1(vec(X), vec(Y2))
@test dot2(X, Y) == dot2(vec(X), vec(Y2))
end
end
end

@testset "Issue 11978" begin
Expand Down
4 changes: 3 additions & 1 deletion stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ ilog2(n::Integer) = sizeof(n)<<3 - leading_zeros(n)
# Frobenius dot/inner product: trace(A'B)
function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2}
m, n = size(A)
size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
if size(B) != (m,n)
throw(DimensionMismatch("The first array has size $(size(A)) which does not match the size of the second, $(size(B)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
r = dot(zero(T1), zero(T2))
@inbounds for j = 1:n
ia = getcolptr(A)[j]; ia_nxt = getcolptr(A)[j+1]
Expand Down
12 changes: 9 additions & 3 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,9 @@ end
function dot(x::AbstractVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number}
require_one_based_indexing(x, y)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(y)
nzval = nonzeros(y)
s = dot(zero(Tx), zero(Ty))
Expand All @@ -1500,7 +1502,9 @@ end
function dot(x::SparseVectorUnion{Tx}, y::AbstractVector{Ty}) where {Tx<:Number,Ty<:Number}
require_one_based_indexing(x, y)
n = length(y)
length(x) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(x)
nzval = nonzeros(x)
s = dot(zero(Tx), zero(Ty))
Expand Down Expand Up @@ -1534,7 +1538,9 @@ end
function dot(x::SparseVectorUnion{<:Number}, y::SparseVectorUnion{<:Number})
x === y && return sum(abs2, x)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end

xnzind = nonzeroinds(x)
ynzind = nonzeroinds(y)
Expand Down

0 comments on commit 6e97d2a

Please sign in to comment.