diff --git a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 1628b53..201ee44 100644 --- a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -14,14 +14,11 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end -function LuxDeviceUtils._get_device(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return nothing - return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) -end - -function LuxDeviceUtils._get_device_type(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return Nothing - return mapreduce(LuxDeviceUtils._get_device_type, LuxDeviceUtils.__combine_devices, x.u) +for op in (:_get_device, :_get_device_type) + @eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) + return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u) + end end end diff --git a/ext/LuxDeviceUtilsReverseDiffExt.jl b/ext/LuxDeviceUtilsReverseDiffExt.jl index f0d1b04..8a097d1 100644 --- a/ext/LuxDeviceUtilsReverseDiffExt.jl +++ b/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -1,11 +1,17 @@ module LuxDeviceUtilsReverseDiffExt -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice +using LuxDeviceUtils: LuxDeviceUtils using ReverseDiff: ReverseDiff -LuxDeviceUtils._get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice() -LuxDeviceUtils._get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice() -LuxDeviceUtils._get_device_type(::ReverseDiff.TrackedArray) = LuxCPUDevice -LuxDeviceUtils._get_device_type(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice +for op in (:_get_device, :_get_device_type) + @eval begin + function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray) + return LuxDeviceUtils.$op(ReverseDiff.value(x)) + end + function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return LuxDeviceUtils.$op(ReverseDiff.value.(x)) + end + end +end end diff --git a/ext/LuxDeviceUtilsTrackerExt.jl b/ext/LuxDeviceUtilsTrackerExt.jl index c68cebf..d41e832 100644 --- a/ext/LuxDeviceUtilsTrackerExt.jl +++ b/ext/LuxDeviceUtilsTrackerExt.jl @@ -5,21 +5,16 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDe LuxoneAPIDevice using Tracker: Tracker -@inline function LuxDeviceUtils._get_device(x::Tracker.TrackedArray) - return LuxDeviceUtils.get_device(Tracker.data(x)) -end -@inline function LuxDeviceUtils._get_device(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils.get_device(Tracker.data.(x)) -end - -@inline function LuxDeviceUtils._get_device_type(x::Tracker.TrackedArray) - return LuxDeviceUtils._get_device_type(Tracker.data(x)) -end -@inline function LuxDeviceUtils._get_device_type(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils._get_device_type(Tracker.data.(x)) +for op in (:_get_device, :_get_device_type) + @eval begin + LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x)) + function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils.$op(Tracker.data.(x)) + end + end end -@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice)