From a60f5ee114ff7c3dcfbee537aa4e60a238ddba19 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 22:18:49 -0700 Subject: [PATCH] refactor: move ForwardDiff.jl into main deps --- Project.toml | 6 +-- ext/LuxLibForwardDiffExt.jl | 85 ------------------------------------- src/LuxLib.jl | 2 + src/impl/forward_diff.jl | 50 ++++++++++++++++++++++ src/impl/fused_conv.jl | 13 ++++++ src/utils.jl | 7 +++ 6 files changed, 74 insertions(+), 89 deletions(-) delete mode 100644 ext/LuxLibForwardDiffExt.jl create mode 100644 src/impl/forward_diff.jl diff --git a/Project.toml b/Project.toml index ff1b8255..01ab63ea 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" @@ -23,7 +24,6 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -31,7 +31,6 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] LuxLibAMDGPUExt = "AMDGPU" LuxLibCUDAExt = "CUDA" -LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" @@ -76,7 +75,6 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -89,4 +87,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl deleted file mode 100644 index 20ca3054..00000000 --- a/ext/LuxLibForwardDiffExt.jl +++ /dev/null @@ -1,85 +0,0 @@ -module LuxLibForwardDiffExt - -using ForwardDiff: ForwardDiff -using LuxLib: LuxLib -using LuxDeviceUtils: AbstractLuxGPUDevice -using NNlib: NNlib - -LuxLib.__has_dual(::ForwardDiff.Dual) = true -LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true - -# Convolutions: We might want to capture these further down in `conv!` -# NOTE: In principle we can concatenate all of the partials along the batch dimension -# and cut down substantially on the time to compute jacobians. -for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] - luxlibop = Symbol("__$(op)") - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; - kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = LuxLib.$(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) - dys = ntuple(i -> LuxLib.$(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = LuxLib.$(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) - dys = ntuple(i -> LuxLib.$(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - x1_data, x2_data = value_fn.(x1), value_fn.(x2) - - y = LuxLib.$(luxlibop)(x1_data, x2_data, cdims; kwargs...) - - dys₁ = ntuple(P) do i - dys₁ᵢ = LuxLib.$(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) - dys₂ᵢ = LuxLib.$(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) - dys₁ᵢ .+= dys₂ᵢ - return dys₁ᵢ - end - - partials = ForwardDiff.Partials.(tuple.(dys₁...)) - return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) - end -end - -# Don't try to promote the input types -function LuxLib.__get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{T}, x, weight) where {T} - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end -function LuxLib.__get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, - ::Type{<:ForwardDiff.Dual}, x, weight) where {T} - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end -function LuxLib.__get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{<:ForwardDiff.Dual}, x, weight) - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end - -LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -LuxLib.__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) - -end diff --git a/src/LuxLib.jl b/src/LuxLib.jl index c27d0859..8ce35303 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -6,6 +6,7 @@ using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure +using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, @@ -31,6 +32,7 @@ include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") include("impl/fast_activation.jl") +include("impl/forward_diff.jl") # User Facing include("api/batchnorm.jl") diff --git a/src/impl/forward_diff.jl b/src/impl/forward_diff.jl new file mode 100644 index 00000000..8e8cd64a --- /dev/null +++ b/src/impl/forward_diff.jl @@ -0,0 +1,50 @@ +for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] + luxlibop = Symbol("__$(op)") + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; + kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) + dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) + dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + x1_data, x2_data = value_fn.(x1), value_fn.(x2) + + y = $(luxlibop)(x1_data, x2_data, cdims; kwargs...) + + dys₁ = ntuple(P) do i + dys₁ᵢ = $(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = $(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) + dys₁ᵢ .+= dys₂ᵢ + return dys₁ᵢ + end + + partials = ForwardDiff.Partials.(tuple.(dys₁...)) + return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) + end +end diff --git a/src/impl/fused_conv.jl b/src/impl/fused_conv.jl index 4595490f..29c747e0 100644 --- a/src/impl/fused_conv.jl +++ b/src/impl/fused_conv.jl @@ -11,6 +11,19 @@ function __get_conv_input_weight( ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{T}, x, weight) where {T} + return __materialize_subarray(x), __materialize_subarray(weight) +end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, + ::Type{<:ForwardDiff.Dual}, x, weight) where {T} + return __materialize_subarray(x), __materialize_subarray(weight) +end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{<:ForwardDiff.Dual}, x, weight) + return __materialize_subarray(x), __materialize_subarray(weight) +end + function __get_conv_input_weight( ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} return __materialize_subarray(x), __materialize_subarray(weight) diff --git a/src/utils.jl b/src/utils.jl index c7f93036..12eeae4f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -49,6 +49,9 @@ CRC.@non_differentiable __is_immutable_array_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing __has_dual(x) = false +__has_dual(::ForwardDiff.Dual) = true +__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true + __is_immutable_array_or_dual(x) = __is_immutable_array(x) || __has_dual(x) function __is_immutable_array_or_dual_val(x::Tuple) return Val(unrolled_any(__is_immutable_array_or_dual, x)) @@ -189,4 +192,8 @@ __value(x::Number) = x __value(x::AbstractArray) = x __value(::Type{T}) where {T <: Number} = T +__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) + __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl