-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement accumulate and friends #702
Changes from 1 commit
3a26ce9
801d95c
4985bcf
4ca0144
df1caa8
df87a52
b665233
3db20ef
6bf5e05
5a7963d
8db9ad6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -285,3 +285,43 @@ end | |
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...))) | ||
end | ||
end | ||
|
||
struct _InitialValue end | ||
|
||
_maybeval(dims::Integer) = Val(Int(dims)) | ||
_maybeval(dims) = dims | ||
_valof(::Val{D}) where D = D | ||
c42f marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} = | ||
_accumulate(op, a, _maybeval(dims), init) | ||
|
||
@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} = | ||
_accumulate(op, a, _maybeval(dims), init) | ||
|
||
@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F} | ||
# Adjoin the initial value to `op`: | ||
rf(x, y) = x isa _InitialValue ? y : op(x, y) | ||
|
||
# StaticArrays' `reduce` is `foldl`: | ||
c42f marked this conversation as resolved.
Show resolved
Hide resolved
|
||
results = _reduce( | ||
a, | ||
dims, | ||
(init = (similar_type(a, Union{}, Size(0))(), init),), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing I noticed here is that for length-0 input, the result has eltype of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think it's possible to solve this in general as the output value depends on the function As all static arrays (including My preference is:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it's possible to solve this in general as the output value depends on the function This is what #664 hacks around for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. I misread the last half of your first comment. But I think one difficulty of op(::A, ::A) :: B
op(::B, ::A) :: C
op(::C, ::A) :: D
op(::D, ::A) :: D # finally "stabilized" But, even though osc_op(::Nothing, ::Any) = missing
osc_op(::Missing, ::Any) = nothing So, maybe we should use something like this? function f(op, acc, x)
T = typeof(acc)
while true
acc = op(acc, x)
T = promote_type(T, typeof(acc))
non_existing_value && return T
# `non_existing_value` modeling truncation arbitrary-sized input.
end
end
return_type(f, Tuple{typeof(op), typeof(init), eltype(a)}) Impressively, I'm not sure if we need to go this far, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, it's confusing to know what the empty case should do here and it's worse than map. With your fix we have julia> accumulate((a,b)->(a+b)/2, SA[1,2])
2-element SArray{Tuple{2},Float64,1,2} with indices SOneTo(2):
1.0
1.5
julia> accumulate((a,b)->(a+b)/2, SA[1])
1-element SArray{Tuple{1},Int64,1,1} with indices SOneTo(1):
1
julia> accumulate((a,b)->(a+b)/2, SA{Int}[])
0-element SArray{Tuple{0},Int64,1,0} with indices SOneTo(0) It seems like we must make a somewhat arbitrary choice for the length-0 case with julia> accumulate((a,b)->(a+b)/2, SA{Int}[1], init=0)
1-element SArray{Tuple{1},Float64,1,1} with indices SOneTo(1):
0.5
julia> accumulate((a,b)->(a+b)/2, SA{Int}[], init=0)
0-element SArray{Tuple{0},Float64,1,0} with indices SOneTo(0) Overall do you feel like this is an improvement on using We do have oddities like julia> cumsum(Int8[1])
1-element Array{Int64,1}:
1
julia> cumsum(SA{Int8}[1])
1-element SArray{Tuple{1},Int8,1,1} with indices SOneTo(1):
1
julia> cumsum(SA{Int8}[1,2])
2-element SArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
1
3 I think that inconsistency between here and Base might be harmless though, and it's certainly not clear how to "fix" it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's a tough question... 😅 Personally, I think we should be using But I think the consistency of the API within a package and with respect to |
||
) do (ys, acc), x | ||
y = rf(acc, x) | ||
(vcat(ys, SA[y]), y) | ||
c42f marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
dims === (:) && return first(results) | ||
|
||
ys = map(first, results) | ||
data = _map(a, CartesianIndices(a)) do _, CI | ||
c42f marked this conversation as resolved.
Show resolved
Hide resolved
|
||
D = _valof(dims) | ||
I = Tuple(CI) | ||
J = Base.setindex(I, 1, D) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already import |
||
ys[J...][I[D]] | ||
end | ||
return similar_type(a, eltype(data))(data) | ||
end | ||
|
||
@inline Base.cumsum(a::StaticArray; kw...) = accumulate(Base.add_sum, a; kw...) | ||
@inline Base.cumprod(a::StaticArray; kw...) = accumulate(Base.mul_prod, a; kw...) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
using StaticArrays, Test | ||
|
||
@testset "accumulate" begin | ||
@testset "cumsum(::$label)" for (label, T) in [ | ||
# label, T | ||
("SVector", SVector), | ||
("MVector", MVector), | ||
("SizedVector", SizedVector{3}), | ||
] | ||
a = T(SA[1, 2, 3]) | ||
@test cumsum(a) == cumsum(collect(a)) | ||
@test cumsum(a) isa similar_type(a) | ||
@inferred cumsum(a) | ||
end | ||
|
||
@testset "cumsum(::$label; dims=2)" for (label, T) in [ | ||
# label, T | ||
("SMatrix", SMatrix), | ||
("MMatrix", MMatrix), | ||
("SizedMatrix", SizedMatrix{3,2}), | ||
] | ||
a = T(SA[1 2; 3 4; 5 6]) | ||
@test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2) | ||
@test cumsum(a; dims = 2) isa similar_type(a) | ||
@inferred cumsum(a; dims = Val(2)) | ||
end | ||
|
||
@testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d | ||
shape = Tuple(1:d) | ||
a = similar_type(SArray, Int, Size(shape))(1:prod(shape)) | ||
@test cumsum(a; dims = i) == cumsum(collect(a); dims = i) | ||
@test cumsum(a; dims = i) isa SArray | ||
@inferred cumsum(a; dims = Val(i)) | ||
end | ||
|
||
@testset "cumprod" begin | ||
a = SA[1, 2, 3] | ||
@test cumprod(a)::SArray == cumprod(collect(a)) | ||
@inferred cumprod(a) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's just me, but I can't help reading this as
_maybe_eval
:-/ Can we change to_maybe_val
?