diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index c72c2fe70..f9893771c 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.3" +version = "1.6.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index 92a29e6d3..ea44806db 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -77,7 +77,11 @@ function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) end # Device Transfer -Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) + MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x + return AMDGPU.roc(x) +end + function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device dev = MLDataDevices.get_device(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index b96056bda..d0f3c55ea 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -57,7 +57,10 @@ function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) end # Device Transfer -Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) + MLDataDevices.get_device_type(x) <: CUDADevice && return x + return CUDA.cu(x) +end function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl index 6e03abc51..e6c7da3f4 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -29,6 +29,9 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) end # Device Transfer -Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) +function Adapt.adapt_storage(::MetalDevice, x::AbstractArray) + MLDataDevices.get_device_type(x) <: MetalDevice && return x + return Metal.mtl(x) +end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index 600c80e1f..2ce1579d8 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -42,6 +42,7 @@ end # Device Transfer for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) + MLDataDevices.get_device_type(x) <: oneAPIDevice && return x if !SUPPORTS_FP64[oneAPI.device()] @warn LazyString( "Double type is not supported on this device. Using `", $(T2), "` instead.") @@ -50,6 +51,9 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) return oneArray(x) end end -Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) = oneArray(x) +function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) + MLDataDevices.get_device_type(x) <: oneAPIDevice && return x + return oneArray(x) +end end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 6b87ef422..068b8abf9 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -375,7 +375,10 @@ for op in (:get_device, :get_device_type) end # Adapt Interface -Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x) +function Adapt.adapt_storage(::CPUDevice, x::AbstractArray) + get_device_type(x) <: CPUDevice && return x + return Array(x) +end Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to) Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index a771ada6e..210d8c117 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -124,6 +124,12 @@ using FillArrays, Zygote # Extensions return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} end + + @testset "Issue #1129: no new object" begin + x = rand(Float32, 10, 10) |> device + y = x |> device + @test x === y + end end @testset "Functions" begin diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 2fce4806a..538048b1a 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -149,6 +149,12 @@ using FillArrays, Zygote # Extensions return_val2(x) = Val(get_device(x)) @test_throws ErrorException @inferred(return_val2(ps)) end + + @testset "Issue #1129: no new object" begin + x = rand(Float32, 10, 10) |> device + y = x |> device + @test x === y + end end @testset "Functions" begin diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 2bc884553..329d3fc2f 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -113,6 +113,12 @@ using FillArrays, Zygote # Extensions return_val2(x) = Val(get_device(x)) @test @inferred(return_val2(ps)) isa Val{get_device(x)} end + + @testset "Issue #1129: no new object" begin + x = rand(Float32, 10, 10) |> device + y = x |> device + @test x === y + end end @testset "Functions" begin diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 2169869d3..0fb707ad7 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -113,6 +113,12 @@ using FillArrays, Zygote # Extensions return_val2(x) = Val(get_device(x)) @test @inferred(return_val2(ps)) isa Val{get_device(x)} end + + @testset "Issue #1129: no new object" begin + x = rand(Float32, 10, 10) |> device + y = x |> device + @test x === y + end end @testset "Functions" begin diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index dd59af96e..30377c828 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -113,6 +113,12 @@ using FillArrays, Zygote # Extensions return_val2(x) = Val(get_device(x)) @test_throws TypeError @inferred(return_val2(ps)) end + + @testset "Issue #1129: no new object" begin + x = rand(Float32, 10, 10) |> device + y = x |> device + @test x === y + end end @testset "Functions" begin