From 6c3a4a77139a1bd2dca33284c37bc4f41e6337ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 11:35:08 -0400 Subject: [PATCH] feat: add fallbacks for unknown objects --- Project.toml | 2 +- ext/MLDataDevicesChainRulesCoreExt.jl | 11 ++++++--- src/internal.jl | 34 +++++++++++++++++++++++---- src/public.jl | 17 +++++++++++++- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 1cb1875..41f3134 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index c6b9560..6a770b8 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable -using MLDataDevices: AbstractDevice, get_device, get_device_type +using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) function ChainRulesCore.rrule( ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + ∇adapt_storage = let dev = get_device(x) + if dev === nothing || dev isa UnknownDevice + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + Δ -> (NoTangent(), NoTangent(), Δ) + else + Δ -> (NoTangent(), NoTangent(), dev(Δ)) + end end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/src/internal.jl b/src/internal.jl index e13b716..bcc8cab 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -5,8 +5,8 @@ using Preferences: load_preference using Random: AbstractRNG using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends, - GPU_DEVICES, loaded, functional + MetalDevice, oneAPIDevice, XLADevice, UnknownDevice, + supported_gpu_backends, GPU_DEVICES, loaded, functional for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @@ -107,16 +107,22 @@ special_aos(::AbstractArray) = false recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) combine_devices(::Nothing, ::Nothing) = nothing -combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Nothing, dev::AbstractDevice) = dev -combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T combine_devices(dev::AbstractDevice, ::Nothing) = dev -combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) dev1 == dev2 && return dev1 + dev1 isa UnknownDevice && return dev2 + dev2 isa UnknownDevice && return dev1 throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end + +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end @@ -147,6 +153,13 @@ for op in (:get_device, :get_device_type) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) end + + function $(op)(f::F) where {F <: Function} + Base.issingletontype(F) && + return $(op == :get_device ? UnknownDevice() : UnknownDevice) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, + map(Base.Fix1(getfield, f), fieldnames(F))) + end end for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @@ -154,6 +167,17 @@ for op in (:get_device, :get_device_type) end end +get_device(_) = UnknownDevice() +get_device_type(_) = UnknownDevice + +fast_structure(::AbstractArray) = true +fast_structure(::Union{Tuple, NamedTuple}) = true +for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + @eval fast_structure(::$(T)) = true +end +fast_structure(::Function) = true +fast_structure(_) = false + function unrolled_mapreduce(f::F, op::O, itr) where {F, O} return unrolled_mapreduce(f, op, itr, static_length(itr)) end diff --git a/src/public.jl b/src/public.jl index 178c6f9..07deeaa 100644 --- a/src/public.jl +++ b/src/public.jl @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end # TODO: Later we might want to add the client field here? struct XLADevice <: AbstractAcceleratorDevice end +# Fallback for when we don't know the device type +struct UnknownDevice <: AbstractDevice end + """ functional(x::AbstractDevice) -> Bool functional(::Type{<:AbstractDevice}) -> Bool @@ -245,6 +248,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur $(GET_DEVICE_ADMONITIONS) +## Special Retuened Values + + - `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice()` -- denotes that the device type is unknown + See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. """ @@ -258,6 +267,12 @@ itself. This value is often a compile time constant and is recommended to be use of [`get_device`](@ref) where ever defining dispatches based on the device type. $(GET_DEVICE_ADMONITIONS) + +## Special Retuened Values + + - `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice` -- denotes that the device type is unknown """ function get_device_type end @@ -345,7 +360,7 @@ end for op in (:get_device, :get_device_type) @eval function $(op)(x) - hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + Internal.fast_structure(x) && return Internal.$(op)(x) return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) end end