Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add fallbacks for unknown objects
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 0d6c6a8 commit 6c3a4a7
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.1"
version = "1.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
11 changes: 8 additions & 3 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt
using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)

Check warning on line 16 in ext/MLDataDevicesChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesChainRulesCoreExt.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
end
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Expand Down
34 changes: 29 additions & 5 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
msg = "`device_id` is not applicable for `$dev`."
Expand Down Expand Up @@ -107,16 +107,22 @@ special_aos(::AbstractArray) = false
recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

combine_devices(::Nothing, ::Nothing) = nothing
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Nothing, dev::AbstractDevice) = dev
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(dev::AbstractDevice, ::Nothing) = dev
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice)
dev1 == dev2 && return dev1
dev1 isa UnknownDevice && return dev2
dev2 isa UnknownDevice && return dev1
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end

combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T

Check warning on line 122 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L122

Added line #L122 was not covered by tests
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice

Check warning on line 125 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end
Expand Down Expand Up @@ -147,13 +153,31 @@ for op in (:get_device, :get_device_type)
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
end

function $(op)(f::F) where {F <: Function}
Base.issingletontype(F) &&

Check warning on line 158 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,

Check warning on line 160 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L160

Added line #L160 was not covered by tests
map(Base.Fix1(getfield, f), fieldnames(F)))
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end

get_device(_) = UnknownDevice()
get_device_type(_) = UnknownDevice

Check warning on line 171 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L170-L171

Added lines #L170 - L171 were not covered by tests

fast_structure(::AbstractArray) = true
fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval fast_structure(::$(T)) = true
end
fast_structure(::Function) = true

Check warning on line 178 in src/internal.jl

View check run for this annotation

Codecov / codecov/patch

src/internal.jl#L178

Added line #L178 was not covered by tests
fast_structure(_) = false

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end
Expand Down
17 changes: 16 additions & 1 deletion src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end
# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end

"""
functional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> Bool
Expand Down Expand Up @@ -245,6 +248,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice()` -- denotes that the device type is unknown
See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
based on device type.
"""
Expand All @@ -258,6 +267,12 @@ itself. This value is often a compile time constant and is recommended to be use
of [`get_device`](@ref) where ever defining dispatches based on the device type.
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice` -- denotes that the device type is unknown
"""
function get_device_type end

Expand Down Expand Up @@ -345,7 +360,7 @@ end

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
Internal.fast_structure(x) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end
end
Expand Down

0 comments on commit 6c3a4a7

Please sign in to comment.