Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #20 from LuxDL/ap/rngs
Browse files Browse the repository at this point in the history
Handle default rngs differently
  • Loading branch information
avik-pal authored Dec 17, 2023
2 parents 9bb0859 + 1e743a6 commit 64803e0
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 76 deletions.
9 changes: 0 additions & 9 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1"
- "nightly"
adjustments:
- with:
julia: "1.6"
soft_fail: true
- with:
julia: "nightly"
soft_fail: true
Expand Down Expand Up @@ -77,15 +73,10 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1"
repo:
- "Lux"
- "Boltz"
adjustments:
- with:
julia: "1.6"
soft_fail: true

- group: ":julia: AMD GPU"
steps:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
matrix:
version:
- "1"
- "1.6"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.10"
version = "0.1.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -36,12 +35,11 @@ LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
LuxCore = "0.1.4"
Metal = "0.4, 0.5"
PackageExtensionCompat = "1"
Preferences = "1"
Random = "<0.0.1, 1"
SparseArrays = "<0.0.1, 1"
Zygote = "0.6"
julia = "1.6"
julia = "1.9"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
4 changes: 4 additions & 0 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
## To GPU
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)
adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ))
Expand Down
4 changes: 4 additions & 0 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
## To GPU
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()

## To CPU
adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x)
Expand Down
6 changes: 6 additions & 0 deletions ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ __init__() = reset_gpu_device!()
LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()

__default_rng() = Metal.GPUArrays.default_rng(MtlArray)

# Device Transfer
## To GPU
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)
adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng()

## Is this a correct thing to do?
adapt_storage(::LuxCPUAdaptor, rng::Metal.GPUArrays.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ))
Expand Down
5 changes: 0 additions & 5 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ module LuxDeviceUtils
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage

using PackageExtensionCompat
function __init__()
@require_extensions
end

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
Expand Down
7 changes: 3 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.14.1"
julia = "1.6"
13 changes: 6 additions & 7 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxAMDGPU.functional() ? ROCArray : Array
rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxAMDGPU.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxAMDGPU.functional()
Expand Down
13 changes: 6 additions & 7 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxCUDA.functional() ? CuArray : Array
rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxCUDA.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxCUDA.functional()
Expand Down
13 changes: 6 additions & 7 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ end
using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1),
b=ones(10, 1),
e=:c,
d="string",
rng=Random.default_rng(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)),
farray=Fill(1.0f0, (2, 3)))
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = Metal.functional() ? MtlArray : Array
rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if Metal.functional()
Expand All @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if Metal.functional()
Expand Down
41 changes: 9 additions & 32 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,29 @@ using LuxCore, LuxDeviceUtils

const GROUP = get(ENV, "GROUP", "CUDA")

@info "Installing Accelerator Packages..."

GROUP == "CUDA" && Pkg.add("LuxCUDA")

@static if VERSION v"1.9"
GROUP == "AMDGPU" && Pkg.add("LuxAMDGPU")

GROUP == "Metal" && Pkg.add("Metal")
else
if GROUP != "CUDA"
@warn "AMDGPU and Metal are only available on Julia 1.9+"
end
end

@info "Installed Accelerator Packages!"

@info "Starting Tests..."

@testset "LuxDeviceUtils Tests" begin
if GROUP == "CUDA"
@safetestset "CUDA" begin
include("cuda.jl")
end
end

@static if VERSION v"1.9"
if GROUP == "AMDGPU"
@safetestset "CUDA" begin
include("amdgpu.jl")
end
end

if GROUP == "Metal"
@safetestset "Metal" begin
include("metal.jl")
end
if GROUP == "AMDGPU"
@safetestset "CUDA" begin
include("amdgpu.jl")
end
end

if VERSION v"1.9"
@testset "Aqua Tests" begin
Aqua.test_all(LuxDeviceUtils; piracies=false)
if GROUP == "Metal"
@safetestset "Metal" begin
include("metal.jl")
end
end

@testset "Others" begin
@testset "Aqua Tests" begin
Aqua.test_all(LuxDeviceUtils)
end
@safetestset "Component Arrays" begin
include("component_arrays.jl")
end
Expand Down

2 comments on commit 64803e0

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97267

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.11 -m "<description of version>" 64803e08bf7e342027583f1ce21971b4fe614464
git push origin v0.1.11

Please sign in to comment.