Skip to content

Commit

Permalink
repeat for AbstractFill
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed May 16, 2023
1 parent 5463d9e commit 5dd7296
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 1 deletion.
39 changes: 38 additions & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
show, view, in, mapreduce, one, reverse, promote_op, promote_rule
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
Expand Down Expand Up @@ -762,6 +762,43 @@ Base.@propagate_inbounds function view(A::AbstractFill, I::Vararg{Real})
fillsimilar(A)
end

# repeat

_first(t::Tuple) = t[1]
_first(t::Tuple{}) = 1

_maybetail(t::Tuple) = Base.tail(t)
_maybetail(t::Tuple{}) = t

_match_size(sz::Tuple{}, inner::Tuple{}, outer::Tuple{}) = ()
function _match_size(sz::Tuple, inner::Tuple, outer::Tuple)
t1 = (_first(sz), _first(inner), _first(outer))
t2 = _match_size(_maybetail(sz), _maybetail(inner), _maybetail(outer))
(t1, t2...)
end

function _repeat_size(sz::Tuple, inner::Tuple, outer::Tuple)
t = _match_size(sz, inner, outer)
map(*, getindex.(t, 1), getindex.(t, 2), getindex.(t, 3))
end

function _repeat(A; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
Base.require_one_based_indexing(A)
length(inner) >= ndims(A) ||
throw(ArgumentError("number of inner repetitions $(length(inner)) cannot be "*
"less than number of dimensions of input array $(ndims(A))"))
length(outer) >= ndims(A) ||
throw(ArgumentError("number of outer repetitions $(length(outer)) cannot be "*
"less than number of dimensions of input array $(ndims(A))"))
sz = _repeat_size(size(A), Tuple(inner), Tuple(outer))
fillsimilar(A, sz)
end

repeat(A::AbstractFill, count::Integer...) = _repeat(A, outer=count)
function repeat(A::AbstractFill; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
_repeat(A, inner=inner, outer=outer)
end

include("oneelement.jl")

end # module
106 changes: 106 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1573,3 +1573,109 @@ end
@test Base.setindex(Zeros(5,3), 2, 2, 3) OneElement(2.0, (2,3), (5,3))
@test_throws BoundsError Base.setindex(Zeros(5), 2, 6)
end

@testset "repeat" begin
@testset "0D" begin
@test repeat(Zeros()) isa Zeros
@test repeat(Zeros()) == repeat(zeros())
@test repeat(Ones()) isa Ones
@test repeat(Ones()) == repeat(ones())
@test repeat(Fill(3)) isa Fill
@test repeat(Fill(3)) == repeat(fill(3))

@test repeat(Zeros(), inner=(), outer=()) isa Zeros
@test repeat(Zeros(), inner=(), outer=()) == repeat(zeros(), inner=(), outer=())
@test repeat(Ones(), inner=(), outer=()) isa Ones
@test repeat(Ones(), inner=(), outer=()) == repeat(ones(), inner=(), outer=())
@test repeat(Fill(4), inner=(), outer=()) isa Fill
@test repeat(Fill(4), inner=(), outer=()) == repeat(fill(4), inner=(), outer=())

@test repeat(Zeros{Bool}(), 2, 3) isa Zeros{Bool}
@test repeat(Zeros{Bool}(), 2, 3) == repeat(zeros(Bool), 2, 3)
@test repeat(Ones{Bool}(), 2, 3) isa Ones{Bool}
@test repeat(Ones{Bool}(), 2, 3) == repeat(ones(Bool), 2, 3)
@test repeat(Fill(false), 2, 3) isa Fill
@test repeat(Fill(false), 2, 3) == repeat(fill(false), 2, 3)

@test repeat(Zeros(), inner=(2,2), outer=5) isa Zeros
@test repeat(Zeros(), inner=(2,2), outer=5) == repeat(zeros(), inner=(2,2), outer=5)
@test repeat(Ones(), inner=(2,2), outer=5) isa Ones
@test repeat(Ones(), inner=(2,2), outer=5) == repeat(ones(), inner=(2,2), outer=5)
@test repeat(Fill(2), inner=(2,2), outer=5) isa Fill
@test repeat(Fill(2), inner=(2,2), outer=5) == repeat(fill(2), inner=(2,2), outer=5)

@test repeat(Zeros(), inner=(2,2), outer=(2,3)) isa Zeros
@test repeat(Zeros(), inner=(2,2), outer=(2,3)) == repeat(zeros(), inner=(2,2), outer=(2,3))
@test repeat(Ones(), inner=(2,2), outer=(2,3)) isa Ones
@test repeat(Ones(), inner=(2,2), outer=(2,3)) == repeat(ones(), inner=(2,2), outer=(2,3))
@test repeat(Fill("a"), inner=(2,2), outer=(2,3)) isa Fill
@test repeat(Fill("a"), inner=(2,2), outer=(2,3)) == repeat(fill("a"), inner=(2,2), outer=(2,3))
end
@testset "1D" begin
@test repeat(Zeros(2), 2, 3) isa Zeros
@test repeat(Zeros(2), 2, 3) == repeat(zeros(2), 2, 3)
@test repeat(Ones(2), 2, 3) isa Ones
@test repeat(Ones(2), 2, 3) == repeat(ones(2), 2, 3)
@test repeat(Fill(2,3), 2, 3) isa Fill
@test repeat(Fill(2,3), 2, 3) == repeat(fill(2,3), 2, 3)

@test repeat(Zeros(2), inner=2, outer=4) isa Zeros
@test repeat(Zeros(2), inner=2, outer=4) == repeat(zeros(2), inner=2, outer=4)
@test repeat(Ones(2), inner=2, outer=4) isa Ones
@test repeat(Ones(2), inner=2, outer=4) == repeat(ones(2), inner=2, outer=4)
@test repeat(Fill(2,3), inner=2, outer=4) isa Fill
@test repeat(Fill(2,3), inner=2, outer=4) == repeat(fill(2,3), inner=2, outer=4)

@test repeat(Zeros(2), inner=2, outer=(2,3)) isa Zeros
@test repeat(Zeros(2), inner=2, outer=(2,3)) == repeat(zeros(2), inner=2, outer=(2,3))
@test repeat(Ones(2), inner=2, outer=(2,3)) isa Ones
@test repeat(Ones(2), inner=2, outer=(2,3)) == repeat(ones(2), inner=2, outer=(2,3))
@test repeat(Fill("b",3), inner=2, outer=(2,3)) isa Fill
@test repeat(Fill("b",3), inner=2, outer=(2,3)) == repeat(fill("b",3), inner=2, outer=(2,3))

@test repeat(Zeros(Int, 2), inner=(2,), outer=(2,3)) isa Zeros
@test repeat(Zeros(Int, 2), inner=(2,), outer=(2,3)) == repeat(zeros(Int, 2), inner=(2,), outer=(2,3))
@test repeat(Ones(Int, 2), inner=(2,), outer=(2,3)) isa Ones
@test repeat(Ones(Int, 2), inner=(2,), outer=(2,3)) == repeat(ones(Int, 2), inner=(2,), outer=(2,3))
@test repeat(Fill(2,3), inner=(2,), outer=(2,3)) isa Fill
@test repeat(Fill(2,3), inner=(2,), outer=(2,3)) == repeat(fill(2,3), inner=(2,), outer=(2,3))

@test repeat(Zeros(2), inner=(2,2,1,4), outer=(2,3)) isa Zeros
@test repeat(Zeros(2), inner=(2,2,1,4), outer=(2,3)) == repeat(zeros(2), inner=(2,2,1,4), outer=(2,3))
@test repeat(Ones(2), inner=(2,2,1,4), outer=(2,3)) isa Ones
@test repeat(Ones(2), inner=(2,2,1,4), outer=(2,3)) == repeat(ones(2), inner=(2,2,1,4), outer=(2,3))
@test repeat(Fill(2,3), inner=(2,2,1,4), outer=(2,3)) isa Fill
@test repeat(Fill(2,3), inner=(2,2,1,4), outer=(2,3)) == repeat(fill(2,3), inner=(2,2,1,4), outer=(2,3))

@test_throws ArgumentError repeat(Fill(2,3), inner=())
@test_throws ArgumentError repeat(Fill(2,3), outer=())
end

@testset "2D" begin
@test repeat(Zeros(2,3), 2, 3) isa Zeros
@test repeat(Zeros(2,3), 2, 3) == repeat(zeros(2,3), 2, 3)
@test repeat(Ones(2,3), 2, 3) isa Ones
@test repeat(Ones(2,3), 2, 3) == repeat(ones(2,3), 2, 3)
@test repeat(Fill(2,3,4), 2, 3) isa Fill
@test repeat(Fill(2,3,4), 2, 3) == repeat(fill(2,3,4), 2, 3)

@test repeat(Zeros(2,3), inner=(1,2), outer=(4,2)) isa Zeros
@test repeat(Zeros(2,3), inner=(1,2), outer=(4,2)) == repeat(zeros(2,3), inner=(1,2), outer=(4,2))
@test repeat(Ones(2,3), inner=(1,2), outer=(4,2)) isa Ones
@test repeat(Ones(2,3), inner=(1,2), outer=(4,2)) == repeat(ones(2,3), inner=(1,2), outer=(4,2))
@test repeat(Fill(2,3,4), inner=(1,2), outer=(4,2)) isa Fill
@test repeat(Fill(2,3,4), inner=(1,2), outer=(4,2)) == repeat(fill(2,3,4), inner=(1,2), outer=(4,2))

@test repeat(Zeros(2,3), inner=(2,2,1,4), outer=(2,1,3)) isa Zeros
@test repeat(Zeros(2,3), inner=(2,2,1,4), outer=(2,1,3)) == repeat(zeros(2,3), inner=(2,2,1,4), outer=(2,1,3))
@test repeat(Ones(2,3), inner=(2,2,1,4), outer=(2,1,3)) isa Ones
@test repeat(Ones(2,3), inner=(2,2,1,4), outer=(2,1,3)) == repeat(ones(2,3), inner=(2,2,1,4), outer=(2,1,3))
@test repeat(Fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3)) isa Fill
@test repeat(Fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3)) == repeat(fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3))

@test_throws ArgumentError repeat(Fill(2,3,4), inner=())
@test_throws ArgumentError repeat(Fill(2,3,4), outer=())
@test_throws ArgumentError repeat(Fill(2,3,4), inner=(1,))
@test_throws ArgumentError repeat(Fill(2,3,4), outer=(1,))
end
end

0 comments on commit 5dd7296

Please sign in to comment.