Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: use fast paths for adapt
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent 89958db commit 79b395b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "1.4.2"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down Expand Up @@ -54,7 +53,6 @@ Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10, 11"
LinearAlgebra = "1.10"
MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Expand Down
3 changes: 1 addition & 2 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type
@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
Expand Down
1 change: 0 additions & 1 deletion src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using Compat: @compat
using LinearAlgebra: Transpose, Adjoint

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
Expand Down
16 changes: 3 additions & 13 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,21 +345,11 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
if isleaf(x)
(isbitstype(T) || Internal.special_aos(x)) && return Adapt.adapt(D, x)
return map(D, x)
elseif x isa Adapt.WrappedArray # fast path
return Adapt.adapt(D, x)
end
return Functors.fmap(D, x; exclude=isleaf)
end
# Fast Paths else we don't get type stability
function (D::$(ldev))(x::Transpose{T, <:AbstractArray{T}}) where {T}
return transpose(D(parent(x)))
end
function (D::$(ldev))(x::Adjoint{T, <:AbstractArray{T}}) where {T}
return adjoint(D(parent(x)))
end
function (D::$(ldev))(x::PermutedDimsArray{
T, N, perm, iperm, <:AbstractArray{T}}) where {T, N, perm, iperm}
y = D(parent(x))
return PermutedDimsArray{eltype(y), N, perm, iperm, typeof(y)}(y)
end

function (D::$(ldev))(x::Union{Tuple, NamedTuple})
isleaf(x) && map(D, x)
Expand Down Expand Up @@ -418,4 +408,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
isleaf(x) = Functors.isleaf(x)

isleaf(::AbstractArray{T}) where {T} = isbitstype(T)
isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
isleaf(::Adapt.WrappedArray) = false
13 changes: 13 additions & 0 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,16 @@ end
end
end
end

@testset "Zygote.gradient(wrapped arrays)" begin
using Zygote

x = rand(4, 4)
cdev = cpu_device()

@test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64}

gdev = gpu_device()

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

0 comments on commit 79b395b

Please sign in to comment.