Skip to content

Commit

Permalink
sum for OneElement (#375)
Browse files Browse the repository at this point in the history
* sum for OneElement

* Add tests

* Accept dims in sum

* Add tests

* Bump version to v1.13.0

* Ensure that init kwarg works

* Update tests for v1.6

---------

Co-authored-by: Sheehan Olver <[email protected]>
  • Loading branch information
jishnub and dlfivefifty authored Aug 27, 2024
1 parent 7b2bb11 commit 414dba7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.12.0"
version = "1.13.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
17 changes: 10 additions & 7 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,6 @@ function *(A::OneElementMatrix, B::OneElementVecOrMat)
OneElement(val, (A.ind[1], B.ind[2:end]...), (axes(A,1), axes(B)[2:end]...))
end

function *(A::AbstractFillMatrix, x::OneElementVector)
check_matmul_sizes(A, x)
val = getindex_value(A) * getindex_value(x)
Fill(val, (axes(A,1),))
end
*(A::AbstractZerosMatrix, x::OneElementVector) = mult_zeros(A, x)

*(A::OneElementMatrix, x::AbstractZerosVector) = mult_zeros(A, x)

function *(A::OneElementMatrix, B::AbstractFillVector)
Expand Down Expand Up @@ -448,3 +441,13 @@ _maybesize(t) = t
Base.show(io::IO, A::OneElement) = print(io, OneElement, "(", A.val, ", ", A.ind, ", ", _maybesize(axes(A)), ")")
Base.show(io::IO, A::OneElement{<:Any,1,Tuple{Int},Tuple{Base.OneTo{Int}}}) =
print(io, OneElement, "(", A.val, ", ", A.ind[1], ", ", size(A,1), ")")

# mapreduce
Base.sum(O::OneElement; dims=:, kw...) = _sum(O, dims; kw...)
_sum(O::OneElement, ::Colon; kw...) = sum((getindex_value(O),); kw...)
function _sum(O::OneElement, dims; kw...)
v = _sum(O, :; kw...)
ax = Base.reduced_indices(axes(O), dims)
ind = ntuple(x -> x in dims ? first(ax[x]) + (O.ind[x] in axes(O)[x]) - 1 : O.ind[x], ndims(O))
OneElement(v, ind, ax)
end
50 changes: 50 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2702,6 +2702,56 @@ end
@test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))"
end

@testset "sum" begin
@testset "OneElement($v, $ind, $sz)" for (v, ind, sz) in (
(Int8(2), 3, 4),
(3.0, 5, 4),
(3.0, 0, 0),
(SMatrix{2,2}(1:4), (4, 2), (12,6)),
)
O = OneElement(v,ind,sz)
A = Array(O)
if VERSION >= v"1.10"
@test @inferred(sum(O)) === sum(A)
else
@test @inferred(sum(O)) == sum(A)
end
@test @inferred(sum(O, init=zero(eltype(O)))) === sum(A, init=zero(eltype(O)))
@test @inferred(sum(x->1, O, init=0)) === sum(Fill(1, axes(O)), init=0)
end

@testset for O in (OneElement(Int8(2), (1,2), (2,4)),
OneElement(3, (1,2,3), (2,4,4)),
OneElement(2.0, (3,2,5), (2,3,2)),
OneElement(SMatrix{2,2}(1:4), (1,2), (2,4)),
)
A = Array(O)
init = sum((zero(FillArrays.getindex_value(O)),))
for i in 1:3
@test @inferred(sum(O, dims=i)) == sum(A, dims=i)
@test @inferred(sum(O, dims=i, init=init)) == sum(A, dims=i, init=init)
@test @inferred(sum(x->1, O, dims=i, init=0)) == sum(Fill(1, axes(O)), dims=i, init=0)
end
@test @inferred(sum(O, dims=1:1)) == sum(A, dims=1:1)
@test @inferred(sum(O, dims=1:2)) == sum(A, dims=1:2)
@test @inferred(sum(O, dims=1:3)) == sum(A, dims=1:3)
@test @inferred(sum(O, dims=(1,))) == sum(A, dims=(1,))
@test @inferred(sum(O, dims=(1,2))) == sum(A, dims=(1,2))
@test @inferred(sum(O, dims=(1,3))) == sum(A, dims=(1,3))
@test @inferred(sum(O, dims=(2,3))) == sum(A, dims=(2,3))
@test @inferred(sum(O, dims=(1,2,3))) == sum(A, dims=(1,2,3))
@test @inferred(sum(O, dims=1:1, init=init)) == sum(A, dims=1:1, init=init)
@test @inferred(sum(O, dims=1:2, init=init)) == sum(A, dims=1:2, init=init)
@test @inferred(sum(O, dims=1:3, init=init)) == sum(A, dims=1:3, init=init)
@test @inferred(sum(O, dims=(1,), init=init)) == sum(A, dims=(1,), init=init)
@test @inferred(sum(O, dims=(1,2), init=init)) == sum(A, dims=(1,2), init=init)
@test @inferred(sum(O, dims=(1,3), init=init)) == sum(A, dims=(1,3), init=init)
@test @inferred(sum(O, dims=(2,3), init=init)) == sum(A, dims=(2,3), init=init)
@test @inferred(sum(O, dims=(1,2,3), init=init)) == sum(A, dims=(1,2,3), init=init)
@test @inferred(sum(x->1, O, dims=(1,2,3), init=0)) == sum(Fill(1, axes(O)), dims=(1,2,3), init=0)
end
end

@testset "diag" begin
@testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz)
O = OneElement(4, Tuple(ind), sz)
Expand Down

0 comments on commit 414dba7

Please sign in to comment.