Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

refactor: general cleanup to follow Lux structure #68

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ steps:
agents:
queue: "juliagpu"
cuda: "*"
env:
RETESTITEMS_NWORKERS: 2
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 60
matrix:
Expand Down Expand Up @@ -161,9 +159,4 @@ steps:
- "1"

env:
RETESTITEMS_NWORKERS: 8
RETESTITEMS_NWORKER_THREADS: 2
RETESTITEMS_TESTITEM_TIMEOUT: 3600
JULIA_PKG_SERVER: ""
JULIA_NUM_THREADS: 4
SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw=="
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,3 @@ jobs:

env:
BACKEND_GROUP: "CPU"
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
32 changes: 1 addition & 31 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.1"
version = "1.0.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -42,50 +42,20 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
[compat]
AMDGPU = "0.9.6, 1"
Adapt = "4"
Aqua = "0.8.4"
ArrayInterface = "7.11"
CUDA = "5.2"
ChainRulesCore = "1.23"
ChainRulesTestUtils = "1.13.0"
ComponentArrays = "0.15.8"
ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
GPUArrays = "10"
Metal = "1"
Pkg = "1.10"
Preferences = "1.4"
Random = "1.10"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
oneAPI = "1.5"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "Tracker", "Zygote"]
20 changes: 9 additions & 11 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ module MLDataDevicesAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()

# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package.
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)

function _check_use_amdgpu!()
function check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return

USE_AMD_GPU[] = AMDGPU.functional()
Expand All @@ -23,14 +23,12 @@ end

MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
_check_use_amdgpu!()
check_use_amdgpu!()
return USE_AMD_GPU[]
end

function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing)
return AMDGPUDevice(nothing)
end
function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer)
Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing)
function Internal.with_device(::Type{AMDGPUDevice}, id::Integer)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
Expand All @@ -40,19 +38,19 @@ function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer)
return device
end

MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)
Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
function MLDataDevices._get_device(x::AMDGPU.AnyROCArray)
function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return MLDataDevices._get_device(parent_x)
return Internal.get_device(parent_x)
end

MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
28 changes: 9 additions & 19 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ module MLDataDevicesCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector, AbstractCuSparseArray
using MLDataDevices: MLDataDevices, Internal, CUDADevice, CPUDevice
using Random: Random

function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer)
Internal.with_device(::Type{CUDADevice}, ::Nothing) = CUDADevice(nothing)
function Internal.with_device(::Type{CUDADevice}, id::Integer)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
Expand All @@ -16,34 +17,23 @@ function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer)
return device
end

function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing)
return CUDADevice(nothing)
end

MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1
Internal.get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1

# Default RNG
MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng()

# Query Device from Array
function MLDataDevices._get_device(x::CUDA.AnyCuArray)
function Internal.get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return CUDADevice(CUDA.device(x))
return MLDataDevices.get_device(parent_x)
end
function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return CUDADevice(CUDA.device(x.nzVal))
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))

function MLDataDevices._get_device_type(::Union{
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray})
return CUDADevice
end
Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice

# Set Device
function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer)
return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id])
end
Expand Down
10 changes: 4 additions & 6 deletions ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device!
using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true
function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}})
return Metal.functional()
end
MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) = Metal.functional()

Check warning on line 11 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L11

Added line #L11 was not covered by tests

# Default RNG
MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
MLDataDevices._get_device(::MtlArray) = MetalDevice()
Internal.get_device(::MtlArray) = MetalDevice()

Check warning on line 17 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L17

Added line #L17 was not covered by tests

MLDataDevices._get_device_type(::MtlArray) = MetalDevice
Internal.get_device_type(::MtlArray) = MetalDevice

Check warning on line 19 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L19

Added line #L19 was not covered by tests

# Device Transfer
## To GPU
Expand Down
10 changes: 5 additions & 5 deletions ext/MLDataDevicesRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MLDataDevicesRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using MLDataDevices: MLDataDevices, AbstractDevice
using MLDataDevices: MLDataDevices, Internal, AbstractDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
Expand All @@ -14,10 +14,10 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray)
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

for op in (:_get_device, :_get_device_type)
@eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u)
for op in (:get_device, :get_device_type)
@eval function Internal.$(op)(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :get_device ? nothing : Nothing)
return mapreduce(Internal.$(op), Internal.combine_devices, x.u)
end
end

Expand Down
12 changes: 4 additions & 8 deletions ext/MLDataDevicesReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
module MLDataDevicesReverseDiffExt

using MLDataDevices: MLDataDevices
using MLDataDevices: Internal
using ReverseDiff: ReverseDiff

for op in (:_get_device, :_get_device_type)
for op in (:get_device, :get_device_type)
@eval begin
function MLDataDevices.$op(x::ReverseDiff.TrackedArray)
return MLDataDevices.$op(ReverseDiff.value(x))
end
function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return MLDataDevices.$op(ReverseDiff.value.(x))
end
Internal.$(op)(x::ReverseDiff.TrackedArray) = Internal.$(op)(ReverseDiff.value(x))
Internal.$(op)(x::AbstractArray{<:ReverseDiff.TrackedReal}) = Internal.$(op)(ReverseDiff.value.(x))
end
end

Expand Down
14 changes: 5 additions & 9 deletions ext/MLDataDevicesTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
module MLDataDevicesTrackerExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice
using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice
using Tracker: Tracker

for op in (:_get_device, :_get_device_type)
@eval begin
MLDataDevices.$op(x::Tracker.TrackedArray) = MLDataDevices.$op(Tracker.data(x))
function MLDataDevices.$op(x::AbstractArray{<:Tracker.TrackedReal})
return MLDataDevices.$op(Tracker.data.(x))
end
end
for op in (:get_device, :get_device_type)
@eval Internal.$(op)(x::Tracker.TrackedArray) = Internal.$(op)(Tracker.data(x))
@eval Internal.$(op)(x::AbstractArray{<:Tracker.TrackedReal}) = Internal.$(op)(Tracker.data.(x))
end

MLDataDevices.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true
Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
Expand Down
6 changes: 3 additions & 3 deletions ext/MLDataDevicesoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MLDataDevicesoneAPIExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: MLDataDevices, oneAPIDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, oneAPIDevice, reset_gpu_device!
using oneAPI: oneAPI, oneArray, oneL0

const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}()
Expand All @@ -25,9 +25,9 @@ end
MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray)

# Query Device from Array
MLDataDevices._get_device(::oneArray) = oneAPIDevice()
Internal.get_device(::oneArray) = oneAPIDevice()

MLDataDevices._get_device_type(::oneArray) = oneAPIDevice
Internal.get_device_type(::oneArray) = oneAPIDevice

# Device Transfer
## To GPU
Expand Down
Loading
Loading