Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: correctly handle adjoints of wrapped arrays #90

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.1"
version = "1.4.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down Expand Up @@ -47,14 +46,13 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6, 1"
Adapt = "4"
Adapt = "4.1"
CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10, 11"
LinearAlgebra = "1.10"
MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Expand Down
21 changes: 12 additions & 9 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
module MLDataDevicesChainRulesCoreExt

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable

using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
dev isa UnknownDevice &&
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ)
return y, ∇adapt_storage_unknown
else
∇adapt_storage = let dev = dev, x = x
Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ)))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

end
1 change: 0 additions & 1 deletion src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using Compat: @compat
using LinearAlgebra: Transpose, Adjoint

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
Expand Down
16 changes: 5 additions & 11 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,10 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) :
map(D, x)
if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray
return Adapt.adapt(D, x)
end
return map(D, x)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
Expand Down Expand Up @@ -373,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
end
end

Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool

Expand All @@ -399,4 +393,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
isleaf(x) = Functors.isleaf(x)

isleaf(::AbstractArray{T}) where {T} = isbitstype(T)
isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
isleaf(::Adapt.WrappedArray) = false
4 changes: 2 additions & 2 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
22 changes: 17 additions & 5 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ end

@testset "CRC Tests" begin
dev = cpu_device() # Other devices don't work with FiniteDifferences.jl
test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true)
test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true)

gdev = gpu_device()
if !(gdev isa MetalDevice) # On intel devices causes problems
x = randn(10)
∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x)
∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x)
@test ∂dev === nothing
@test ∂x ≈ ones(10)

x = randn(10) |> gdev
∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x)
∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x)
@test ∂dev === nothing
@test ∂x ≈ gdev(ones(10))
@test get_device(∂x) isa parameterless_type(typeof(gdev))
Expand Down Expand Up @@ -181,7 +181,6 @@ end
end

@testset "shared parameters" begin
# from
x = rand(1)
m = (; a=x, b=x')
count = Ref(0)
Expand All @@ -199,11 +198,24 @@ end
y::Float64
end

for x in [1.0, 'a', BitsType(1, 2.0)]
@testset for x in [1.0, 'a', BitsType(1, 2.0)]
@test MLDataDevices.isleaf([x])
@test !MLDataDevices.isleaf([x]')
@test !MLDataDevices.isleaf(transpose([x]))
@test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2)))
end
end
end

@testset "Zygote.gradient(wrapped arrays)" begin
using Zygote

x = rand(4, 4)
cdev = cpu_device()

@test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64}

gdev = gpu_device()

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end
4 changes: 2 additions & 2 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
Loading