diff --git a/Project.toml b/Project.toml index 93c18549..88429e92 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "10.2.3" [deps] +AbstractNumbers = "85c772de-338a-5e7f-b815-41e76c26ac1f" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -14,6 +15,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +AbstractNumbers = "0.2" Adapt = "4.0" GPUArraysCore = "= 0.1.6" LLVM = "3.9, 4, 5, 6, 7, 8" diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 2d4f1bd9..0bc6fd8c 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -1,5 +1,7 @@ module GPUArrays +import AbstractNumbers + using Serialization using Random using LinearAlgebra @@ -25,6 +27,7 @@ include("device/synchronization.jl") # host abstractions include("host/abstractarray.jl") include("host/construction.jl") +include("host/gpunumber.jl") ## integrations and specialized methods include("host/base.jl") include("host/indexing.jl") diff --git a/src/host/gpunumber.jl b/src/host/gpunumber.jl new file mode 100644 index 00000000..e2660740 --- /dev/null +++ b/src/host/gpunumber.jl @@ -0,0 +1,42 @@ +# Custom GPU-compatible `Number` interface. +struct AsyncNumber{T <: AbstractGPUArray} <: AbstractNumbers.AbstractNumber{T} + val::T + + function AsyncNumber(val::T) where T <: AbstractGPUArray + length(val) != 1 && error( + "`AsyncNumber` accepts only 1-element GPU arrays, " * + "instead `$(length(val))`-element array was given.") + new{T}(val) + end +end + +AbstractNumbers.number(g::AsyncNumber) = @allowscalar g.val[] +maybe_number(g::AsyncNumber) = AbstractNumbers.number(g) +maybe_number(g) = g + +number_type(::AsyncNumber{T}) where T = eltype(T) + +# When operations involve other `::Number` types, +# do not convert back to `AsyncNumber`. +AbstractNumbers.like(::Type{<: AsyncNumber}, x) = x + +# When broadcasting, just pass the array itself. +Base.broadcastable(g::AsyncNumber) = g.val + +# Overload to avoid copies. +Base.one(g::AsyncNumber) = one(number_type(g)) +Base.one(::Type{AsyncNumber{T}}) where T = one(eltype(T)) +Base.zero(g::AsyncNumber) = zero(number_type(g)) +Base.zero(::Type{AsyncNumber{T}}) where T = zero(eltype(T)) +Base.identity(g::AsyncNumber) = g + +Base.getindex(g::AsyncNumber) = AbstractNumbers.number(g) + +Base.isequal(g::AsyncNumber, v::Number) = isequal(g[], v) +Base.isequal(v::Number, g::AsyncNumber) = isequal(v, g[]) + +Base.nextpow(a, x::AsyncNumber) = nextpow(a, x[]) +Base.nextpow(a::AsyncNumber, x) = nextpow(a[], x) +Base.nextpow(a::AsyncNumber, x::AsyncNumber) = nextpow(a[], x[]) + +Base.convert(::Type{Number}, g::AsyncNumber) = g[] diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 659fb029..31c943cb 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -203,8 +203,7 @@ function Base.findfirst(f::Function, A::AnyGPUArray) return (false, dummy_index) end - res = mapreduce((x, y)->(f(x), y), reduction, A, indices; - init = (false, dummy_index)) + res = mapreduce((x, y)->(f(x), y), reduction, A, indices; init = (false, dummy_index)) if res[1] # out of consistency with Base.findarray, return a CartesianIndex # when the input is a multidimensional array @@ -230,7 +229,8 @@ function findminmax(binop, A::AnyGPUArray; init, dims) end if dims == Colon() - res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index)) + res = mapreduce(tuple, reduction, A, indices; + init = (init, dummy_index)) # out of consistency with Base.findarray, return a CartesianIndex # when the input is a multidimensional array diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index 32520ebc..594bbf5b 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -68,20 +68,23 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP, end if dims === Colon() - @allowscalar R[] + # Return `AsyncNumber` for `Number` eltypes, otherwise - transfer to host. + eltype(R) <: Number ? + AsyncNumber(reshape(R, :)) : + @allowscalar(R[]) else R end end -Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A) -Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A) +Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)[] +Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)[] -Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A) -Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A) +Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)[] +Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)[] Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) = - mapreduce(pred, Base.add_sum, A; init=init, dims=dims) + mapreduce(pred, Base.add_sum, A; init=init, dims=dims) |> maybe_number # avoid calling into `initarray!` for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), @@ -94,7 +97,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), end end -LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A)) +LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A))[] # comparisons @@ -105,7 +108,7 @@ function Base.isequal(A::AnyGPUArray, B::AnyGPUArray) if axes(A) != axes(B) return false end - mapreduce(isequal, &, A, B; init=true) + mapreduce(isequal, &, A, B; init=true)[] end # returns `missing` when missing values are involved @@ -129,6 +132,7 @@ function Base.:(==)(A::AnyGPUArray, B::AnyGPUArray) (; is_missing=false, is_equal=a.is_equal & b.is_equal) end end - res = mapreduce(mapper, reducer, A, B; init=(; is_missing=false, is_equal=true)) + res = mapreduce(mapper, reducer, A, B; + init=(; is_missing=false, is_equal=true)) res.is_missing ? missing : res.is_equal end