Skip to content

Commit

Permalink
chore: run the formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 11, 2024
1 parent bcf9f2e commit e35d643
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
5 changes: 3 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
ReactantDevice
using ChainRules: OneElement

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end
Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
ReactantDevice
using Zygote: OneElement

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end

end

24 changes: 12 additions & 12 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ end
@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

@testset "Zygote and ChainRules OneElement #1016" begin
using Zygote

cpu = cpu_device()
gpu = gpu_device()

g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1, 2, 3])[1]
@test g isa Vector{Float32}
g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9])[1]
@test g isa Matrix{Float32}
end

@testset "OneHotArrays" begin
using OneHotArrays

Expand All @@ -241,15 +253,3 @@ end
@test x_rd isa Reactant.ConcreteRArray{Bool, 2}
end
end

@testset "Zygote and ChainRules OneElement" begin
# Issue #1016
using Zygote
cpu = cpu_device()
gpu = gpu_device()

g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1]
@test g isa Vector{Float32}
g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1]
@test g isa Matrix{Float32}
end

0 comments on commit e35d643

Please sign in to comment.