-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathmapreduce.jl
134 lines (114 loc) · 4.94 KB
/
mapreduce.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# map-reduce
const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArrayOrBroadcasted;
init=nothing) = error("Not implemented") # COV_EXCL_LINE
# resolve ambiguities
Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)
neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
Please pass it as an explicit argument to `GPUArrays.mapreducedim!`,
or register it globally by defining `GPUArrays.neutral_element(::typeof($op), T)`.""")
neutral_element(::typeof(Base.:(|)), T) = zero(T)
neutral_element(::typeof(Base.:(+)), T) = zero(T)
neutral_element(::typeof(Base.add_sum), T) = zero(T)
neutral_element(::typeof(Base.:(&)), T) = one(T)
neutral_element(::typeof(Base.:(*)), T) = one(T)
neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)
neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = typemax(T), typemin(T)
# resolve ambiguities
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,N,D}
# mapreduce should apply `f` like `map` does, consuming elements like iterators
bc = if allequal(size.(As)...)
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
else
# TODO: can we avoid the reshape + view?
indices = LinearIndices.(As)
common_length = minimum(length.(indices))
Bs = map(As) do A
view(reshape(A, length(A)), 1:common_length)
end
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
end
# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions
if init === nothing
ET = Broadcast.combine_eltypes(f, As)
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("mapreduce cannot figure the output element type, please pass an explicit init value")
init = neutral_element(op, ET)
else
ET = typeof(init)
end
sz = size(bc)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
R = similar(bc, ET, red)
if prod(sz) == 0
fill!(R, init)
else
mapreducedim!(identity, op, R, bc; init=init)
end
if dims === Colon()
@allowscalar R[]
else
R
end
end
Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)
Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) =
mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
# avoid calling into `initarray!`
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
(:maximum, :(Base.max)), (:minimum, :(Base.min)),
(:all, :&), (:any, :|)]
fname! = Symbol(fname, '!')
@eval begin
Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T =
GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T))
end
end
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A))
# comparisons
# ignores missing
function Base.isequal(A::AnyGPUArray, B::AnyGPUArray)
if A === B return true end
if axes(A) != axes(B)
return false
end
mapreduce(isequal, &, A, B; init=true)
end
# returns `missing` when missing values are involved
function Base.:(==)(A::AnyGPUArray, B::AnyGPUArray)
if axes(A) != axes(B)
return false
end
function mapper(a, b)
eq = (a == b)
if ismissing(eq)
(; is_missing=true, is_equal=#=don't care=#false)
else
(; is_missing=false, is_equal=eq)
end
end
function reducer(a, b)
if a.is_missing || b.is_missing
(; is_missing=true, is_equal=#=don't care=#false)
else
(; is_missing=false, is_equal=a.is_equal & b.is_equal)
end
end
res = mapreduce(mapper, reducer, A, B; init=(; is_missing=false, is_equal=true))
res.is_missing ? missing : res.is_equal
end