Skip to content

Commit

Permalink
_fill_dot support general vectors (#229)
Browse files Browse the repository at this point in the history
* Update fillalgebra.jl

* promote_op

* add breaking test

* add breaking test

* fix

* accept round-off errors

* Update test/runtests.jl

Co-authored-by: Sheehan Olver <[email protected]>

* update

* support inf and nan

* fix 1.6

* Update fillalgebra.jl

* Update fillalgebra.jl

* trying to fix Julia 1.6

* comments

* Update runtests.jl

* add @inferred

---------

Co-authored-by: Sheehan Olver <[email protected]>
  • Loading branch information
putianyi889 and dlfivefifty authored Mar 30, 2023
1 parent fea49f6 commit d8e1f9a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
30 changes: 7 additions & 23 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,38 +159,22 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::ZerosVector{T}) where T<:Rea
end
*(a::Transpose{T, <:AbstractMatrix{T}}, b::ZerosVector{T}) where T<:Real = mult_zeros(a, b)

# treat zero separately to support ∞-vectors
function _zero_dot(a, b)
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
zero(promote_type(eltype(a),eltype(b)))
end

_fill_dot(a::Zeros, b::Zeros) = _zero_dot(a, b)
_fill_dot(a::Zeros, b) = _zero_dot(a, b)
_fill_dot(a, b::Zeros) = _zero_dot(a, b)
_fill_dot(a::Zeros, b::AbstractFill) = _zero_dot(a, b)
_fill_dot(a::AbstractFill, b::Zeros) = _zero_dot(a, b)

function _fill_dot(a::AbstractFill, b::AbstractFill)
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
getindex_value(a)getindex_value(b)*length(b)
end

