From 414dba754c8c0dc23eadaf28e023425315276061 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 27 Aug 2024 12:55:21 +0530 Subject: [PATCH] sum for OneElement (#375) * sum for OneElement * Add tests * Accept dims in sum * Add tests * Bump version to v1.13.0 * Ensure that init kwarg works * Update tests for v1.6 --------- Co-authored-by: Sheehan Olver --- Project.toml | 2 +- src/oneelement.jl | 17 +++++++++------- test/runtests.jl | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index f22c04ab..4ec9f49a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/oneelement.jl b/src/oneelement.jl index a46f53a4..39687a2a 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -159,13 +159,6 @@ function *(A::OneElementMatrix, B::OneElementVecOrMat) OneElement(val, (A.ind[1], B.ind[2:end]...), (axes(A,1), axes(B)[2:end]...)) end -function *(A::AbstractFillMatrix, x::OneElementVector) - check_matmul_sizes(A, x) - val = getindex_value(A) * getindex_value(x) - Fill(val, (axes(A,1),)) -end -*(A::AbstractZerosMatrix, x::OneElementVector) = mult_zeros(A, x) - *(A::OneElementMatrix, x::AbstractZerosVector) = mult_zeros(A, x) function *(A::OneElementMatrix, B::AbstractFillVector) @@ -448,3 +441,13 @@ _maybesize(t) = t Base.show(io::IO, A::OneElement) = print(io, OneElement, "(", A.val, ", ", A.ind, ", ", _maybesize(axes(A)), ")") Base.show(io::IO, A::OneElement{<:Any,1,Tuple{Int},Tuple{Base.OneTo{Int}}}) = print(io, OneElement, "(", A.val, ", ", A.ind[1], ", ", size(A,1), ")") + +# mapreduce +Base.sum(O::OneElement; dims=:, kw...) = _sum(O, dims; kw...) +_sum(O::OneElement, ::Colon; kw...) = sum((getindex_value(O),); kw...) +function _sum(O::OneElement, dims; kw...) + v = _sum(O, :; kw...) + ax = Base.reduced_indices(axes(O), dims) + ind = ntuple(x -> x in dims ? first(ax[x]) + (O.ind[x] in axes(O)[x]) - 1 : O.ind[x], ndims(O)) + OneElement(v, ind, ax) +end diff --git a/test/runtests.jl b/test/runtests.jl index c0f36e7f..cd6555b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2702,6 +2702,56 @@ end @test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))" end + @testset "sum" begin + @testset "OneElement($v, $ind, $sz)" for (v, ind, sz) in ( + (Int8(2), 3, 4), + (3.0, 5, 4), + (3.0, 0, 0), + (SMatrix{2,2}(1:4), (4, 2), (12,6)), + ) + O = OneElement(v,ind,sz) + A = Array(O) + if VERSION >= v"1.10" + @test @inferred(sum(O)) === sum(A) + else + @test @inferred(sum(O)) == sum(A) + end + @test @inferred(sum(O, init=zero(eltype(O)))) === sum(A, init=zero(eltype(O))) + @test @inferred(sum(x->1, O, init=0)) === sum(Fill(1, axes(O)), init=0) + end + + @testset for O in (OneElement(Int8(2), (1,2), (2,4)), + OneElement(3, (1,2,3), (2,4,4)), + OneElement(2.0, (3,2,5), (2,3,2)), + OneElement(SMatrix{2,2}(1:4), (1,2), (2,4)), + ) + A = Array(O) + init = sum((zero(FillArrays.getindex_value(O)),)) + for i in 1:3 + @test @inferred(sum(O, dims=i)) == sum(A, dims=i) + @test @inferred(sum(O, dims=i, init=init)) == sum(A, dims=i, init=init) + @test @inferred(sum(x->1, O, dims=i, init=0)) == sum(Fill(1, axes(O)), dims=i, init=0) + end + @test @inferred(sum(O, dims=1:1)) == sum(A, dims=1:1) + @test @inferred(sum(O, dims=1:2)) == sum(A, dims=1:2) + @test @inferred(sum(O, dims=1:3)) == sum(A, dims=1:3) + @test @inferred(sum(O, dims=(1,))) == sum(A, dims=(1,)) + @test @inferred(sum(O, dims=(1,2))) == sum(A, dims=(1,2)) + @test @inferred(sum(O, dims=(1,3))) == sum(A, dims=(1,3)) + @test @inferred(sum(O, dims=(2,3))) == sum(A, dims=(2,3)) + @test @inferred(sum(O, dims=(1,2,3))) == sum(A, dims=(1,2,3)) + @test @inferred(sum(O, dims=1:1, init=init)) == sum(A, dims=1:1, init=init) + @test @inferred(sum(O, dims=1:2, init=init)) == sum(A, dims=1:2, init=init) + @test @inferred(sum(O, dims=1:3, init=init)) == sum(A, dims=1:3, init=init) + @test @inferred(sum(O, dims=(1,), init=init)) == sum(A, dims=(1,), init=init) + @test @inferred(sum(O, dims=(1,2), init=init)) == sum(A, dims=(1,2), init=init) + @test @inferred(sum(O, dims=(1,3), init=init)) == sum(A, dims=(1,3), init=init) + @test @inferred(sum(O, dims=(2,3), init=init)) == sum(A, dims=(2,3), init=init) + @test @inferred(sum(O, dims=(1,2,3), init=init)) == sum(A, dims=(1,2,3), init=init) + @test @inferred(sum(x->1, O, dims=(1,2,3), init=0)) == sum(Fill(1, axes(O)), dims=(1,2,3), init=0) + end + end + @testset "diag" begin @testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz) O = OneElement(4, Tuple(ind), sz)