From 89958db2ca41976cdf2976dc6dfe605d88966ad4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 15:37:19 -0400 Subject: [PATCH 1/3] fix: correctly handle adjoints of wrapped arrays --- Project.toml | 2 +- ext/MLDataDevicesChainRulesCoreExt.jl | 19 +++++++++++-------- src/public.jl | 25 ++++++++++++++++++++++--- test/misc_tests.jl | 3 +-- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index c85cb0d..391724d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.1" +version = "1.4.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index 6a770b8..2b8c9c8 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -1,7 +1,7 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt -using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @@ -10,15 +10,18 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type function ChainRulesCore.rrule( ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let dev = get_device(x) - if dev === nothing || dev isa UnknownDevice - @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 - Δ -> (NoTangent(), NoTangent(), Δ) - else - Δ -> (NoTangent(), NoTangent(), dev(Δ)) + dev = get_device(x) + y = Adapt.adapt_storage(to, x) + if dev === nothing || dev isa UnknownDevice + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + ∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ) + return y, ∇adapt_storage_unknown + else + ∇adapt_storage = let dev = dev, x = x + Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ))) end + return Adapt.adapt_storage(to, x), ∇adapt_storage end - return Adapt.adapt_storage(to, x), ∇adapt_storage end end diff --git a/src/public.jl b/src/public.jl index 104a424..6f7c8b8 100644 --- a/src/public.jl +++ b/src/public.jl @@ -342,10 +342,29 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} - return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : - map(D, x) + if isleaf(x) + (isbitstype(T) || Internal.special_aos(x)) && return Adapt.adapt(D, x) + return map(D, x) + end + return Functors.fmap(D, x; exclude=isleaf) + end + # Fast Paths else we don't get type stability + function (D::$(ldev))(x::Transpose{T, <:AbstractArray{T}}) where {T} + return transpose(D(parent(x))) + end + function (D::$(ldev))(x::Adjoint{T, <:AbstractArray{T}}) where {T} + return adjoint(D(parent(x))) + end + function (D::$(ldev))(x::PermutedDimsArray{ + T, N, perm, iperm, <:AbstractArray{T}}) where {T, N, perm, iperm} + y = D(parent(x)) + return PermutedDimsArray{eltype(y), N, perm, iperm, typeof(y)}(y) + end + + function (D::$(ldev))(x::Union{Tuple, NamedTuple}) + isleaf(x) && map(D, x) + return Functors.fmap(D, x; exclude=isleaf) end - (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x; exclude=isleaf) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 9bec386..a1023cb 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -181,7 +181,6 @@ end end @testset "shared parameters" begin - # from x = rand(1) m = (; a=x, b=x') count = Ref(0) @@ -199,7 +198,7 @@ end y::Float64 end - for x in [1.0, 'a', BitsType(1, 2.0)] + @testset for x in [1.0, 'a', BitsType(1, 2.0)] @test MLDataDevices.isleaf([x]) @test !MLDataDevices.isleaf([x]') @test !MLDataDevices.isleaf(transpose([x])) From c8ef5908a3502352cf323a96de583ee5ef4ea258 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 15:56:34 -0400 Subject: [PATCH 2/3] fix: use fast paths for adapt --- Project.toml | 2 -- ext/MLDataDevicesChainRulesCoreExt.jl | 3 +-- src/MLDataDevices.jl | 1 - src/public.jl | 26 +++++--------------------- test/misc_tests.jl | 15 ++++++++++++++- 5 files changed, 20 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 391724d..7d94339 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "1.4.2" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -54,7 +53,6 @@ Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" -LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index 2b8c9c8..e625dc1 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -8,8 +8,7 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) +function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) dev = get_device(x) y = Adapt.adapt_storage(to, x) if dev === nothing || dev isa UnknownDevice diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index c837887..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -5,7 +5,6 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat -using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/src/public.jl b/src/public.jl index 6f7c8b8..b6ee2c4 100644 --- a/src/public.jl +++ b/src/public.jl @@ -342,29 +342,13 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} - if isleaf(x) - (isbitstype(T) || Internal.special_aos(x)) && return Adapt.adapt(D, x) - return map(D, x) + if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray + return Adapt.adapt(D, x) end - return Functors.fmap(D, x; exclude=isleaf) - end - # Fast Paths else we don't get type stability - function (D::$(ldev))(x::Transpose{T, <:AbstractArray{T}}) where {T} - return transpose(D(parent(x))) - end - function (D::$(ldev))(x::Adjoint{T, <:AbstractArray{T}}) where {T} - return adjoint(D(parent(x))) - end - function (D::$(ldev))(x::PermutedDimsArray{ - T, N, perm, iperm, <:AbstractArray{T}}) where {T, N, perm, iperm} - y = D(parent(x)) - return PermutedDimsArray{eltype(y), N, perm, iperm, typeof(y)}(y) + return map(D, x) end - function (D::$(ldev))(x::Union{Tuple, NamedTuple}) - isleaf(x) && map(D, x) - return Functors.fmap(D, x; exclude=isleaf) - end + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x; exclude=isleaf) @@ -418,4 +402,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct isleaf(x) = Functors.isleaf(x) isleaf(::AbstractArray{T}) where {T} = isbitstype(T) -isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false +isleaf(::Adapt.WrappedArray) = false diff --git a/test/misc_tests.jl b/test/misc_tests.jl index a1023cb..89474b5 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -50,7 +50,7 @@ end @testset "CRC Tests" begin dev = cpu_device() # Other devices don't work with FiniteDifferences.jl - test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true) gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems @@ -206,3 +206,16 @@ end end end end + +@testset "Zygote.gradient(wrapped arrays)" begin + using Zygote + + x = rand(4, 4) + cdev = cpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + + gdev = gpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} +end From bb8388d6391f6998fc23d6d019fa283c7ec0f72f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 16:14:53 -0400 Subject: [PATCH 3/3] fix: adapt ranges to https://github.com/JuliaGPU/Adapt.jl/pull/86 --- Project.toml | 2 +- ext/MLDataDevicesChainRulesCoreExt.jl | 3 ++- src/public.jl | 9 --------- test/amdgpu_tests.jl | 4 ++-- test/cuda_tests.jl | 4 ++-- test/metal_tests.jl | 4 ++-- test/misc_tests.jl | 4 ++-- test/oneapi_tests.jl | 4 ++-- 8 files changed, 13 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 7d94339..68d4325 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6, 1" -Adapt = "4" +Adapt = "4.1" CUDA = "5.2" ChainRulesCore = "1.23" Compat = "4.15" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index e625dc1..518ff20 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -12,7 +12,8 @@ function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::Abst dev = get_device(x) y = Adapt.adapt_storage(to, x) if dev === nothing || dev isa UnknownDevice - @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + dev isa UnknownDevice && + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 ∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ) return y, ∇adapt_storage_unknown else diff --git a/src/public.jl b/src/public.jl index b6ee2c4..6440ddb 100644 --- a/src/public.jl +++ b/src/public.jl @@ -347,7 +347,6 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end return map(D, x) end - (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) @@ -376,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) end end -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - """ isleaf(x) -> Bool diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 41a8797..a771ada 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 1f95831..2fce480 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/metal_tests.jl b/test/metal_tests.jl index aeb596a..2bc8845 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 89474b5..28275d3 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -55,12 +55,12 @@ end gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems x = randn(10) - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) @test ∂dev === nothing @test ∂x ≈ ones(10) x = randn(10) |> gdev - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) @test ∂dev === nothing @test ∂x ≈ gdev(ones(10)) @test get_device(∂x) isa parameterless_type(typeof(gdev)) diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 8bb6026..2169869 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG