Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
change to DeviceUtils.jl
Browse files Browse the repository at this point in the history
cleanup

cleanup

fix

improve docstring

require cuDNN

none

functional only if cuDNN is functional

separate cuDNN extension

cleanup
  • Loading branch information
CarloLucibello committed Jul 13, 2024
1 parent b914979 commit 30dcabc
Show file tree
Hide file tree
Showing 34 changed files with 719 additions and 419 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Manifest.toml
*.cov
generated
build
.vscode
Expand Down
31 changes: 16 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "LuxDeviceUtils"
name = "DeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.26"
Expand All @@ -17,28 +17,28 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
LuxDeviceUtilsAMDGPUExt = "AMDGPU"
LuxDeviceUtilsCUDAExt = "CUDA"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
LuxDeviceUtilsReverseDiffExt = "ReverseDiff"
LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsTrackerExt = "Tracker"
LuxDeviceUtilsZygoteExt = "Zygote"
LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]
DeviceUtilsAMDGPUExt = "AMDGPU"
DeviceUtilsCUDAExt = "CUDA"
DeviceUtilsFillArraysExt = "FillArrays"
DeviceUtilsGPUArraysExt = "GPUArrays"
DeviceUtilsMetalExt = ["GPUArrays", "Metal"]
DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
DeviceUtilsReverseDiffExt = "ReverseDiff"
DeviceUtilsSparseArraysExt = "SparseArrays"
DeviceUtilsTrackerExt = "Tracker"
DeviceUtilsZygoteExt = "Zygote"
DeviceUtilscuDNNExt = ["CUDA", "cuDNN"]
DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6"
Expand All @@ -54,7 +54,6 @@ FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
GPUArrays = "10"
LuxCUDA = "0.3.2"
LuxCore = "0.1.4"
Metal = "1"
Pkg = "1.10"
Expand All @@ -68,9 +67,11 @@ Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
oneAPI = "1.5"


[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# LuxDeviceUtils
# DeviceUtils

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/DeviceUtils)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/DeviceUtils)

[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl)
[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across
devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead.
`DeviceUtils.jl` is a lightweight package defining rules for transferring data across
devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/).

Currently we provide support for the following backends:

Expand Down
89 changes: 89 additions & 0 deletions ext/DeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
module DeviceUtilsAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()

const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)

function _check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return

USE_AMD_GPU[] = AMDGPU.functional()
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \
available." maxlog=1
end
return
end

DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
_check_use_amdgpu!()
return USE_AMD_GPU[]
end

function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing)
return AMDGPUDevice(nothing)
end
function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
AMDGPU.device!(AMDGPU.devices()[id])
device = AMDGPUDevice(AMDGPU.device())
AMDGPU.device!(old_dev)
return device
end

DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
function DeviceUtils.get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return DeviceUtils.get_device(parent_x)
end

# Set Device
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
return AMDGPU.device!(dev)
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer)
return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id])
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(AMDGPU.devices()))
return DeviceUtils.set_device!(AMDGPUDevice, id)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
dev = DeviceUtils.get_device(x)
if !(dev isa AMDGPUDevice)
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)
return x
else
AMDGPU.device!(to.device)
x_new = copy(x)
AMDGPU.device!(old_dev)
return x_new
end
end

Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

end
85 changes: 85 additions & 0 deletions ext/DeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
module DeviceUtilsCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice, reset_gpu_device!
using Random: Random

function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
CUDA.device!(id - 1)
device = CUDADevice(CUDA.device())
CUDA.device!(old_dev)
return device
end

function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing)
return CUDADevice(nothing)
end

DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1

# Default RNG
DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng()

# Query Device from Array
function DeviceUtils.get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return CUDADevice(CUDA.device(x))
return DeviceUtils.get_device(parent_x)
end
function DeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return CUDADevice(CUDA.device(x.nzVal))
end

# Set Device
function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer)
return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id])
end
function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(CUDA.devices()))
return DeviceUtils.set_device!(CUDADevice, id)
end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
dev = DeviceUtils.get_device(x)
if !(dev isa CUDADevice)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
return x_new
elseif dev.device == to.device
return x
else
CUDA.device!(to.device)
x_new = copy(x)
CUDA.device!(old_dev)
return x_new
end
end

Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng()

# Defining as extensions seems to case precompilation errors
@static if isdefined(CUDA.CUSPARSE, :SparseArrays)
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix)
return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x)
end
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector)
return CUDA.CUSPARSE.SparseArrays.SparseVector(x)
end
else
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \
an issue in DeviceUtils.jl repository."
end

end
10 changes: 10 additions & 0 deletions ext/DeviceUtilsFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DeviceUtilsFillArraysExt

using Adapt: Adapt
using FillArrays: FillArrays, AbstractFill
using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice

Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))

end
10 changes: 10 additions & 0 deletions ext/DeviceUtilsGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DeviceUtilsGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using DeviceUtils: CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

end
25 changes: 25 additions & 0 deletions ext/DeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module DeviceUtilsMetalExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device!
using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true
function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}})
return Metal.functional()
end

# Default RNG
DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
DeviceUtils.get_device(::MtlArray) = MetalDevice()

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)

end
21 changes: 21 additions & 0 deletions ext/DeviceUtilsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module DeviceUtilsRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using DeviceUtils: DeviceUtils, AbstractDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray)
return VectorOfArray(map(Base.Fix1(adapt, to), x.u))
end

function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray)
# Don't move the `time` to the GPU
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

function DeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray})
return mapreduce(DeviceUtils.get_device, DeviceUtils.__combine_devices, x.u)
end

end
13 changes: 13 additions & 0 deletions ext/DeviceUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DeviceUtilsReverseDiffExt

using DeviceUtils: DeviceUtils
using ReverseDiff: ReverseDiff

@inline function DeviceUtils.get_device(x::ReverseDiff.TrackedArray)
return DeviceUtils.get_device(ReverseDiff.value(x))
end
@inline function DeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal})
return DeviceUtils.get_device(ReverseDiff.value.(x))
end

end
9 changes: 9 additions & 0 deletions ext/DeviceUtilsSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module DeviceUtilsSparseArraysExt

using Adapt: Adapt
using DeviceUtils: CPUDevice
using SparseArrays: AbstractSparseArray

Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x

end
26 changes: 26 additions & 0 deletions ext/DeviceUtilsTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module DeviceUtilsTrackerExt

using Adapt: Adapt
using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice,
oneAPIDevice
using Tracker: Tracker

@inline function DeviceUtils.get_device(x::Tracker.TrackedArray)
return DeviceUtils.get_device(Tracker.data(x))
end
@inline function DeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal})
return DeviceUtils.get_device(Tracker.data.(x))
end

@inline DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal})
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \
to Tracker.TrackedArray." maxlog=1
return to(Tracker.collect(x))
end
end

end
Loading

0 comments on commit 30dcabc

Please sign in to comment.