# support types with fast sum
function _fill_dot(a::AbstractFill, b)
# infinite cases should be supported in InfiniteArrays.jl
# type issues of Bool dot are ignored at present.
function _fill_dot(a::AbstractFillVector{T}, b::AbstractVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
getindex_value(a)sum(b)
dot(getindex_value(a), sum(b))
end

function _fill_dot(a, b::AbstractFill)
function _fill_dot_rev(a::AbstractVector{T}, b::AbstractFillVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
sum(a)getindex_value(b)
dot(sum(a), getindex_value(b))
end


dot(a::AbstractFillVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractFillVector, b::AbstractVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot_rev(a, b)

function dot(u::AbstractVector, E::Eye, v::AbstractVector)
length(u) == size(E,1) && length(v) == size(E,2) ||
Expand Down
62 changes: 38 additions & 24 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,20 +329,32 @@ end
# type, and produce numerically correct results.
as_array(x::AbstractArray) = Array(x)
as_array(x::UniformScaling) = x
function test_addition_and_subtraction(As, Bs, Tout::Type)
equal_or_undef(a::Number, b::Number) = (a == b) || isequal(a, b)
equal_or_undef(a, b) = all(equal_or_undef.(a, b))
function test_addition_subtraction_dot(As, Bs, Tout::Type)
for A in As, B in Bs
@testset "$(typeof(A)) ± $(typeof(B))" begin
@testset "$(typeof(A)) and $(typeof(B))" begin
@test A + B isa Tout{promote_type(eltype(A), eltype(B))}
@test as_array(A + B) == as_array(A) + as_array(B)
@test equal_or_undef(as_array(A + B), as_array(A) + as_array(B))

@test A - B isa Tout{promote_type(eltype(A), eltype(B))}
@test as_array(A - B) == as_array(A) - as_array(B)
@test equal_or_undef(as_array(A - B), as_array(A) - as_array(B))

@test B + A isa Tout{promote_type(eltype(B), eltype(A))}
@test as_array(B + A) == as_array(B) + as_array(A)
@test equal_or_undef(as_array(B + A), as_array(B) + as_array(A))

@test B - A isa Tout{promote_type(eltype(B), eltype(A))}
@test as_array(B - A) == as_array(B) - as_array(A)
@test equal_or_undef(as_array(B - A), as_array(B) - as_array(A))

# Julia 1.6 doesn't support dot(UniformScaling)
if VERSION < v"1.6.0" || VERSION >= v"1.8.0"
d1 = dot(A, B)
d2 = dot(as_array(A), as_array(B))
d3 = dot(B, A)
d4 = dot(as_array(B), as_array(A))
@test d1 d2 || d1 d2
@test d3 d4 || d3 d4
end
end
end
end
Expand Down Expand Up @@ -372,37 +384,37 @@ end
@test -A_fill === Fill(-A_fill.value, 5)

# FillArray +/- FillArray should construct a new FillArray.
test_addition_and_subtraction((A_fill, B_fill), (A_fill, B_fill), Fill)
test_addition_subtraction_dot((A_fill, B_fill), (A_fill, B_fill), Fill)
test_addition_and_subtraction_dim_mismatch(A_fill, Fill(randn(rng), 5, 2))

# FillArray + Array (etc) should construct a new Array using `getindex`.
A_dense, B_dense = randn(rng, 5), [5, 4, 3, 2, 1]
test_addition_and_subtraction((A_fill, B_fill), (A_dense, B_dense), Array)
B_dense = (randn(rng, 5), [5, 4, 3, 2, 1], fill(Inf, 5), fill(NaN, 5))
test_addition_subtraction_dot((A_fill, B_fill), B_dense, Array)
test_addition_and_subtraction_dim_mismatch(A_fill, randn(rng, 5, 2))

# FillArray + StepLenRange / UnitRange (etc) should yield an AbstractRange.
A_ur, B_ur = 1.0:5.0, 6:10
test_addition_and_subtraction((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
test_addition_subtraction_dot((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
test_addition_and_subtraction_dim_mismatch(A_fill, 1.0:6.0)
test_addition_and_subtraction_dim_mismatch(A_fill, 5:10)

# FillArray + UniformScaling should yield a Matrix in general
As_fill_square = (Fill(randn(rng, Float64), 3, 3), Fill(5, 4, 4))
Bs_us = (UniformScaling(2.3), UniformScaling(3))
test_addition_and_subtraction(As_fill_square, Bs_us, Matrix)
test_addition_subtraction_dot(As_fill_square, Bs_us, Matrix)
As_fill_nonsquare = (Fill(randn(rng, Float64), 3, 2), Fill(5, 3, 4))
for A in As_fill_nonsquare, B in Bs_us
test_addition_and_subtraction_dim_mismatch(A, B)
end

# FillArray + StaticArray should not have ambiguities
A_svec, B_svec = SVector{5}(rand(5)), SVector(1, 2, 3, 4, 5)
test_addition_and_subtraction((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})
test_addition_subtraction_dot((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})

# Issue #224
A_matmat, B_matmat = Fill(rand(3,3),5), [rand(3,3) for n=1:5]
test_addition_and_subtraction((A_matmat,), (A_matmat,), Fill)
test_addition_and_subtraction((B_matmat,), (A_matmat,), Vector)
test_addition_subtraction_dot((A_matmat,), (A_matmat,), Fill)
test_addition_subtraction_dot((B_matmat,), (A_matmat,), Vector)

# Optimizations for Zeros and RectOrDiagonal{<:Any, <:AbstractFill}
As_special_square = (
Expand All @@ -412,7 +424,7 @@ end
RectDiagonal(Fill(randn(rng, Float64), 3), 3, 3), RectDiagonal(Fill(3, 4), 4, 4)
)
DiagonalAbstractFill{T} = Diagonal{T, <:AbstractFill{T, 1}}
test_addition_and_subtraction(As_special_square, Bs_us, DiagonalAbstractFill)
test_addition_subtraction_dot(As_special_square, Bs_us, DiagonalAbstractFill)
As_special_nonsquare = (
Zeros(3, 2), Zeros{Int}(3, 4),
Eye(3, 2), Eye{Int}(3, 4),
Expand Down Expand Up @@ -537,7 +549,7 @@ end
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
@test_throws DimensionMismatch randn(4)' * Zeros(3)
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
@test Zeros(5)' * randn(5) Zeros(5)' * Zeros(5) Zeros(5)' * Ones(5) 0.0
@test abs(Zeros(5)' * randn(5)) abs(Zeros(5)' * Zeros(5)) abs(Zeros(5)' * Ones(5)) 0.0
@test Zeros(5) * Zeros(6)' Zeros(5,1) * Zeros(6)' Zeros(5,6)
@test randn(5) * Zeros(6)' randn(5,1) * Zeros(6)' Zeros(5,6)
@test Zeros(5) * randn(6)' Zeros(5,6)
Expand All @@ -552,7 +564,7 @@ end
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
@test transpose(Zeros(5)) * randn(5,3) transpose(Zeros(5))*Zeros(5,3) transpose(Zeros(5))*Ones(5,3) transpose(Zeros(3))
@test transpose(Zeros(5)) * randn(5) transpose(Zeros(5)) * Zeros(5) transpose(Zeros(5)) * Ones(5) 0.0
@test abs(transpose(Zeros(5)) * randn(5)) abs(transpose(Zeros(5)) * Zeros(5)) abs(transpose(Zeros(5)) * Ones(5)) 0.0
@test randn(5) * transpose(Zeros(6)) randn(5,1) * transpose(Zeros(6)) Zeros(5,6)
@test Zeros(5) * transpose(randn(6)) Zeros(5,6)
@test transpose(randn(5)) * Zeros(5) 0.0
Expand All @@ -571,13 +583,13 @@ end
@test +(z1) === z1
@test -(z1) === z1

test_addition_and_subtraction((z1, z2), (z1, z2), Zeros)
test_addition_subtraction_dot((z1, z2), (z1, z2), Zeros)
test_addition_and_subtraction_dim_mismatch(z1, Zeros{Float64}(4, 2))
end

# `Zeros` +/- `Fill`s should yield `Fills`.
fill1, fill2 = Fill(5.0, 4), Fill(5, 4)
test_addition_and_subtraction((z1, z2), (fill1, fill2), Fill)
test_addition_subtraction_dot((z1, z2), (fill1, fill2), Fill)
test_addition_and_subtraction_dim_mismatch(z1, Fill(5, 5))

X = randn(3, 5)
Expand Down Expand Up @@ -1326,17 +1338,19 @@ end
Random.seed!(5)
u = rand(n)
v = rand(n)
c = rand(ComplexF16, n)

@test dot(u, D, v) == dot(u, v)
@test dot(u, 2D, v) == 2dot(u, v)
@test dot(u, Z, v) == 0

@test dot(Zeros(5), Zeros{ComplexF16}(5)) zero(ComplexF64)
@test dot(Zeros(5), Ones{ComplexF16}(5)) zero(ComplexF64)
@test dot(Ones{ComplexF16}(5), Zeros(5)) zero(ComplexF64)
@test dot(randn(5), Zeros{ComplexF16}(5)) dot(Zeros{ComplexF16}(5), randn(5)) zero(ComplexF64)
@test @inferred(dot(Zeros(5), Zeros{ComplexF16}(5))) zero(ComplexF64)
@test @inferred(dot(Zeros(5), Ones{ComplexF16}(5))) zero(ComplexF64)
@test abs(@inferred(dot(Ones{ComplexF16}(5), Zeros(5)))) abs(@inferred(dot(randn(5), Zeros{ComplexF16}(5)))) abs(@inferred(dot(Zeros{ComplexF16}(5), randn(5)))) zero(Float64) # 0.0 !≡ -0.0
@test @inferred(dot(c, Fill(1 + im, 15))) (@inferred(dot(Fill(1 + im, 15), c)))' @inferred(dot(c, fill(1 + im, 15)))

@test dot(Fill(1,5), Fill(2.0,5)) 10.0
@test @inferred(dot(Fill(1,5), Fill(2.0,5))) 10.0
@test_skip dot(Fill(true,5), Fill(Int8(1),5)) isa Int8 # not working at present

let N = 2^big(1000) # fast dot for fast sum
@test dot(Fill(2,N),1:N) == dot(Fill(2,N),1:N) == dot(1:N,Fill(2,N)) == 2*sum(1:N)
Expand Down

0 comments on commit d8e1f9a

Please sign in to comment.