diff --git a/.gitignore b/.gitignore index c2b7741..2fd7d52 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ Manifest.toml +*.cov generated build .vscode diff --git a/Project.toml b/Project.toml index 78889f7..09aca5d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "LuxDeviceUtils" +name = "DeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] version = "0.1.26" @@ -17,28 +17,28 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] -LuxDeviceUtilsAMDGPUExt = "AMDGPU" -LuxDeviceUtilsCUDAExt = "CUDA" -LuxDeviceUtilsFillArraysExt = "FillArrays" -LuxDeviceUtilsGPUArraysExt = "GPUArrays" -LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] -LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" -LuxDeviceUtilsReverseDiffExt = "ReverseDiff" -LuxDeviceUtilsSparseArraysExt = "SparseArrays" -LuxDeviceUtilsTrackerExt = "Tracker" -LuxDeviceUtilsZygoteExt = "Zygote" -LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] +DeviceUtilsAMDGPUExt = "AMDGPU" +DeviceUtilsCUDAExt = "CUDA" +DeviceUtilsFillArraysExt = "FillArrays" +DeviceUtilsGPUArraysExt = "GPUArrays" +DeviceUtilsMetalExt = ["GPUArrays", "Metal"] +DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +DeviceUtilsReverseDiffExt = "ReverseDiff" +DeviceUtilsSparseArraysExt = "SparseArrays" +DeviceUtilsTrackerExt = "Tracker" +DeviceUtilsZygoteExt = "Zygote" +DeviceUtilscuDNNExt = ["CUDA", "cuDNN"] +DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6" @@ -54,7 +54,6 @@ FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" GPUArrays = "10" -LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" Pkg = "1.10" @@ -68,9 +67,11 @@ Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" Zygote = "0.6.69" +cuDNN = "1.3" julia = "1.10" oneAPI = "1.5" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/README.md b/README.md index 0fae7fd..f377cff 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,19 @@ -# LuxDeviceUtils +# DeviceUtils [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/DeviceUtils) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/DeviceUtils) -[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) +[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across -devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead. +`DeviceUtils.jl` is a lightweight package defining rules for transferring data across +devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/). Currently we provide support for the following backends: diff --git a/ext/DeviceUtilsAMDGPUExt.jl b/ext/DeviceUtilsAMDGPUExt.jl new file mode 100644 index 0000000..365a119 --- /dev/null +++ b/ext/DeviceUtilsAMDGPUExt.jl @@ -0,0 +1,89 @@ +module DeviceUtilsAMDGPUExt + +using Adapt: Adapt +using AMDGPU: AMDGPU +using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! +using Random: Random + +__init__() = reset_gpu_device!() + +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true +function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool + _check_use_amdgpu!() + return USE_AMD_GPU[] +end + +function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) + return AMDGPUDevice(nothing) +end +function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) + old_dev = AMDGPU.device() + AMDGPU.device!(AMDGPU.devices()[id]) + device = AMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(old_dev) + return device +end + +DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) + +# Default RNG +DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() + +# Query Device from Array +function DeviceUtils.get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) + return DeviceUtils.get_device(parent_x) +end + +# Set Device +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) + return AMDGPU.device!(dev) +end +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) + return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) +end +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(AMDGPU.devices())) + return DeviceUtils.set_device!(AMDGPUDevice, id) +end + +# Device Transfer +## To GPU +Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) + old_dev = AMDGPU.device() # remember the current device + dev = DeviceUtils.get_device(x) + if !(dev isa AMDGPUDevice) + AMDGPU.device!(to.device) + x_new = AMDGPU.roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() + +end diff --git a/ext/DeviceUtilsCUDAExt.jl b/ext/DeviceUtilsCUDAExt.jl new file mode 100644 index 0000000..b51fa4f --- /dev/null +++ b/ext/DeviceUtilsCUDAExt.jl @@ -0,0 +1,85 @@ +module DeviceUtilsCUDAExt + +using Adapt: Adapt +using CUDA: CUDA +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector +using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice, reset_gpu_device! +using Random: Random + +function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) + old_dev = CUDA.device() + CUDA.device!(id - 1) + device = CUDADevice(CUDA.device()) + CUDA.device!(old_dev) + return device +end + +function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) + return CUDADevice(nothing) +end + +DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 + +# Default RNG +DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() + +# Query Device from Array +function DeviceUtils.get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return CUDADevice(CUDA.device(x)) + return DeviceUtils.get_device(parent_x) +end +function DeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) + return CUDADevice(CUDA.device(x.nzVal)) +end + +# Set Device +function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) + return CUDA.device!(dev) +end +function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) + return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) +end +function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(CUDA.devices())) + return DeviceUtils.set_device!(CUDADevice, id) +end + +# Device Transfer +Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) + old_dev = CUDA.device() # remember the current device + dev = DeviceUtils.get_device(x) + if !(dev isa CUDADevice) + CUDA.device!(to.device) + x_new = CUDA.cu(x) + CUDA.device!(old_dev) + return x_new + elseif dev.device == to.device + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() + +# Defining as extensions seems to case precompilation errors +@static if isdefined(CUDA.CUSPARSE, :SparseArrays) + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) + return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) + end + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) + return CUDA.CUSPARSE.SparseArrays.SparseVector(x) + end +else + @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ + an issue in DeviceUtils.jl repository." +end + +end diff --git a/ext/DeviceUtilsFillArraysExt.jl b/ext/DeviceUtilsFillArraysExt.jl new file mode 100644 index 0000000..25a9d61 --- /dev/null +++ b/ext/DeviceUtilsFillArraysExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsFillArraysExt + +using Adapt: Adapt +using FillArrays: FillArrays, AbstractFill +using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice + +Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x +Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) + +end diff --git a/ext/DeviceUtilsGPUArraysExt.jl b/ext/DeviceUtilsGPUArraysExt.jl new file mode 100644 index 0000000..304b3f0 --- /dev/null +++ b/ext/DeviceUtilsGPUArraysExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsGPUArraysExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using DeviceUtils: CPUDevice +using Random: Random + +Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() + +end diff --git a/ext/DeviceUtilsMetalExt.jl b/ext/DeviceUtilsMetalExt.jl new file mode 100644 index 0000000..25724d6 --- /dev/null +++ b/ext/DeviceUtilsMetalExt.jl @@ -0,0 +1,25 @@ +module DeviceUtilsMetalExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray + +__init__() = reset_gpu_device!() + +DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true +function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) + return Metal.functional() +end + +# Default RNG +DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) + +# Query Device from Array +DeviceUtils.get_device(::MtlArray) = MetalDevice() + +# Device Transfer +## To GPU +Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) + +end diff --git a/ext/DeviceUtilsRecursiveArrayToolsExt.jl b/ext/DeviceUtilsRecursiveArrayToolsExt.jl new file mode 100644 index 0000000..6319220 --- /dev/null +++ b/ext/DeviceUtilsRecursiveArrayToolsExt.jl @@ -0,0 +1,21 @@ +module DeviceUtilsRecursiveArrayToolsExt + +using Adapt: Adapt, adapt +using DeviceUtils: DeviceUtils, AbstractDevice +using RecursiveArrayTools: VectorOfArray, DiffEqArray + +# We want to preserve the structure +function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) + return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) +end + +function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) + # Don't move the `time` to the GPU + return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) +end + +function DeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) + return mapreduce(DeviceUtils.get_device, DeviceUtils.__combine_devices, x.u) +end + +end diff --git a/ext/DeviceUtilsReverseDiffExt.jl b/ext/DeviceUtilsReverseDiffExt.jl new file mode 100644 index 0000000..27b6bfe --- /dev/null +++ b/ext/DeviceUtilsReverseDiffExt.jl @@ -0,0 +1,13 @@ +module DeviceUtilsReverseDiffExt + +using DeviceUtils: DeviceUtils +using ReverseDiff: ReverseDiff + +@inline function DeviceUtils.get_device(x::ReverseDiff.TrackedArray) + return DeviceUtils.get_device(ReverseDiff.value(x)) +end +@inline function DeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return DeviceUtils.get_device(ReverseDiff.value.(x)) +end + +end diff --git a/ext/DeviceUtilsSparseArraysExt.jl b/ext/DeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000..6c3c15d --- /dev/null +++ b/ext/DeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module DeviceUtilsSparseArraysExt + +using Adapt: Adapt +using DeviceUtils: CPUDevice +using SparseArrays: AbstractSparseArray + +Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x + +end diff --git a/ext/DeviceUtilsTrackerExt.jl b/ext/DeviceUtilsTrackerExt.jl new file mode 100644 index 0000000..9dbedbb --- /dev/null +++ b/ext/DeviceUtilsTrackerExt.jl @@ -0,0 +1,26 @@ +module DeviceUtilsTrackerExt + +using Adapt: Adapt +using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, + oneAPIDevice +using Tracker: Tracker + +@inline function DeviceUtils.get_device(x::Tracker.TrackedArray) + return DeviceUtils.get_device(Tracker.data(x)) +end +@inline function DeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) + return DeviceUtils.get_device(Tracker.data.(x)) +end + +@inline DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/ext/DeviceUtilsZygoteExt.jl b/ext/DeviceUtilsZygoteExt.jl new file mode 100644 index 0000000..5b7e6b0 --- /dev/null +++ b/ext/DeviceUtilsZygoteExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsZygoteExt + +using Adapt: Adapt +using DeviceUtils: AbstractDevice, CPUDevice +using Zygote: OneElement + +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) + +end diff --git a/ext/DeviceUtilscuDNNExt.jl b/ext/DeviceUtilscuDNNExt.jl new file mode 100644 index 0000000..32ee3c9 --- /dev/null +++ b/ext/DeviceUtilscuDNNExt.jl @@ -0,0 +1,56 @@ +module DeviceUtilscuDNNExt + +using CUDA: CUDA +using cuDNN: cuDNN +using DeviceUtils: DeviceUtils, CUDADevice, reset_gpu_device! + +__init__() = reset_gpu_device!() + +const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_cuda!() + USE_CUDA_GPU[] === nothing || return + + USE_CUDA_GPU[] = CUDA.functional() + if USE_CUDA_GPU[] + + ### Uncomment the following and move all to CUDA extension + ### when we will ditch the cuDNN.jl dependency + + # cudnnid = Base.identify_package("cuDNN") + # cudnn_loaded = cudnnid in keys(Base.loaded_modules) + # if !cudnn_loaded + # @warn """ + # cuDNN is not loaded. Some functionality will not be available. + # Load cuDNN by running `using CUDA, cuDNN` or `using LuxCUDA`. + # """ maxlog=1 + # else + # cuDNN = Base.loaded_modules[cudnnid] + # if !cuDNN.has_cudnn() + # @warn """ + # cuDNN is not functional. Some functionality will not be available. + # """ maxlog=1 + # end + # end + + if !cuDNN.has_cudnn() + @warn """ + cuDNN is not functional. Some functionality will not be available. + """ maxlog=1 + + # We make the device selectable only if cuDNN is functional + # to avoid issues with convolutions and other deep learning operations + USE_CUDA_GPU[] = false + end + end + return +end + +DeviceUtils.loaded(::Union{CUDADevice, Type{<:CUDADevice}}) = true + +function DeviceUtils.functional(::Union{CUDADevice, Type{<:CUDADevice}})::Bool + _check_use_cuda!() + return USE_CUDA_GPU[] +end + +end diff --git a/ext/LuxDeviceUtilsoneAPIExt.jl b/ext/DeviceUtilsoneAPIExt.jl similarity index 57% rename from ext/LuxDeviceUtilsoneAPIExt.jl rename to ext/DeviceUtilsoneAPIExt.jl index f9da407..24ef8c4 100644 --- a/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/ext/DeviceUtilsoneAPIExt.jl @@ -1,8 +1,8 @@ -module LuxDeviceUtilsoneAPIExt +module DeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! +using DeviceUtils: DeviceUtils, oneAPIDevice, reset_gpu_device! using oneAPI: oneAPI, oneArray, oneL0 const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() @@ -16,23 +16,23 @@ function __init__() end end -LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) +DeviceUtils.loaded(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) = true +function DeviceUtils.functional(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) return oneAPI.functional() end # Default RNG -LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray) +DeviceUtils.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -LuxDeviceUtils._get_device(::oneArray) = LuxoneAPIDevice() +DeviceUtils._get_device(::oneArray) = oneAPIDevice() -LuxDeviceUtils._get_device_type(::oneArray) = LuxoneAPIDevice +DeviceUtils._get_device_type(::oneArray) = oneAPIDevice # Device Transfer ## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) - @eval function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)}) + @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) if !SUPPORTS_FP64[oneAPI.device()] @warn LazyString( "Double type is not supported on this device. Using `", $(T2), "` instead.") @@ -41,6 +41,6 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) return oneArray(x) end end -Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x) +Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) = oneArray(x) end diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index 7f8efb3..e22feca 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device! +using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! using Random: Random __init__() = reset_gpu_device!() @@ -21,58 +21,58 @@ function _check_use_amdgpu!() return end -LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool +DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true +function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool _check_use_amdgpu!() return USE_AMD_GPU[] end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) - return LuxAMDGPUDevice(nothing) +function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) + return AMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer) +function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() AMDGPU.device!(AMDGPU.devices()[id]) - device = LuxAMDGPUDevice(AMDGPU.device()) + device = AMDGPUDevice(AMDGPU.device()) AMDGPU.device!(old_dev) return device end -LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) +DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) # Default RNG -LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() +DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -function LuxDeviceUtils._get_device(x::AMDGPU.AnyROCArray) +function DeviceUtils._get_device(x::AMDGPU.AnyROCArray) parent_x = parent(x) - parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) - return LuxDeviceUtils._get_device(parent_x) + parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) + return DeviceUtils._get_device(parent_x) end -LuxDeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice +DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice # Set Device -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) return AMDGPU.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer) - return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) + return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer) +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(AMDGPU.devices())) - return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) + return DeviceUtils.set_device!(AMDGPUDevice, id) end # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) -function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) +Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - dev = LuxDeviceUtils.get_device(x) - if !(dev isa LuxAMDGPUDevice) + dev = DeviceUtils.get_device(x) + if !(dev isa AMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) @@ -87,6 +87,6 @@ function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) end end -Adapt.adapt_storage(::LuxCPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() +Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() end diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index 8d86061..42f2a12 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -3,61 +3,61 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt using CUDA: CUDA using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice +using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice using Random: Random -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer) +function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() CUDA.device!(id - 1) - device = LuxCUDADevice(CUDA.device()) + device = CUDADevice(CUDA.device()) CUDA.device!(old_dev) return device end -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) - return LuxCUDADevice(nothing) +function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) + return CUDADevice(nothing) end -LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 +DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 # Default RNG -LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() +DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() # Query Device from Array -function LuxDeviceUtils._get_device(x::CUDA.AnyCuArray) +function DeviceUtils._get_device(x::CUDA.AnyCuArray) parent_x = parent(x) - parent_x === x && return LuxCUDADevice(CUDA.device(x)) - return LuxDeviceUtils.get_device(parent_x) + parent_x === x && return CUDADevice(CUDA.device(x)) + return DeviceUtils.get_device(parent_x) end -function LuxDeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) - return LuxCUDADevice(CUDA.device(x.nzVal)) +function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) + return CUDADevice(CUDA.device(x.nzVal)) end -function LuxDeviceUtils._get_device_type(::Union{ +function DeviceUtils._get_device_type(::Union{ <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) - return LuxCUDADevice + return CUDADevice end # Set Device -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) +function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer) - return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) +function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) + return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer) +function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(CUDA.devices())) - return LuxDeviceUtils.set_device!(LuxCUDADevice, id) + return DeviceUtils.set_device!(CUDADevice, id) end # Device Transfer -Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) -function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) +Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - dev = LuxDeviceUtils.get_device(x) - if !(dev isa LuxCUDADevice) + dev = DeviceUtils.get_device(x) + if !(dev isa CUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) @@ -72,19 +72,19 @@ function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) end end -Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() +Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() # Defining as extensions seems to case precompilation errors @static if isdefined(CUDA.CUSPARSE, :SparseArrays) - function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) end - function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) return CUDA.CUSPARSE.SparseArrays.SparseVector(x) end else @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ - an issue in LuxDeviceUtils.jl repository." + an issue in DeviceUtils.jl repository." end end diff --git a/ext/LuxDeviceUtilsFillArraysExt.jl b/ext/LuxDeviceUtilsFillArraysExt.jl deleted file mode 100644 index b596233..0000000 --- a/ext/LuxDeviceUtilsFillArraysExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsFillArraysExt - -using Adapt: Adapt -using FillArrays: FillArrays, AbstractFill -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice, AbstractLuxDevice - -Adapt.adapt_structure(::LuxCPUDevice, x::AbstractFill) = x -Adapt.adapt_structure(to::AbstractLuxDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) - -end diff --git a/ext/LuxDeviceUtilsGPUArraysExt.jl b/ext/LuxDeviceUtilsGPUArraysExt.jl deleted file mode 100644 index 1e8f9f9..0000000 --- a/ext/LuxDeviceUtilsGPUArraysExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsGPUArraysExt - -using Adapt: Adapt -using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxCPUDevice -using Random: Random - -Adapt.adapt_storage(::LuxCPUDevice, rng::GPUArrays.RNG) = Random.default_rng() - -end diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl deleted file mode 100644 index 4870710..0000000 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module LuxDeviceUtilsLuxCUDAExt - -using LuxCUDA: LuxCUDA -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) - return LuxCUDA.functional() -end - -end diff --git a/ext/LuxDeviceUtilsMetalExt.jl b/ext/LuxDeviceUtilsMetalExt.jl index b2e188a..49c86d2 100644 --- a/ext/LuxDeviceUtilsMetalExt.jl +++ b/ext/LuxDeviceUtilsMetalExt.jl @@ -2,26 +2,26 @@ module LuxDeviceUtilsMetalExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxMetalDevice, reset_gpu_device! +using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! using Metal: Metal, MtlArray __init__() = reset_gpu_device!() -LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) +DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true +function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) return Metal.functional() end # Default RNG -LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) +DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) # Query Device from Array -LuxDeviceUtils._get_device(::MtlArray) = LuxMetalDevice() +DeviceUtils._get_device(::MtlArray) = MetalDevice() -LuxDeviceUtils._get_device_type(::MtlArray) = LuxMetalDevice +DeviceUtils._get_device_type(::MtlArray) = MetalDevice # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxMetalDevice, x::AbstractArray) = Metal.mtl(x) +Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) end diff --git a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 201ee44..086c955 100644 --- a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,23 +1,23 @@ module LuxDeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice +using DeviceUtils: DeviceUtils, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure -function Adapt.adapt_structure(to::AbstractLuxDevice, x::VectorOfArray) +function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) end -function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) +function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) # Don't move the `time` to the GPU return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end for op in (:_get_device, :_get_device_type) - @eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) + @eval function DeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) - return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u) + return mapreduce(DeviceUtils.$op, DeviceUtils.__combine_devices, x.u) end end diff --git a/ext/LuxDeviceUtilsReverseDiffExt.jl b/ext/LuxDeviceUtilsReverseDiffExt.jl index 8a097d1..592fb49 100644 --- a/ext/LuxDeviceUtilsReverseDiffExt.jl +++ b/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -1,15 +1,15 @@ module LuxDeviceUtilsReverseDiffExt -using LuxDeviceUtils: LuxDeviceUtils +using DeviceUtils: DeviceUtils using ReverseDiff: ReverseDiff for op in (:_get_device, :_get_device_type) @eval begin - function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray) - return LuxDeviceUtils.$op(ReverseDiff.value(x)) + function DeviceUtils.$op(x::ReverseDiff.TrackedArray) + return DeviceUtils.$op(ReverseDiff.value(x)) end - function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return LuxDeviceUtils.$op(ReverseDiff.value.(x)) + function DeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return DeviceUtils.$op(ReverseDiff.value.(x)) end end end diff --git a/ext/LuxDeviceUtilsSparseArraysExt.jl b/ext/LuxDeviceUtilsSparseArraysExt.jl deleted file mode 100644 index f337d2f..0000000 --- a/ext/LuxDeviceUtilsSparseArraysExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module LuxDeviceUtilsSparseArraysExt - -using Adapt: Adapt -using LuxDeviceUtils: LuxCPUDevice -using SparseArrays: AbstractSparseArray - -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractSparseArray) = x - -end diff --git a/ext/LuxDeviceUtilsTrackerExt.jl b/ext/LuxDeviceUtilsTrackerExt.jl index d41e832..868a3e5 100644 --- a/ext/LuxDeviceUtilsTrackerExt.jl +++ b/ext/LuxDeviceUtilsTrackerExt.jl @@ -1,23 +1,23 @@ module LuxDeviceUtilsTrackerExt using Adapt: Adapt -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, - LuxoneAPIDevice +using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, + oneAPIDevice using Tracker: Tracker for op in (:_get_device, :_get_device_type) @eval begin - LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x)) - function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils.$op(Tracker.data.(x)) + DeviceUtils.$op(x::Tracker.TrackedArray) = DeviceUtils.$op(Tracker.data(x)) + function DeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) + return DeviceUtils.$op(Tracker.data.(x)) end end end -LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true -for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ to Tracker.TrackedArray." maxlog=1 diff --git a/ext/LuxDeviceUtilsZygoteExt.jl b/ext/LuxDeviceUtilsZygoteExt.jl deleted file mode 100644 index ae61dc4..0000000 --- a/ext/LuxDeviceUtilsZygoteExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsZygoteExt - -using Adapt: Adapt -using LuxDeviceUtils: AbstractLuxDevice, LuxCPUDevice -using Zygote: OneElement - -Adapt.adapt_structure(::LuxCPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractLuxDevice, x::OneElement) = Adapt.adapt(to, collect(x)) - -end diff --git a/src/LuxDeviceUtils.jl b/src/DeviceUtils.jl similarity index 75% rename from src/LuxDeviceUtils.jl rename to src/DeviceUtils.jl index f362ef0..a4861e4 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/DeviceUtils.jl @@ -1,4 +1,4 @@ -module LuxDeviceUtils +module DeviceUtils using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent @@ -13,19 +13,20 @@ const CRC = ChainRulesCore export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device -export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice + +export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type -abstract type AbstractLuxDevice <: Function end -abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end +abstract type AbstractDevice <: Function end +abstract type AbstractGPUDevice <: AbstractDevice end """ - functional(x::AbstractLuxDevice) -> Bool - functional(::Type{<:AbstractLuxDevice}) -> Bool + functional(x::AbstractDevice) -> Bool + functional(::Type{<:AbstractDevice}) -> Bool Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via -[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional. +[`DeviceUtils.loaded`](@ref)), the device may not be functional. Note that while this function is not exported, it is considered part of the public API. """ @@ -34,12 +35,12 @@ Note that while this function is not exported, it is considered part of the publ Base.@deprecate __is_functional(x) functional(x) """ - loaded(x::AbstractLuxDevice) -> Bool - loaded(::Type{<:AbstractLuxDevice}) -> Bool + loaded(x::AbstractDevice) -> Bool + loaded(::Type{<:AbstractDevice}) -> Bool Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - `LuxCUDA.jl` for NVIDIA CUDA Support. + - Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. - `AMDGPU.jl` for AMD GPU ROCM Support. - `Metal.jl` for Apple Metal GPU Support. - `oneAPI.jl` for Intel oneAPI GPU Support. @@ -48,17 +49,17 @@ Checks if the trigger package for the device is loaded. Trigger packages are as Base.@deprecate __is_loaded(x) loaded(x) -struct LuxCPUDevice <: AbstractLuxDevice end -@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice +struct CPUDevice <: AbstractDevice end +@kwdef struct CUDADevice{D} <: AbstractGPUDevice device::D = nothing end -@kwdef struct LuxAMDGPUDevice{D} <: AbstractLuxGPUDevice +@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice device::D = nothing end -struct LuxMetalDevice <: AbstractLuxGPUDevice end -struct LuxoneAPIDevice <: AbstractLuxGPUDevice end +struct MetalDevice <: AbstractGPUDevice end +struct oneAPIDevice <: AbstractGPUDevice end -for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) +for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @eval begin _with_device(::Type{$dev}, ::Nothing) = $dev() @@ -69,33 +70,33 @@ for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) end end -@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true +@inline loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name)) - ldev = eval(Symbol(:Lux, name, :Device)) + tpkg = name === :CPU ? "" : string(name) + ldev = eval(Symbol(name, :Device)) @eval begin @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) end end -for T in (LuxCPUDevice, LuxCUDADevice{Nothing}, - LuxAMDGPUDevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (CPUDevice, CUDADevice{Nothing}, + AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) @eval @inline _get_device_id(::$(T)) = nothing end -struct LuxDeviceSelectionException <: Exception end +struct DeviceSelectionException <: Exception end -function Base.showerror(io::IO, ::LuxDeviceSelectionException) - return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") +function Base.showerror(io::IO, ::DeviceSelectionException) + return print(io, "DeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice) +const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) -const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) +const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) """ reset_gpu_device!() @@ -113,18 +114,13 @@ Return a tuple of supported GPU backends. !!! warning This is not the list of functional backends on the system, but rather backends which - `Lux.jl` supports. - -!!! danger - - `Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not - expected to work. + `DeviceUtils.jl` supports. """ @inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ gpu_device(device_id::Union{Nothing, Integer}=nothing; - force_gpu_usage::Bool=false) -> AbstractLuxDevice() + force_gpu_usage::Bool=false) -> AbstractDevice() Selects GPU device based on the following criteria: @@ -151,21 +147,28 @@ Selects GPU device based on the following criteria: `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` and `CPU` backends, `device_id` is ignored and a warning is printed. +!!! warning + + `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. + This is to ensure that deep learning operations work correctly. + Nonetheless, if cuDNN is not loaded you can still manually create a + `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). + ## Keyword Arguments - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; - force_gpu_usage::Bool=false)::AbstractLuxDevice + force_gpu_usage::Bool=false)::AbstractDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) if GPU_DEVICE[] !== nothing dev = GPU_DEVICE[] if device_id === nothing force_gpu_usage && - !(dev isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) + !(dev isa AbstractGPUDevice) && + throw(DeviceSelectionException()) return dev else selected_device_id = _get_device_id(dev) @@ -228,24 +231,24 @@ function _get_gpu_device(; force_gpu_usage::Bool) end if force_gpu_usage - throw(LuxDeviceSelectionException()) + throw(DeviceSelectionException()) else @warn """No functional GPU backend found! Defaulting to CPU. 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. `LuxCUDA.jl` for NVIDIA CUDA Support. + a. Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. b. `AMDGPU.jl` for AMD GPU ROCM Support. c. `Metal.jl` for Apple Metal GPU Support. (Experimental) d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 - return LuxCPUDevice + return CPUDevice end end """ gpu_backend!() = gpu_backend!("") gpu_backend!(backend) = gpu_backend!(string(backend)) - gpu_backend!(backend::AbstractLuxGPUDevice) + gpu_backend!(backend::AbstractGPUDevice) gpu_backend!(backend::String) Creates a `LocalPreferences.toml` file with the desired GPU backend. @@ -257,7 +260,7 @@ If a new backend is successfully set, then the Julia session must be restarted f change to take effect. """ gpu_backend!(backend) = gpu_backend!(string(backend)) -gpu_backend!(backend::AbstractLuxGPUDevice) = gpu_backend!(_get_device_name(backend)) +gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(_get_device_name(backend)) gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @@ -285,20 +288,20 @@ function gpu_backend!(backend::String) end """ - cpu_device() -> LuxCPUDevice() + cpu_device() -> CPUDevice() -Return a `LuxCPUDevice` object which can be used to transfer data to CPU. +Return a `CPUDevice` object which can be used to transfer data to CPU. """ -@inline cpu_device() = LuxCPUDevice() +@inline cpu_device() = CPUDevice() """ - default_device_rng(::AbstractLuxDevice) + default_device_rng(::AbstractDevice) Returns the default RNG for the device. This can be used to directly generate parameters and states on the device using [WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). """ -function default_device_rng(D::AbstractLuxDevice) +function default_device_rng(D::AbstractDevice) return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ either because: @@ -306,14 +309,14 @@ function default_device_rng(D::AbstractLuxDevice) 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. """) end -default_device_rng(::LuxCPUDevice) = Random.default_rng() +default_device_rng(::CPUDevice) = Random.default_rng() # Dispatches for Different Data Structures # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("Lux$(dev)Device") + ldev = Symbol("$(dev)Device") @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} fn = Base.Fix1(Adapt.adapt, D) @@ -349,7 +352,7 @@ const GET_DEVICE_ADMONITIONS = """ # Query Device from Array """ - get_device(x) -> dev::AbstractLuxDevice | Exception | nothing + get_device(x) -> dev::AbstractDevice | Exception | Nothing If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. @@ -362,7 +365,7 @@ based on device type. function get_device end """ - get_device_type(x) -> Type{<:AbstractLuxDevice} | Exception | Type{Nothing} + get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} Similar to [`get_device`](@ref) but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead @@ -374,7 +377,7 @@ function get_device_type end for op in (:get_device, :get_device_type) _op = Symbol("_", op) - cpu_ret_val = op == :get_device ? LuxCPUDevice() : LuxCPUDevice + cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice @eval begin function $(op)(x) hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x) @@ -408,27 +411,27 @@ __recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number __combine_devices(::Nothing, ::Nothing) = nothing __combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing -__combine_devices(::Nothing, dev::AbstractLuxDevice) = dev -__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T -__combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev -__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractLuxDevice} = T -function __combine_devices(dev1::AbstractLuxDevice, dev2::AbstractLuxDevice) +__combine_devices(::Nothing, dev::AbstractDevice) = dev +__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +__combine_devices(dev::AbstractDevice, ::Nothing) = dev +__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +function __combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) dev1 == dev2 && return dev1 throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end -__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractLuxDevice} = T +__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T function __combine_devices( - ::Type{T1}, ::Type{T2}) where {T1 <: AbstractLuxDevice, T2 <: AbstractLuxDevice} + ::Type{T1}, ::Type{T2}) where {T1 <: AbstractDevice, T2 <: AbstractDevice} throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end # Set the device const SET_DEVICE_DOCS = """ -Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` -and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not +Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` +and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. - -Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. + +Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. """ const SET_DEVICE_DANGER = """ @@ -440,63 +443,56 @@ const SET_DEVICE_DANGER = """ """ """ - set_device!(T::Type{<:AbstractLuxDevice}, dev_or_id) + set_device!(T::Type{<:AbstractDevice}, dev_or_id) $SET_DEVICE_DOCS ## Arguments - - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `T::Type{<:AbstractDevice}`: The device type to set. - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it can be a `CuDevice`. If it is an integer, it is the device id to set. This is `1`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} - T === LuxCUDADevice && +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." - T === LuxAMDGPUDevice && + T === AMDGPUDevice && @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." - T === LuxMetalDevice && + T === MetalDevice && @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." - T === LuxoneAPIDevice && + T === oneAPIDevice && @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." - T === LuxCPUDevice && - @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." + T === CPUDevice && + @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." return end """ - set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer) + set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) $SET_DEVICE_DOCS ## Arguments - - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `T::Type{<:AbstractDevice}`: The device type to set. - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and must be `0`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} return set_device!(T, rank) end # Adapt Interface -# In older versions we had corresponding Adapt functions, rn we directly dispatch on the -# device type. -for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - dev = Symbol(:Lux, name, :Device) - adaptor = Symbol(:Lux, name, :Adaptor) - @eval Base.@deprecate $(adaptor) $(dev) true -end -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng -for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) @@ -505,15 +501,15 @@ for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) end end -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x +Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x # Prevent Ambiguity -for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end # Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) ∇adapt_storage = let x = x Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) end diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index a290807..f5c3766 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -1,33 +1,33 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(AMDGPUDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) - @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxAMDGPUDevice, nothing, 1) + @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + AMDGPUDevice, nothing, 1) end using AMDGPU @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @info "AMDGPU is functional" - @test gpu_device() isa LuxAMDGPUDevice - @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice + @test gpu_device() isa AMDGPUDevice + @test gpu_device(; force_gpu_usage=true) isa AMDGPUDevice else @info "AMDGPU is NOT functional" - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -40,13 +40,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array - rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG : + aType = DeviceUtils.functional(AMDGPUDevice) ? ROCArray : Array + rngType = DeviceUtils.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxAMDGPUDevice - @test get_device_type(ps_xpu) <: LuxAMDGPUDevice + @test get_device(ps_xpu) isa AMDGPUDevice + @test get_device_type(ps_xpu) <: AMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -60,7 +60,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -69,8 +69,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -86,7 +86,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -103,7 +103,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) dev2 = gpu_device(length(AMDGPU.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -123,18 +123,18 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) - x = rand(10, 10) |> LuxAMDGPUDevice() - @test get_device(x) isa LuxAMDGPUDevice - @test get_device_type(x) <: LuxAMDGPUDevice + if DeviceUtils.functional(AMDGPUDevice) + x = rand(10, 10) |> AMDGPUDevice() + @test get_device(x) isa AMDGPUDevice + @test get_device_type(x) <: AMDGPUDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxAMDGPUDevice - @test get_device_type(x_view) <: LuxAMDGPUDevice + @test get_device(x_view) isa AMDGPUDevice + @test get_device_type(x_view) <: AMDGPUDevice end end @testset "Multiple Devices AMDGPU" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -159,9 +159,9 @@ end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) for i in 1:10 - @test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i) + @test_nowarn DeviceUtils.set_device!(AMDGPUDevice, nothing, i) end end end diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index cd97a8e..9adfa2b 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -1,33 +1,33 @@ -using LuxDeviceUtils, Random, Functors, Test +using DeviceUtils, Random, Functors, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxCUDADevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(CUDADevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) - @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxCUDADevice, nothing, 1) + @test_throws Exception default_device_rng(CUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + CUDADevice, nothing, 1) end using LuxCUDA @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @info "LuxCUDA is functional" - @test gpu_device() isa LuxCUDADevice - @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice + @test gpu_device() isa CUDADevice + @test gpu_device(; force_gpu_usage=true) isa CUDADevice else @info "LuxCUDA is NOT functional" - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -40,12 +40,12 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array - rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG + aType = DeviceUtils.functional(CUDADevice) ? CuArray : Array + rngType = DeviceUtils.functional(CUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxCUDADevice - @test get_device_type(ps_xpu) <: LuxCUDADevice + @test get_device(ps_xpu) isa CUDADevice + @test get_device_type(ps_xpu) <: CUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -59,7 +59,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -68,8 +68,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -85,7 +85,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -100,22 +100,22 @@ using FillArrays, Zygote # Extensions Functors.@functor MyStruct data = MyStruct(rand(10)) - @test get_device(data) isa LuxCPUDevice - @test get_device_type(data) <: LuxCPUDevice + @test get_device(data) isa CPUDevice + @test get_device_type(data) <: CPUDevice data_dev = data |> device - if LuxDeviceUtils.functional(LuxCUDADevice) - @test get_device(data_dev) isa LuxCUDADevice - @test get_device_type(data_dev) <: LuxCUDADevice + if DeviceUtils.functional(CUDADevice) + @test get_device(data_dev) isa CUDADevice + @test get_device_type(data_dev) <: CUDADevice else - @test get_device(data_dev) isa LuxCPUDevice - @test get_device_type(data_dev) <: LuxCPUDevice + @test get_device(data_dev) isa CPUDevice + @test get_device_type(data_dev) <: CPUDevice end ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) - @test get_device(ps_mixed.st) isa LuxCPUDevice - @test get_device_type(ps_mixed.st) <: LuxCPUDevice - @test get_device(ps_mixed.c) isa LuxCPUDevice - @test get_device_type(ps_mixed.c) <: LuxCPUDevice + @test get_device(ps_mixed.st) isa CPUDevice + @test get_device_type(ps_mixed.st) <: CPUDevice + @test get_device(ps_mixed.c) isa CPUDevice + @test get_device_type(ps_mixed.c) <: CPUDevice @test_throws ArgumentError get_device(ps_mixed) @test_throws ArgumentError get_device_type(ps_mixed) @@ -125,7 +125,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) dev2 = gpu_device(length(CUDA.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -145,18 +145,18 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if LuxDeviceUtils.functional(LuxCUDADevice) - x = rand(10, 10) |> LuxCUDADevice() - @test get_device(x) isa LuxCUDADevice - @test get_device_type(x) <: LuxCUDADevice + if DeviceUtils.functional(CUDADevice) + x = rand(10, 10) |> CUDADevice() + @test get_device(x) isa CUDADevice + @test get_device_type(x) <: CUDADevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxCUDADevice - @test get_device_type(x_view) <: LuxCUDADevice + @test get_device(x_view) isa CUDADevice + @test get_device_type(x_view) <: CUDADevice end end @testset "Multiple Devices CUDA" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -183,7 +183,7 @@ end using SparseArrays @testset "CUDA Sparse Arrays" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -208,9 +208,9 @@ using SparseArrays end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) for i in 1:10 - @test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i) + @test_nowarn DeviceUtils.set_device!(CUDADevice, nothing, i) end end end diff --git a/test/metal_tests.jl b/test/metal_tests.jl index db5a2e1..ce97125 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -1,31 +1,31 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxMetalDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(MetalDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxMetalDevice()) + @test_throws Exception default_device_rng(MetalDevice()) end using Metal @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @info "Metal is functional" - @test gpu_device() isa LuxMetalDevice - @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice + @test gpu_device() isa MetalDevice + @test gpu_device(; force_gpu_usage=true) isa MetalDevice else @info "Metal is NOT functional" - @test gpu_device() isa LuxMetalDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa MetalDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,13 +38,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array - rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG : + aType = DeviceUtils.functional(MetalDevice) ? MtlArray : Array + rngType = DeviceUtils.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxMetalDevice - @test get_device_type(ps_xpu) <: LuxMetalDevice + @test get_device(ps_xpu) isa MetalDevice + @test get_device_type(ps_xpu) <: MetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -58,7 +58,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -67,8 +67,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -84,7 +84,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -109,20 +109,20 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if LuxDeviceUtils.functional(LuxMetalDevice) - x = rand(Float32, 10, 10) |> LuxMetalDevice() - @test get_device(x) isa LuxMetalDevice - @test get_device_type(x) <: LuxMetalDevice + if DeviceUtils.functional(MetalDevice) + x = rand(Float32, 10, 10) |> MetalDevice() + @test get_device(x) isa MetalDevice + @test get_device_type(x) <: MetalDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxMetalDevice - @test get_device_type(x_view) <: LuxMetalDevice + @test get_device(x_view) isa MetalDevice + @test get_device_type(x_view) <: MetalDevice end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test_logs (:warn, - "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxMetalDevice, nothing, 1) + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + MetalDevice, nothing, 1) end end diff --git a/test/misc_tests.jl b/test/misc_tests.jl index dd0ef8e..bbbd71c 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -1,12 +1,12 @@ -using Adapt, LuxDeviceUtils, ComponentArrays, Random +using Adapt, DeviceUtils, ComponentArrays, Random using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools using LuxCore -@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin - dev = LuxCPUDevice() +@testset "https://github.com/LuxDL/DeviceUtils.jl/issues/10 patch" begin + dev = CPUDevice() ps = (; weight=randn(10, 1), bias=randn(1)) ps_ca = ps |> ComponentArray @@ -25,23 +25,23 @@ end x = randn(Float32, 10) x_rdiff = ReverseDiff.track(x) - @test get_device(x_rdiff) isa LuxCPUDevice + @test get_device(x_rdiff) isa CPUDevice x_rdiff = ReverseDiff.track.(x) - @test get_device(x_rdiff) isa LuxCPUDevice + @test get_device(x_rdiff) isa CPUDevice gdev = gpu_device() x_tracker = Tracker.param(x) - @test get_device(x_tracker) isa LuxCPUDevice + @test get_device(x_tracker) isa CPUDevice x_tracker = Tracker.param.(x) - @test get_device(x_tracker) isa LuxCPUDevice + @test get_device(x_tracker) isa CPUDevice x_tracker_dev = Tracker.param(x) |> gdev @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) x_tracker_dev = Tracker.param.(x) |> gdev @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) x_fdiff = ForwardDiff.Dual.(x) - @test get_device(x_fdiff) isa LuxCPUDevice + @test get_device(x_fdiff) isa CPUDevice x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) end @@ -51,7 +51,7 @@ end test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) gdev = gpu_device() - if !(gdev isa LuxMetalDevice) # On intel devices causes problems + if !(gdev isa MetalDevice) # On intel devices causes problems x = randn(10) ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) @test ∂dev === nothing @@ -78,34 +78,34 @@ end gdev = gpu_device() diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) - @test get_device(diffeqarray) isa LuxCPUDevice + @test get_device(diffeqarray) isa CPUDevice diffeqarray_dev = diffeqarray |> gdev @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) vecarray = VectorOfArray([rand(10) for _ in 1:10]) - @test get_device(vecarray) isa LuxCPUDevice + @test get_device(vecarray) isa CPUDevice vecarray_dev = vecarray |> gdev @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) end @testset "CPU default rng" begin - @test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG + @test default_device_rng(CPUDevice()) isa Random.TaskLocalRNG end @testset "CPU setdevice!" begin @test_logs (:warn, - "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxCPUDevice, nothing, 1) + "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting.") DeviceUtils.set_device!( + CPUDevice, nothing, 1) end @testset "get_device on Arrays" begin x = rand(10, 10) x_view = view(x, 1:5, 1:5) - @test get_device(x) isa LuxCPUDevice - @test get_device(x_view) isa LuxCPUDevice + @test get_device(x) isa CPUDevice + @test get_device(x_view) isa CPUDevice struct MyArrayType <: AbstractArray{Float32, 2} data::Array{Float32, 2} @@ -113,22 +113,22 @@ end x_custom = MyArrayType(rand(10, 10)) - @test get_device(x_custom) isa LuxCPUDevice + @test get_device(x_custom) isa CPUDevice end @testset "loaded and functional" begin - @test LuxDeviceUtils.loaded(LuxCPUDevice) - @test LuxDeviceUtils.functional(LuxCPUDevice) + @test DeviceUtils.loaded(CPUDevice) + @test DeviceUtils.functional(CPUDevice) end @testset "writing to preferences" begin @test_logs (:info, "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() - for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), - LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), + CUDADevice(), MetalDevice(), oneAPIDevice()) backend_name = backend isa Symbol ? string(backend) : - LuxDeviceUtils._get_device_name(backend) + DeviceUtils._get_device_name(backend) @test_logs (:info, "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) end diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 40b3fb7..0394837 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -1,31 +1,31 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxoneAPIDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(oneAPIDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxoneAPIDevice()) + @test_throws Exception default_device_rng(oneAPIDevice()) end using oneAPI @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @info "oneAPI is functional" - @test gpu_device() isa LuxoneAPIDevice - @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice + @test gpu_device() isa oneAPIDevice + @test gpu_device(; force_gpu_usage=true) isa oneAPIDevice else @info "oneAPI is NOT functional" - @test gpu_device() isa LuxoneAPIDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa oneAPIDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,13 +38,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array - rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG : + aType = DeviceUtils.functional(oneAPIDevice) ? oneArray : Array + rngType = DeviceUtils.functional(oneAPIDevice) ? oneAPI.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxoneAPIDevice - @test get_device_type(ps_xpu) <: LuxoneAPIDevice + @test get_device(ps_xpu) isa oneAPIDevice + @test get_device_type(ps_xpu) <: oneAPIDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -58,7 +58,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @test ps_xpu.farray isa oneArray else @@ -67,8 +67,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -84,7 +84,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -109,20 +109,20 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if LuxDeviceUtils.functional(LuxoneAPIDevice) - x = rand(10, 10) |> LuxoneAPIDevice() - @test get_device(x) isa LuxoneAPIDevice - @test get_device_type(x) <: LuxoneAPIDevice + if DeviceUtils.functional(oneAPIDevice) + x = rand(10, 10) |> oneAPIDevice() + @test get_device(x) isa oneAPIDevice + @test get_device_type(x) <: oneAPIDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxoneAPIDevice - @test get_device_type(x_view) <: LuxoneAPIDevice + @test get_device(x_view) isa oneAPIDevice + @test get_device_type(x_view) <: oneAPIDevice end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test_logs (:warn, - "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxoneAPIDevice, nothing, 1) + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + oneAPIDevice, nothing, 1) end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index bc177fb..b08a873 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,17 +1,17 @@ -using Aqua, ExplicitImports, LuxDeviceUtils, Test +using Aqua, ExplicitImports, DeviceUtils, Test @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils) + Aqua.test_all(DeviceUtils) end import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @testset "Explicit Imports" begin - @test check_no_implicit_imports(LuxDeviceUtils) === nothing - @test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing - @test check_no_self_qualified_accesses(LuxDeviceUtils) === nothing - @test check_all_explicit_imports_via_owners(LuxDeviceUtils) === nothing - @test check_all_qualified_accesses_via_owners(LuxDeviceUtils) === nothing - @test_broken check_all_explicit_imports_are_public(LuxDeviceUtils) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(LuxDeviceUtils) === nothing # mostly upstream problem + @test check_no_implicit_imports(DeviceUtils) === nothing + @test check_no_stale_explicit_imports(DeviceUtils) === nothing + @test check_no_self_qualified_accesses(DeviceUtils) === nothing + @test check_all_explicit_imports_via_owners(DeviceUtils) === nothing + @test check_all_qualified_accesses_via_owners(DeviceUtils) === nothing + @test_broken check_all_explicit_imports_are_public(DeviceUtils) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(DeviceUtils) === nothing # mostly upstream problem end diff --git a/test/runtests.jl b/test/runtests.jl index 8b170d3..8448f4b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ import Pkg using SafeTestsets, Test -const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "NONE")) +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) const EXTRA_PKGS = String[] @@ -18,7 +18,7 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end -@testset "LuxDeviceUtils Tests" begin +@testset "DeviceUtils Tests" begin file_names = BACKEND_GROUP == "all" ? ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"])