Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: clean up device and type code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent 9c96ad9 commit 6b4d1d7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
13 changes: 5 additions & 8 deletions ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray)
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

function LuxDeviceUtils._get_device(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return nothing
return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u)
end

function LuxDeviceUtils._get_device_type(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return Nothing
return mapreduce(LuxDeviceUtils._get_device_type, LuxDeviceUtils.__combine_devices, x.u)
for op in (:_get_device, :_get_device_type)
@eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u)
end
end

end
16 changes: 11 additions & 5 deletions ext/LuxDeviceUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
module LuxDeviceUtilsReverseDiffExt

using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice
using LuxDeviceUtils: LuxDeviceUtils
using ReverseDiff: ReverseDiff

LuxDeviceUtils._get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice()
LuxDeviceUtils._get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice()
LuxDeviceUtils._get_device_type(::ReverseDiff.TrackedArray) = LuxCPUDevice
LuxDeviceUtils._get_device_type(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice
for op in (:_get_device, :_get_device_type)
@eval begin
function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray)
return LuxDeviceUtils.$op(ReverseDiff.value(x))
end
function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return LuxDeviceUtils.$op(ReverseDiff.value.(x))
end
end
end

end
21 changes: 8 additions & 13 deletions ext/LuxDeviceUtilsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDe
LuxoneAPIDevice
using Tracker: Tracker

@inline function LuxDeviceUtils._get_device(x::Tracker.TrackedArray)
return LuxDeviceUtils.get_device(Tracker.data(x))
end
@inline function LuxDeviceUtils._get_device(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils.get_device(Tracker.data.(x))
end

@inline function LuxDeviceUtils._get_device_type(x::Tracker.TrackedArray)
return LuxDeviceUtils._get_device_type(Tracker.data(x))
end
@inline function LuxDeviceUtils._get_device_type(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils._get_device_type(Tracker.data.(x))
for op in (:_get_device, :_get_device_type)
@eval begin
LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x))
function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils.$op(Tracker.data.(x))
end
end
end

@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true
LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice,
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice)
Expand Down

0 comments on commit 6b4d1d7

Please sign in to comment.