diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index e625dc1..38e6013 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -8,7 +8,7 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) -function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) +function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) dev = get_device(x) y = Adapt.adapt_storage(to, x) if dev === nothing || dev isa UnknownDevice diff --git a/src/public.jl b/src/public.jl index b6ee2c4..535557f 100644 --- a/src/public.jl +++ b/src/public.jl @@ -376,14 +376,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) end end -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - """ isleaf(x) -> Bool diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 41a8797..a771ada 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 1f95831..2fce480 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/metal_tests.jl b/test/metal_tests.jl index aeb596a..2bc8845 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 8bb6026..2169869 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG