Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: adapt ranges to JuliaGPU/Adapt.jl#86
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent c8ef590 commit 5ae1237
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type
@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
Expand Down
8 changes: 0 additions & 8 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
end
end

Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool
Expand Down
4 changes: 2 additions & 2 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down

0 comments on commit 5ae1237

Please sign in to comment.