Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Handle default rngs differently #20

Merged
merged 2 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1"
- "nightly"
adjustments:
- with:
julia: "1.6"
soft_fail: true
- with:
julia: "nightly"
soft_fail: true
Expand Down Expand Up @@ -77,15 +73,10 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1"
repo:
- "Lux"
- "Boltz"
adjustments:
- with:
julia: "1.6"
soft_fail: true

- group: ":julia: AMD GPU"
steps:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
matrix:
version:
- "1"
- "1.6"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.10"
version = "0.1.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -36,12 +35,11 @@ LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
LuxCore = "0.1.4"
Metal = "0.4, 0.5"
PackageExtensionCompat = "1"
Preferences = "1"
Random = "<0.0.1, 1"
SparseArrays = "<0.0.1, 1"
Zygote = "0.6"
julia = "1.6"
julia = "1.9"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
4 changes: 4 additions & 0 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
## To GPU
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)
adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ))
Expand Down
4 changes: 4 additions & 0 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
## To GPU
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()

## To CPU
adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x)
Expand Down
6 changes: 6 additions & 0 deletions ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ __init__() = reset_gpu_device!()
LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()

__default_rng() = Metal.GPUArrays.default_rng(MtlArray)

# Device Transfer
## To GPU
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)
adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::Metal.GPUArrays.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ))
Expand Down
5 changes: 0 additions & 5 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ module LuxDeviceUtils
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage

using PackageExtensionCompat
function __init__()
@require_extensions
end

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
Expand Down
7 changes: 3 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.14.1"
julia = "1.6"
13 changes: 6 additions & 7 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxAMDGPU.functional() ? ROCArray : Array
rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxAMDGPU.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxAMDGPU.functional()
Expand Down
13 changes: 6 additions & 7 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxCUDA.functional() ? CuArray : Array
rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxCUDA.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxCUDA.functional()
Expand Down
13 changes: 6 additions & 7 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = Metal.functional() ? MtlArray : Array
rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if Metal.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if Metal.functional()
Expand Down
41 changes: 9 additions & 32 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,29 @@ using LuxCore, LuxDeviceUtils

const GROUP = get(ENV, "GROUP", "CUDA")

@info "Installing Accelerator Packages..."

GROUP == "CUDA" && Pkg.add("LuxCUDA")

@static if VERSION ≥ v"1.9"
GROUP == "AMDGPU" && Pkg.add("LuxAMDGPU")

GROUP == "Metal" && Pkg.add("Metal")
else
if GROUP != "CUDA"
@warn "AMDGPU and Metal are only available on Julia 1.9+"
end
end

@info "Installed Accelerator Packages!"

@info "Starting Tests..."

@testset "LuxDeviceUtils Tests" begin
if GROUP == "CUDA"
@safetestset "CUDA" begin
include("cuda.jl")
end
end

@static if VERSION ≥ v"1.9"
if GROUP == "AMDGPU"
@safetestset "CUDA" begin
include("amdgpu.jl")
end
end

if GROUP == "Metal"
@safetestset "Metal" begin
include("metal.jl")
end
if GROUP == "AMDGPU"
@safetestset "CUDA" begin
include("amdgpu.jl")
end
end

if VERSION ≥ v"1.9"
@testset "Aqua Tests" begin
Aqua.test_all(LuxDeviceUtils; piracies=false)
if GROUP == "Metal"
@safetestset "Metal" begin
include("metal.jl")
end
end

@testset "Others" begin
@testset "Aqua Tests" begin
Aqua.test_all(LuxDeviceUtils)
end
@safetestset "Component Arrays" begin
include("component_arrays.jl")
end
Expand Down