Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add fallbacks for unknown objects (#87)
Browse files Browse the repository at this point in the history
* feat: add fallbacks for unknown objects

* feat: handle RNGs and undef arrays gracefully

* test: RNG movement

* test: functions and closures
  • Loading branch information
avik-pal authored Oct 18, 2024
1 parent 0d6c6a8 commit 17bc9aa
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
steps:
- label: "Triggering Pipelines (Pull Request)"
if: "build.pull_request.base_branch == 'main'"
if: build.branch != "main" && build.tag == null
agents:
queue: "juliagpu"
plugins:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.1"
version = "1.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return Internal.get_device(parent_x)
end
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())

Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
4 changes: 4 additions & 0 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray)
return MLDataDevices.get_device(parent_x)
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())
Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device())

Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
Internal.get_device_type(::CUDA.RNG) = CUDADevice
Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice

# Set Device
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
Expand Down
11 changes: 8 additions & 3 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt
using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
end
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Expand Down
5 changes: 4 additions & 1 deletion ext/MLDataDevicesGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: CPUDevice
using MLDataDevices: Internal, CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)

end
39 changes: 32 additions & 7 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
msg = "`device_id` is not applicable for `$dev`."
Expand Down Expand Up @@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false
recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

combine_devices(::Nothing, ::Nothing) = nothing
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Nothing, dev::AbstractDevice) = dev
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(dev::AbstractDevice, ::Nothing) = dev
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice)
dev1 == dev2 && return dev1
dev1 isa UnknownDevice && return dev2
dev2 isa UnknownDevice && return dev1
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end

combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice
function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end

for op in (:get_device, :get_device_type)
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
$(cpu_ret_val)..."
$(unknown_ret_val)..."

@eval begin
function $(op)(x::AbstractArray{T}) where {T}
if recursive_array_eltype(T)
if any(!isassigned(x, i) for i in eachindex(x))
@warn $(not_assigned_msg)
return $(cpu_ret_val)
return $(unknown_ret_val)
end
return mapreduce(MLDataDevices.$(op), combine_devices, x)
end
Expand All @@ -147,13 +154,31 @@ for op in (:get_device, :get_device_type)
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
end

function $(op)(f::F) where {F <: Function}
Base.issingletontype(F) &&
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
map(Base.Fix1(getfield, f), fieldnames(F)))
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end

get_device(_) = UnknownDevice()
get_device_type(_) = UnknownDevice

fast_structure(::AbstractArray) = true
fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval fast_structure(::$(T)) = true
end
fast_structure(::Function) = true
fast_structure(_) = false

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end
Expand Down
22 changes: 16 additions & 6 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end
# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end

"""
functional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> Bool
Expand Down Expand Up @@ -229,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
!!! note
Trigger Packages must be loaded for this to return the correct device.
!!! warning
RNG types currently don't participate in device determination. We will remove this
restriction in the future.
"""

# Query Device from Array
Expand All @@ -245,6 +243,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice()` -- denotes that the device type is unknown
See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
based on device type.
"""
Expand All @@ -258,6 +262,12 @@ itself. This value is often a compile time constant and is recommended to be use
of [`get_device`](@ref) where ever defining dispatches based on the device type.
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice` -- denotes that the device type is unknown
"""
function get_device_type end

Expand Down Expand Up @@ -345,7 +355,7 @@ end

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
Internal.fast_structure(x) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end
end
Expand Down
29 changes: 29 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_xpu.one_elem isa ROCArray
Expand All @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -118,6 +126,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(AMDGPUDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> AMDGPUDevice()
@test get_device(ff_xpu) isa AMDGPUDevice
@test get_device_type(ff_xpu) <: AMDGPUDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(AMDGPUDevice)
x = rand(10, 10) |> AMDGPUDevice()
Expand Down
29 changes: 29 additions & 0 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa CUDADevice
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_xpu.one_elem isa CuArray
Expand All @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -143,6 +151,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(CUDADevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> CUDADevice()
@test get_device(ff_xpu) isa CUDADevice
@test get_device_type(ff_xpu) <: CUDADevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(CUDADevice)
x = rand(10, 10) |> CUDADevice()
Expand Down
29 changes: 29 additions & 0 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa MetalDevice
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_xpu.one_elem isa MtlArray
Expand All @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_cpu.one_elem isa Array
Expand All @@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(MetalDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> MetalDevice()
@test get_device(ff_xpu) isa MetalDevice
@test get_device_type(ff_xpu) <: MetalDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapper Arrays" begin
if MLDataDevices.functional(MetalDevice)
x = rand(Float32, 10, 10) |> MetalDevice()
Expand Down
4 changes: 2 additions & 2 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ end
@testset "undefined references array" begin
x = Matrix{Any}(undef, 10, 10)

@test get_device(x) isa CPUDevice
@test get_device_type(x) <: CPUDevice
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end
Loading

2 comments on commit 17bc9aa

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117584

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.3.0 -m "<description of version>" 17bc9aabbc4c574f8751c3afbf48220b508ffd73
git push origin v1.3.0

Please sign in to comment.