Skip to content

Commit

Permalink
check sizes of arguments in dot; fixes #28617
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Aug 15, 2018
1 parent eabb601 commit c87ede6
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 18 deletions.
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

function dot(x::BitVector, y::BitVector)
# simplest way to mimic Array dot behavior
length(x) == length(y) || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
s = 0
xc = x.chunks
yc = y.chunks
Expand Down
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,24 @@ end
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
@assert !has_offset_axes(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
GC.@preserve DX DY dot(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
@assert !has_offset_axes(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
GC.@preserve DX DY dotc(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
@assert !has_offset_axes(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if size(DX) != size(DY)
throw(DimensionMismatch("The first array has size $(size(DX)) which does not match the size of the second, $(size(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
GC.@preserve DX DY dotu(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
Expand Down
15 changes: 8 additions & 7 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,11 @@ dot(x::Number, y::Number) = conj(x) * y
dot(x, y)
x ⋅ y
Compute the dot product between two vectors. For complex vectors, the first
vector is conjugated. When the vectors have equal lengths, calling `dot` is
semantically equivalent to `sum(dot(vx,vy) for (vx,vy) in zip(x, y))`.
Compute the dot product between two arrays of the same size as if they were
vectors. For complex arrays, the elements of the first array are conjugated.
This is the classical dot product for vectors and the Hilbert-Schmidt dot
product `tr(x' * y)` for matrices. When the arrays have equal sizes, calling
`dot` is semantically equivalent to `sum(dot(vx,vy) for (vx,vy) in zip(x, y))`.
# Examples
```jldoctest
Expand All @@ -697,11 +699,10 @@ julia> dot([im; im], [1; 1])
```
"""
function dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
if lx == 0
if length(x) == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
s = zero(dot(first(x), first(y)))
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ Random.seed!(100)
x2 = convert(Vector{elty}, randn(n))
@test BLAS.dot(x1,x2) sum(x1.*x2)
@test_throws DimensionMismatch BLAS.dot(x1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, randn(4,4))
y2 = convert(Matrix{elty}, randn(2,8))
@test_throws DimensionMismatch BLAS.dot(y1, y2)
@test sum(y1[i] * y2[i] for i in 1:16) BLAS.dot(vec(y1), vec(y2))
else
z1 = convert(Vector{elty}, complex.(randn(n),randn(n)))
z2 = convert(Vector{elty}, complex.(randn(n),randn(n)))
@test BLAS.dotc(z1,z2) sum(conj(z1).*z2)
@test BLAS.dotu(z1,z2) sum(z1.*z2)
@test_throws DimensionMismatch BLAS.dotc(z1,rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.dotu(z1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, complex.(randn(4,4),randn(4,4)))
y2 = convert(Matrix{elty}, complex.(randn(2,8),randn(2,8)))
@test_throws DimensionMismatch BLAS.dotc(y1, y2)
@test_throws DimensionMismatch BLAS.dotu(y1, y2)
@test sum(conj(y1[i]) * y2[i] for i in 1:16) BLAS.dotc(vec(y1), vec(y2))
@test sum(y1[i] * y2[i] for i in 1:16) BLAS.dotu(vec(y1), vec(y2))
end
end
@testset "iamax" begin
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ end
@test dot(X, Y) == convert(elty, 35.0)
Z = convert(Vector{Matrix{elty}},[reshape(1:4, 2, 2), fill(1, 2, 2)])
@test dot(Z, Z) == convert(elty, 34.0)
Y2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
@test_throws DimensionMismatch dot(X, Y2)
@test_throws DimensionMismatch dot(vec(X), Y2)
@test dot(X, Y) == dot(vec(X), vec(Y2))
end

dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
Expand All @@ -251,6 +255,21 @@ dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y)
end
end
end
for elty in (Float32, Float64, ComplexF32, ComplexF64)
XX = convert(Matrix{elty},[1.0 2.0; 3.0 4.0])
YY = convert(Matrix{elty},[1.5 2.5; 3.5 4.5])
YY2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
for X in (copy(XX), view(XX, 1:2, 1:2)), Y in (copy(YY), view(YY, 1:2, 1:2)), Y2 in (copy(YY2), view(YY2, 1:1, 1:4))
@test dot1(X, Y) == convert(elty, 35.0)
@test dot2(X, Y) == convert(elty, 35.0)
@test dot1(X, Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(X, Y2)
@test dot1(vec(X), Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(vec(X), Y2)
@test dot1(X, Y) == dot1(vec(X), vec(Y2))
@test dot2(X, Y) == dot2(vec(X), vec(Y2))
end
end
end

@testset "Issue 11978" begin
Expand Down
4 changes: 3 additions & 1 deletion stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ end
# Frobenius dot/inner product: trace(A'B)
function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2}
m, n = size(A)
size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
if size(B) != (m,n)
throw(DimensionMismatch("The first array has size $(size(A)) which does not match the size of the second, $(size(B)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
r = dot(zero(T1), zero(T2))
@inbounds for j = 1:n
ia = A.colptr[j]; ia_nxt = A.colptr[j+1]
Expand Down
12 changes: 9 additions & 3 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,9 @@ end
function dot(x::AbstractVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number}
@assert !has_offset_axes(x, y)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(y)
nzval = nonzeros(y)
s = dot(zero(Tx), zero(Ty))
Expand All @@ -1411,7 +1413,9 @@ end
function dot(x::SparseVectorUnion{Tx}, y::AbstractVector{Ty}) where {Tx<:Number,Ty<:Number}
@assert !has_offset_axes(x, y)
n = length(y)
length(x) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(x)
nzval = nonzeros(x)
s = dot(zero(Tx), zero(Ty))
Expand Down Expand Up @@ -1445,7 +1449,9 @@ end
function dot(x::SparseVectorUnion{<:Number}, y::SparseVectorUnion{<:Number})
x === y && return sum(abs2, x)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if size(x) != size(y)
throw(DimensionMismatch("The first array has size $(size(x)) which does not match the size of the second, $(size(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end

xnzind = nonzeroinds(x)
ynzind = nonzeroinds(y)
Expand Down

0 comments on commit c87ede6

Please sign in to comment.