From 4ad663d1a410cf179ea23c5c04abd28532db4595 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <44583944+putianyi889@users.noreply.github.com> Date: Sat, 18 Mar 2023 16:10:58 +0000 Subject: [PATCH] add tests for and fix compatibility with StaticArrays.jl (#222) * add tests for compatibility with StaticArrays.jl * fix test * fix compatibility Co-Authored-By: Christopher Rackauckas * add Zeros x StaticArray compatibility * typo * bump version --------- Co-authored-by: Christopher Rackauckas --- Project.toml | 2 +- src/fillalgebra.jl | 18 ++++++++++++------ test/runtests.jl | 32 ++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 1ce598e3..496208c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.8" +version = "0.13.9" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 5d5e133a..fb70fadd 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -216,7 +216,9 @@ function +(a::Zeros{T}, b::Zeros{V}) where {T, V} # for disambiguity promote_shape(a,b) return elconvert(promote_op(+,T,V),a) end -for TYPE in (:AbstractArray, :AbstractFill) # AbstractFill for disambiguity +# no AbstractArray. Otherwise incompatible with StaticArrays.jl +# AbstractFill for disambiguity +for TYPE in (:Array, :AbstractFill, :AbstractRange, :Diagonal) @eval function +(a::$TYPE{T}, b::Zeros{V}) where {T, V} promote_shape(a,b) return elconvert(promote_op(+,T,V),a) @@ -236,13 +238,17 @@ end -(a::Ones, b::Ones) = Zeros(a) + Zeros(b) -# necessary for AbstractRange, Diagonal, etc +# no AbstractArray. Otherwise incompatible with StaticArrays.jl +for TYPE in (:Array, :AbstractRange) + @eval begin + +(a::$TYPE, b::AbstractFill) = fill_add(a, b) + -(a::$TYPE, b::AbstractFill) = a + (-b) + +(a::AbstractFill, b::$TYPE) = fill_add(b, a) + -(a::AbstractFill, b::$TYPE) = a + (-b) + end +end +(a::AbstractFill, b::AbstractFill) = fill_add(a, b) -+(a::AbstractFill, b::AbstractArray) = fill_add(b, a) -+(a::AbstractArray, b::AbstractFill) = fill_add(a, b) -(a::AbstractFill, b::AbstractFill) = a + (-b) --(a::AbstractFill, b::AbstractArray) = a + (-b) --(a::AbstractArray, b::AbstractFill) = a + (-b) @inline function fill_add(a, b::AbstractFill) promote_shape(a, b) diff --git a/test/runtests.jl b/test/runtests.jl index 2c23ebe1..f4dcede8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -308,26 +308,30 @@ as_array(x::AbstractArray) = Array(x) as_array(x::UniformScaling) = x function test_addition_and_subtraction(As, Bs, Tout::Type) for A in As, B in Bs - @test A + B isa Tout{promote_type(eltype(A), eltype(B))} - @test as_array(A + B) == as_array(A) + as_array(B) + @testset "$A ± $B" begin + @test A + B isa Tout{promote_type(eltype(A), eltype(B))} + @test as_array(A + B) == as_array(A) + as_array(B) - @test A - B isa Tout{promote_type(eltype(A), eltype(B))} - @test as_array(A - B) == as_array(A) - as_array(B) + @test A - B isa Tout{promote_type(eltype(A), eltype(B))} + @test as_array(A - B) == as_array(A) - as_array(B) - @test B + A isa Tout{promote_type(eltype(B), eltype(A))} - @test as_array(B + A) == as_array(B) + as_array(A) + @test B + A isa Tout{promote_type(eltype(B), eltype(A))} + @test as_array(B + A) == as_array(B) + as_array(A) - @test B - A isa Tout{promote_type(eltype(B), eltype(A))} - @test as_array(B - A) == as_array(B) - as_array(A) + @test B - A isa Tout{promote_type(eltype(B), eltype(A))} + @test as_array(B - A) == as_array(B) - as_array(A) + end end end # Check that all permutations of + / - throw a `DimensionMismatch` exception. function test_addition_and_subtraction_dim_mismatch(a, b) - @test_throws DimensionMismatch a + b - @test_throws DimensionMismatch a - b - @test_throws DimensionMismatch b + a - @test_throws DimensionMismatch b - a + @testset "$a ± $b" begin + @test_throws DimensionMismatch a + b + @test_throws DimensionMismatch a - b + @test_throws DimensionMismatch b + a + @test_throws DimensionMismatch b - a + end end @testset "FillArray addition and subtraction" begin @@ -368,6 +372,10 @@ end test_addition_and_subtraction_dim_mismatch(A, B) end + # FillArray + StaticArray should not have ambiguities + A_svec, B_svec = SVector{5}(rand(5)), SVector(1, 2, 3, 4, 5) + test_addition_and_subtraction((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5}) + # Optimizations for Zeros and RectOrDiagonal{<:Any, <:AbstractFill} As_special_square = ( Zeros(3, 3), Zeros{Int}(4, 4),