From 877e7aaf76cefefd70aa0fb75ac0cddab54b756e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 22:05:28 -0700 Subject: [PATCH] refactor: move internal functions into separate modules --- Project.toml | 2 +- ext/MLDataDevicesAMDGPUExt.jl | 20 +- ext/MLDataDevicesCUDAExt.jl | 28 +- ext/MLDataDevicesMetalExt.jl | 10 +- ext/MLDataDevicesRecursiveArrayToolsExt.jl | 10 +- ext/MLDataDevicesReverseDiffExt.jl | 12 +- ext/MLDataDevicesTrackerExt.jl | 14 +- ext/MLDataDevicesoneAPIExt.jl | 6 +- src/MLDataDevices.jl | 495 +-------------------- src/internal.jl | 144 ++++++ src/public.jl | 347 +++++++++++++++ test/amdgpu_tests.jl | 5 +- test/cuda_tests.jl | 5 +- test/metal_tests.jl | 5 +- test/misc_tests.jl | 2 +- test/oneapi_tests.jl | 5 +- 16 files changed, 551 insertions(+), 559 deletions(-) create mode 100644 src/internal.jl create mode 100644 src/public.jl diff --git a/Project.toml b/Project.toml index 13649ab..f264895 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.1" +version = "1.0.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index 7769b84..e539a15 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -2,7 +2,7 @@ module MLDataDevicesAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device! using Random: Random __init__() = reset_gpu_device!() @@ -10,7 +10,7 @@ __init__() = reset_gpu_device!() # This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) -function _check_use_amdgpu!() +function check_use_amdgpu!() USE_AMD_GPU[] === nothing || return USE_AMD_GPU[] = AMDGPU.functional() @@ -23,14 +23,12 @@ end MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool - _check_use_amdgpu!() + check_use_amdgpu!() return USE_AMD_GPU[] end -function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing) - return AMDGPUDevice(nothing) -end -function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer) +Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing) +function Internal.with_device(::Type{AMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -40,19 +38,19 @@ function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer) return device end -MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) +Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) # Default RNG MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -function MLDataDevices._get_device(x::AMDGPU.AnyROCArray) +function Internal.get_device(x::AMDGPU.AnyROCArray) parent_x = parent(x) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) - return MLDataDevices._get_device(parent_x) + return Internal.get_device(parent_x) end -MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice # Set Device function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) diff --git a/ext/MLDataDevicesCUDAExt.jl b/ext/MLDataDevicesCUDAExt.jl index 6362f80..cc4cde4 100644 --- a/ext/MLDataDevicesCUDAExt.jl +++ b/ext/MLDataDevicesCUDAExt.jl @@ -2,11 +2,12 @@ module MLDataDevicesCUDAExt using Adapt: Adapt using CUDA: CUDA -using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector, AbstractCuSparseArray +using MLDataDevices: MLDataDevices, Internal, CUDADevice, CPUDevice using Random: Random -function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer) +Internal.with_device(::Type{CUDADevice}, ::Nothing) = CUDADevice(nothing) +function Internal.with_device(::Type{CUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -16,34 +17,23 @@ function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer) return device end -function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing) - return CUDADevice(nothing) -end - -MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 +Internal.get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 # Default RNG MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng() # Query Device from Array -function MLDataDevices._get_device(x::CUDA.AnyCuArray) +function Internal.get_device(x::CUDA.AnyCuArray) parent_x = parent(x) parent_x === x && return CUDADevice(CUDA.device(x)) return MLDataDevices.get_device(parent_x) end -function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) - return CUDADevice(CUDA.device(x.nzVal)) -end +Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) -function MLDataDevices._get_device_type(::Union{ - <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) - return CUDADevice -end +Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice # Set Device -function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) - return CUDA.device!(dev) -end +MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer) return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id]) end diff --git a/ext/MLDataDevicesMetalExt.jl b/ext/MLDataDevicesMetalExt.jl index 1c81689..87d0b0e 100644 --- a/ext/MLDataDevicesMetalExt.jl +++ b/ext/MLDataDevicesMetalExt.jl @@ -2,23 +2,21 @@ module MLDataDevicesMetalExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device! using Metal: Metal, MtlArray __init__() = reset_gpu_device!() MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true -function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) - return Metal.functional() -end +MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) = Metal.functional() # Default RNG MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) # Query Device from Array -MLDataDevices._get_device(::MtlArray) = MetalDevice() +Internal.get_device(::MtlArray) = MetalDevice() -MLDataDevices._get_device_type(::MtlArray) = MetalDevice +Internal.get_device_type(::MtlArray) = MetalDevice # Device Transfer ## To GPU diff --git a/ext/MLDataDevicesRecursiveArrayToolsExt.jl b/ext/MLDataDevicesRecursiveArrayToolsExt.jl index 4277150..f0b29a2 100644 --- a/ext/MLDataDevicesRecursiveArrayToolsExt.jl +++ b/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -1,7 +1,7 @@ module MLDataDevicesRecursiveArrayToolsExt using Adapt: Adapt, adapt -using MLDataDevices: MLDataDevices, AbstractDevice +using MLDataDevices: MLDataDevices, Internal, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure @@ -14,10 +14,10 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end -for op in (:_get_device, :_get_device_type) - @eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) - return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u) +for op in (:get_device, :get_device_type) + @eval function Internal.$(op)(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return $(op == :get_device ? nothing : Nothing) + return mapreduce(Internal.$(op), Internal.combine_devices, x.u) end end diff --git a/ext/MLDataDevicesReverseDiffExt.jl b/ext/MLDataDevicesReverseDiffExt.jl index 9e6553e..eeb9442 100644 --- a/ext/MLDataDevicesReverseDiffExt.jl +++ b/ext/MLDataDevicesReverseDiffExt.jl @@ -1,16 +1,12 @@ module MLDataDevicesReverseDiffExt -using MLDataDevices: MLDataDevices +using MLDataDevices: Internal using ReverseDiff: ReverseDiff -for op in (:_get_device, :_get_device_type) +for op in (:get_device, :get_device_type) @eval begin - function MLDataDevices.$op(x::ReverseDiff.TrackedArray) - return MLDataDevices.$op(ReverseDiff.value(x)) - end - function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return MLDataDevices.$op(ReverseDiff.value.(x)) - end + Internal.$(op)(x::ReverseDiff.TrackedArray) = Internal.$(op)(ReverseDiff.value(x)) + Internal.$(op)(x::AbstractArray{<:ReverseDiff.TrackedReal}) = Internal.$(op)(ReverseDiff.value.(x)) end end diff --git a/ext/MLDataDevicesTrackerExt.jl b/ext/MLDataDevicesTrackerExt.jl index 49ef3ea..f9b90d9 100644 --- a/ext/MLDataDevicesTrackerExt.jl +++ b/ext/MLDataDevicesTrackerExt.jl @@ -1,19 +1,15 @@ module MLDataDevicesTrackerExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice +using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice using Tracker: Tracker -for op in (:_get_device, :_get_device_type) - @eval begin - MLDataDevices.$op(x::Tracker.TrackedArray) = MLDataDevices.$op(Tracker.data(x)) - function MLDataDevices.$op(x::AbstractArray{<:Tracker.TrackedReal}) - return MLDataDevices.$op(Tracker.data.(x)) - end - end +for op in (:get_device, :get_device_type) + @eval Internal.$(op)(x::Tracker.TrackedArray) = Internal.$(op)(Tracker.data(x)) + @eval Internal.$(op)(x::AbstractArray{<:Tracker.TrackedReal}) = Internal.$(op)(Tracker.data.(x)) end -MLDataDevices.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) diff --git a/ext/MLDataDevicesoneAPIExt.jl b/ext/MLDataDevicesoneAPIExt.jl index ebffa02..4bda871 100644 --- a/ext/MLDataDevicesoneAPIExt.jl +++ b/ext/MLDataDevicesoneAPIExt.jl @@ -2,7 +2,7 @@ module MLDataDevicesoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: MLDataDevices, oneAPIDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, oneAPIDevice, reset_gpu_device! using oneAPI: oneAPI, oneArray, oneL0 const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() @@ -25,9 +25,9 @@ end MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -MLDataDevices._get_device(::oneArray) = oneAPIDevice() +Internal.get_device(::oneArray) = oneAPIDevice() -MLDataDevices._get_device_type(::oneArray) = oneAPIDevice +Internal.get_device_type(::oneArray) = oneAPIDevice # Device Transfer ## To GPU diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index 556bfab..b7636db 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -2,13 +2,18 @@ module MLDataDevices using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent -using Functors: Functors, fmap, fleaves +using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random -using UnrolledUtilities: unrolled_mapreduce const CRC = ChainRulesCore +abstract type AbstractDevice <: Function end +abstract type AbstractGPUDevice <: AbstractDevice end + +include("public.jl") +include("internal.jl") + export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device @@ -16,490 +21,4 @@ export gpu_device, cpu_device export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type -abstract type AbstractDevice <: Function end -abstract type AbstractGPUDevice <: AbstractDevice end - -""" - functional(x::AbstractDevice) -> Bool - functional(::Type{<:AbstractDevice}) -> Bool - -Checks if the device is functional. This is used to determine if the device can be used for -computation. Note that even if the backend is loaded (as checked via -[`MLDataDevices.loaded`](@ref)), the device may not be functional. - -Note that while this function is not exported, it is considered part of the public API. -""" -@inline functional(x) = false - -""" - loaded(x::AbstractDevice) -> Bool - loaded(::Type{<:AbstractDevice}) -> Bool - -Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. - - `AMDGPU.jl` for AMD GPU ROCM Support. - - `Metal.jl` for Apple Metal GPU Support. - - `oneAPI.jl` for Intel oneAPI GPU Support. -""" -@inline loaded(x) = false - -struct CPUDevice <: AbstractDevice end -@kwdef struct CUDADevice{D} <: AbstractGPUDevice - device::D = nothing -end -@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice - device::D = nothing -end -struct MetalDevice <: AbstractGPUDevice end -struct oneAPIDevice <: AbstractGPUDevice end - -for dev in (CPUDevice, MetalDevice, oneAPIDevice) - msg = "`device_id` is not applicable for `$dev`." - @eval begin - _with_device(::Type{$dev}, ::Nothing) = $dev() - function _with_device(::Type{$dev}, device_id) - @warn $(msg) maxlog=1 - return $dev() - end - end -end - -@inline functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true -@inline loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true - -for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : string(name) - ldev = eval(Symbol(name, :Device)) - @eval begin - @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) - @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) - end -end - -for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) - @eval @inline _get_device_id(::$(T)) = nothing -end - -struct DeviceSelectionException <: Exception end - -function Base.showerror(io::IO, ::DeviceSelectionException) - return print(io, "DeviceSelectionException(No functional GPU device found!!)") -end - -# Order is important here -const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) - -const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) - -""" - reset_gpu_device!() - -Resets the selected GPU device. This is useful when automatic GPU selection needs to be -run again. -""" -@inline reset_gpu_device!() = (GPU_DEVICE[] = nothing) - -""" - supported_gpu_backends() -> Tuple{String, ...} - -Return a tuple of supported GPU backends. - -!!! warning - - This is not the list of functional backends on the system, but rather backends which - `MLDataDevices.jl` supports. -""" -@inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) - -""" - gpu_device(device_id::Union{Nothing, Integer}=nothing; - force_gpu_usage::Bool=false) -> AbstractDevice() - -Selects GPU device based on the following criteria: - - 1. If `gpu_backend` preference is set and the backend is functional on the system, then - that device is selected. - 2. Otherwise, an automatic selection algorithm is used. We go over possible device - backends in the order specified by `supported_gpu_backends()` and select the first - functional backend. - 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is - invoked. - 4. If nothing works, an error is thrown. - -## Arguments - - - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return - the last selected device or if none was selected then we run the autoselection and - choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If - `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in - contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to - `CUDA.device!(3)`. - -!!! warning - - `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` - and `CPU` backends, `device_id` is ignored and a warning is printed. - -!!! warning - - `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. - This is to ensure that deep learning operations work correctly. - Nonetheless, if cuDNN is not loaded you can still manually create a - `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). - -## Keyword Arguments - - - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU - device is found. -""" -function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; - force_gpu_usage::Bool=false)::AbstractDevice - device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) - - if GPU_DEVICE[] !== nothing - dev = GPU_DEVICE[] - if device_id === nothing - force_gpu_usage && - !(dev isa AbstractGPUDevice) && - throw(DeviceSelectionException()) - return dev - else - selected_device_id = _get_device_id(dev) - selected_device_id !== nothing && selected_device_id == device_id && return dev - end - end - - device_type = _get_gpu_device(; force_gpu_usage) - device = _with_device(device_type, device_id) - GPU_DEVICE[] = device - - return device -end - -function _get_gpu_device(; force_gpu_usage::Bool) - backend = @load_preference("gpu_backend", nothing) - - # If backend set with preferences, use it - if backend !== nothing - allowed_backends = supported_gpu_backends() - if backend ∉ allowed_backends - @warn "`gpu_backend` preference is set to $backend, which is not a valid \ - backend. Valid backends are $allowed_backends. Defaulting to automatic \ - GPU Backend selection." maxlog=1 - else - @debug "Using GPU backend set in preferences: $backend." - idx = findfirst(isequal(backend), allowed_backends) - device = GPU_DEVICES[idx] - if !loaded(device) - @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ - package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ - Preferences backend!!! Please load the package and call this \ - function again to respect the Preferences backend." maxlog=1 - else - if functional(device) - @debug "Using GPU backend: $(_get_device_name(device))." - return device - else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ - is not functional. Defaulting to automatic GPU Backend \ - selection." maxlog=1 - end - end - end - end - - @debug "Running automatic GPU backend selection..." - for device in GPU_DEVICES - if loaded(device) - @debug "Trying backend: $(_get_device_name(device))." - if functional(device) - @debug "Using GPU backend: $(_get_device_name(device))." - return device - end - @debug "GPU backend: $(_get_device_name(device)) is not functional." - else - @debug "Trigger package for backend ($(_get_device_name(device))): \ - $(_get_triggerpkg_name(device)) not loaded." - end - end - - if force_gpu_usage - throw(DeviceSelectionException()) - else - @warn """No functional GPU backend found! Defaulting to CPU. - - 1. If no GPU is available, nothing needs to be done. - 2. If GPU is available, load the corresponding trigger package. - a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. - b. `AMDGPU.jl` for AMD GPU ROCM Support. - c. `Metal.jl` for Apple Metal GPU Support. (Experimental) - d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 - return CPUDevice - end -end - -""" - gpu_backend!() = gpu_backend!("") - gpu_backend!(backend) = gpu_backend!(string(backend)) - gpu_backend!(backend::AbstractGPUDevice) - gpu_backend!(backend::String) - -Creates a `LocalPreferences.toml` file with the desired GPU backend. - -If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is -validated to be one of the possible backends and the preference is set to `backend`. - -If a new backend is successfully set, then the Julia session must be restarted for the -change to take effect. -""" -gpu_backend!(backend) = gpu_backend!(string(backend)) -gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(_get_device_name(backend)) -gpu_backend!() = gpu_backend!("") -function gpu_backend!(backend::String) - if backend == "" - @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ - new backend." - return - end - - allowed_backends = supported_gpu_backends() - - set_backend = @load_preference("gpu_backend", nothing) - if set_backend == backend - @info "GPU backend is already set to $backend. No action is required." - return - end - - if backend ∉ allowed_backends - throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) - end - - @set_preferences!("gpu_backend"=>backend) - @info "GPU backend has been set to $backend. Restart Julia to use the new backend." - return -end - -""" - cpu_device() -> CPUDevice() - -Return a `CPUDevice` object which can be used to transfer data to CPU. -""" -@inline cpu_device() = CPUDevice() - -""" - default_device_rng(::AbstractDevice) - -Returns the default RNG for the device. This can be used to directly generate parameters -and states on the device using -[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). -""" -function default_device_rng(D::AbstractDevice) - return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ - either because: - - 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. - """) -end -default_device_rng(::CPUDevice) = Random.default_rng() - -# Dispatches for Different Data Structures -# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability -# For all other types we rely on fmap which means we lose type stability. -# For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("$(dev)Device") - @eval begin - function (D::$(ldev))(x::AbstractArray{T}) where {T} - fn = Base.Fix1(Adapt.adapt, D) - return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) - end - (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) - function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) - return fmap(D, x) - end - end -end - -@inline __special_aos(x::AbstractArray) = false - -const GET_DEVICE_ADMONITIONS = """ -!!! note - - Trigger Packages must be loaded for this to return the correct device. - -!!! warning - - RNG types currently don't participate in device determination. We will remove this - restriction in the future. -""" - -# Query Device from Array -""" - get_device(x) -> dev::AbstractDevice | Exception | Nothing - -If all arrays (on the leaves of the structure) are on the same device, we return that -device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. - -$(GET_DEVICE_ADMONITIONS) - -See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch -based on device type. -""" -function get_device end - -""" - get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} - -Similar to [`get_device`](@ref) but returns the type of the device instead of the device -itself. This value is often a compile time constant and is recommended to be used instead -of [`get_device`](@ref) where ever defining dispatches based on the device type. - -$(GET_DEVICE_ADMONITIONS) -""" -function get_device_type end - -for op in (:get_device, :get_device_type) - _op = Symbol("_", op) - cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice - @eval begin - function $(op)(x) - hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x) - return mapreduce($(_op), __combine_devices, fleaves(x)) - end - - CRC.@non_differentiable $op(::Any) - - function $(_op)(x::AbstractArray{T}) where {T} - __recursible_array_eltype(T) && return mapreduce($(op), __combine_devices, x) - if hasmethod(parent, Tuple{typeof(x)}) - parent_x = parent(x) - parent_x === x && return $(cpu_ret_val) - return $(_op)(parent_x) - end - return $(cpu_ret_val) - end - - function $(_op)(x::Union{Tuple, NamedTuple}) - length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return unrolled_mapreduce($(op), __combine_devices, values(x)) - end - end - - for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) - @eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing) - end -end - -__recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) - -__combine_devices(::Nothing, ::Nothing) = nothing -__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing -__combine_devices(::Nothing, dev::AbstractDevice) = dev -__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T -__combine_devices(dev::AbstractDevice, ::Nothing) = dev -__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T -function __combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) - dev1 == dev2 && return dev1 - throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) -end -__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T -function __combine_devices( - ::Type{T1}, ::Type{T2}) where {T1 <: AbstractDevice, T2 <: AbstractDevice} - throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) -end - -# Set the device -const SET_DEVICE_DOCS = """ -Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` -and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not -loaded. - -Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. -""" - -const SET_DEVICE_DANGER = """ -!!! danger - - This specific function should be considered experimental at this point and is currently - provided to support distributed training in Lux. As such please use - `Lux.DistributedUtils` instead of using this function. -""" - -""" - set_device!(T::Type{<:AbstractDevice}, dev_or_id) - -$SET_DEVICE_DOCS - -## Arguments - - - `T::Type{<:AbstractDevice}`: The device type to set. - - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it - can be a `CuDevice`. If it is an integer, it is the device id to set. This is - `1`-indexed. - -$SET_DEVICE_DANGER -""" -function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} - T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." - T === AMDGPUDevice && - @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." - T === MetalDevice && - @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." - T === oneAPIDevice && - @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." - T === CPUDevice && - @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." - return -end - -""" - set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) - -$SET_DEVICE_DOCS - -## Arguments - - - `T::Type{<:AbstractDevice}`: The device type to set. - - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and - must be `0`-indexed. - -$SET_DEVICE_DANGER -""" -function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} - return set_device!(T, rank) -end - -# Adapt Interface - -Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng - -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) - @eval begin - function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) - return default_device_rng(to) - end - Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng - end -end - -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - -# Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) - end - return Adapt.adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/src/internal.jl b/src/internal.jl new file mode 100644 index 0000000..664dc52 --- /dev/null +++ b/src/internal.jl @@ -0,0 +1,144 @@ +module Internal + +using Preferences: load_preference +using Random: AbstractRNG +using UnrolledUtilities: unrolled_mapreduce + +using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES, + loaded, functional + +for dev in (CPUDevice, MetalDevice, oneAPIDevice) + msg = "`device_id` is not applicable for `$dev`." + @eval begin + with_device(::Type{$dev}, ::Nothing) = $dev() + function with_device(::Type{$dev}, device_id) + @warn $(msg) maxlog=1 + return $dev() + end + end +end + +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + tpkg = name === :CPU ? "" : string(name) + ldev = Symbol(name, :Device) + @eval begin + get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) + get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) + end +end + +for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) + @eval get_device_id(::$(T)) = nothing +end + +struct DeviceSelectionException <: Exception end + +function Base.showerror(io::IO, ::DeviceSelectionException) + return print(io, "DeviceSelectionException(No functional GPU device found!!)") +end + +function get_gpu_device(; force_gpu_usage::Bool) + backend = load_preference(MLDataDevices, "gpu_backend", nothing) + + # If backend set with preferences, use it + if backend !== nothing + allowed_backends = supported_gpu_backends() + if backend ∉ allowed_backends + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 + else + @debug "Using GPU backend set in preferences: $backend." + idx = findfirst(isequal(backend), allowed_backends) + device = GPU_DEVICES[idx] + if !loaded(device) + @warn "Trying to use backend: $(get_device_name(device)) but the trigger \ + package $(get_triggerpkg_name(device)) is not loaded. Ignoring the \ + Preferences backend!!! Please load the package and call this \ + function again to respect the Preferences backend." maxlog=1 + else + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + else + @warn "GPU backend: $(get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 + end + end + end + end + + @debug "Running automatic GPU backend selection..." + for device in GPU_DEVICES + if loaded(device) + @debug "Trying backend: $(get_device_name(device))." + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + end + @debug "GPU backend: $(get_device_name(device)) is not functional." + else + @debug "Trigger package for backend ($(get_device_name(device))): \ + $(get_triggerpkg_name(device)) not loaded." + end + end + + force_gpu_usage && throw(DeviceSelectionException()) + @warn """No functional GPU backend found! Defaulting to CPU. + + 1. If no GPU is available, nothing needs to be done. + 2. If GPU is available, load the corresponding trigger package. + a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 + return CPUDevice +end + +special_aos(::AbstractArray) = false + +recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) + +combine_devices(::Nothing, ::Nothing) = nothing +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing +combine_devices(::Nothing, dev::AbstractDevice) = dev +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(dev::AbstractDevice, ::Nothing) = dev +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) + dev1 == dev2 && return dev1 + throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) +end +combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) + throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) +end + +for op in (:get_device, :get_device_type) + cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + + @eval begin + function $(op)(x::AbstractArray{T}) where {T} + recursive_array_eltype(T) && return mapreduce($(op), combine_devices, x) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return $(cpu_ret_val) + return $(op)(parent_x) + end + return $(cpu_ret_val) + end + + function $(op)(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return $(op == :get_device ? nothing : Nothing) + return unrolled_mapreduce($(op), combine_devices, values(x)) + end + end + + for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) + @eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing) + end +end + +end diff --git a/src/public.jl b/src/public.jl new file mode 100644 index 0000000..ac53ee5 --- /dev/null +++ b/src/public.jl @@ -0,0 +1,347 @@ +struct CPUDevice <: AbstractDevice end +@kwdef struct CUDADevice{D} <: AbstractGPUDevice + device::D = nothing +end +@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice + device::D = nothing +end +struct MetalDevice <: AbstractGPUDevice end +struct oneAPIDevice <: AbstractGPUDevice end + +""" + functional(x::AbstractDevice) -> Bool + functional(::Type{<:AbstractDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`MLDataDevices.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +functional(x) = false +functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +""" + loaded(x::AbstractDevice) -> Bool + loaded(::Type{<:AbstractDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +loaded(x) = false +loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +# Order is important here +const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + +const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) + +""" + reset_gpu_device!() + +Resets the selected GPU device. This is useful when automatic GPU selection needs to be +run again. +""" +reset_gpu_device!() = (GPU_DEVICE[] = nothing) + +""" + supported_gpu_backends() -> Tuple{String, ...} + +Return a tuple of supported GPU backends. + +!!! warning + + This is not the list of functional backends on the system, but rather backends which + `MLDataDevices.jl` supports. +""" +supported_gpu_backends() = map(Internal.get_device_name, GPU_DEVICES) + +""" + gpu_device(device_id::Union{Nothing, Integer}=nothing; + force_gpu_usage::Bool=false) -> AbstractDevice() + +Selects GPU device based on the following criteria: + + 1. If `gpu_backend` preference is set and the backend is functional on the system, then + that device is selected. + 2. Otherwise, an automatic selection algorithm is used. We go over possible device + backends in the order specified by `supported_gpu_backends()` and select the first + functional backend. + 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is + invoked. + 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. + +!!! warning + + `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. + This is to ensure that deep learning operations work correctly. + Nonetheless, if cuDNN is not loaded you can still manually create a + `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). + +## Keyword Arguments + + - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. +""" +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; + force_gpu_usage::Bool=false)::AbstractDevice + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + + if GPU_DEVICE[] !== nothing + dev = GPU_DEVICE[] + if device_id === nothing + force_gpu_usage && + !(dev isa AbstractGPUDevice) && + throw(Internal.DeviceSelectionException()) + return dev + else + selected_device_id = Internal.get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end + end + + device_type = Internal.get_gpu_device(; force_gpu_usage) + device = Internal.with_device(device_type, device_id) + GPU_DEVICE[] = device + + return device +end + +""" + gpu_backend!() = gpu_backend!("") + gpu_backend!(backend) = gpu_backend!(string(backend)) + gpu_backend!(backend::AbstractGPUDevice) + gpu_backend!(backend::String) + +Creates a `LocalPreferences.toml` file with the desired GPU backend. + +If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is +validated to be one of the possible backends and the preference is set to `backend`. + +If a new backend is successfully set, then the Julia session must be restarted for the +change to take effect. +""" +gpu_backend!(backend) = gpu_backend!(string(backend)) +gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(Internal.get_device_name(backend)) +gpu_backend!() = gpu_backend!("") +function gpu_backend!(backend::String) + if backend == "" + @delete_preferences!("gpu_backend") + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." + return + end + + allowed_backends = supported_gpu_backends() + + set_backend = @load_preference("gpu_backend", nothing) + if set_backend == backend + @info "GPU backend is already set to $backend. No action is required." + return + end + + if backend ∉ allowed_backends + throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) + end + + @set_preferences!("gpu_backend"=>backend) + @info "GPU backend has been set to $backend. Restart Julia to use the new backend." + return +end + +""" + cpu_device() -> CPUDevice() + +Return a `CPUDevice` object which can be used to transfer data to CPU. +""" +cpu_device() = CPUDevice() + +""" + default_device_rng(::AbstractDevice) + +Returns the default RNG for the device. This can be used to directly generate parameters +and states on the device using +[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). +""" +function default_device_rng(D::AbstractDevice) + return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ + either because: + + 1. The default RNG for this device is not known / officially provided. + 2. The trigger package for the device ($(Internal.get_device_name(D)).jl) is not loaded. + """) +end +default_device_rng(::CPUDevice) = Random.default_rng() + +const GET_DEVICE_ADMONITIONS = """ +!!! note + + Trigger Packages must be loaded for this to return the correct device. + +!!! warning + + RNG types currently don't participate in device determination. We will remove this + restriction in the future. +""" + +# Query Device from Array +""" + get_device(x) -> dev::AbstractDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. + +$(GET_DEVICE_ADMONITIONS) + +See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch +based on device type. +""" +function get_device end + +""" + get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} + +Similar to [`get_device`](@ref) but returns the type of the device instead of the device +itself. This value is often a compile time constant and is recommended to be used instead +of [`get_device`](@ref) where ever defining dispatches based on the device type. + +$(GET_DEVICE_ADMONITIONS) +""" +function get_device_type end + +# Set the device +const SET_DEVICE_DOCS = """ +Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` +and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not +loaded. + +Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. +""" + +const SET_DEVICE_DANGER = """ +!!! danger + + This specific function should be considered experimental at this point and is currently + provided to support distributed training in Lux. As such please use + `Lux.DistributedUtils` instead of using this function. +""" + +""" + set_device!(T::Type{<:AbstractDevice}, dev_or_id) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it + can be a `CuDevice`. If it is an integer, it is the device id to set. This is + `1`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." + T === AMDGPUDevice && + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." + T === MetalDevice && + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." + T === oneAPIDevice && + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." + T === CPUDevice && + @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." + return +end + +""" + set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and + must be `0`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} + return set_device!(T, rank) +end + +# Dispatches for Different Data Structures +# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability +# For all other types we rely on fmap which means we lose type stability. +# For Lux, typically models only has these 3 datastructures so we should be mostly fine. +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol("$(dev)Device") + @eval begin + function (D::$(ldev))(x::AbstractArray{T}) where {T} + return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : + map(D, x) + end + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) + function (D::$(ldev))(x) + Functors.isleaf(x) && return Adapt.adapt(D, x) + return Functors.fmap(D, x) + end + end +end + +for op in (:get_device, :get_device_type) + @eval begin + function $(op)(x) + hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) + end + + CRC.@non_differentiable $op(::Any) + end +end + +# Adapt Interface +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng + +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) + @eval begin + function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) + return default_device_rng(to) + end + Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng + end +end + +Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x +# Prevent Ambiguity +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) + @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) +end + +# Chain Rules Core +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end + return Adapt.adapt_storage(to, x), ∇adapt_storage +end diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 0338031..a4cb8cf 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( AMDGPUDevice, nothing, 1) @@ -23,7 +24,7 @@ using AMDGPU else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 7804183..c6cf533 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(CUDADevice(nothing)) @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( CUDADevice, nothing, 1) @@ -23,7 +24,7 @@ using LuxCUDA else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 3bf98ec..a4dd887 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(MetalDevice()) end @@ -21,7 +22,7 @@ using Metal else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/test/misc_tests.jl b/test/misc_tests.jl index e3f3ed8..aa39962 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -127,7 +127,7 @@ end for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), CUDADevice(), MetalDevice(), oneAPIDevice()) backend_name = backend isa Symbol ? string(backend) : - MLDataDevices._get_device_name(backend) + MLDataDevices.Internal.get_device_name(backend) @test_logs (:info, "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) end diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index a9f25cf..f046498 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(oneAPIDevice()) end @@ -21,7 +22,7 @@ using oneAPI else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing