Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move ForwardDiff.jl into main deps
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent ea75484 commit a60f5ee
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 89 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Expand All @@ -23,15 +24,13 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
LuxLibAMDGPUExt = "AMDGPU"
LuxLibCUDAExt = "CUDA"
LuxLibForwardDiffExt = "ForwardDiff"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
LuxLibTrackerExt = "Tracker"
Expand Down Expand Up @@ -76,7 +75,6 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -89,4 +87,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"]
85 changes: 0 additions & 85 deletions ext/LuxLibForwardDiffExt.jl

This file was deleted.

2 changes: 2 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using DispatchDoctor: @stable
using EnzymeCore: EnzymeCore, EnzymeRules
using FastBroadcast: @..
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LinearAlgebra: LinearAlgebra, BLAS, mul!
using LuxCore: LuxCore
using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice,
Expand All @@ -31,6 +32,7 @@ include("impl/normalization.jl")
include("impl/fused_dense.jl")
include("impl/fused_conv.jl")
include("impl/fast_activation.jl")
include("impl/forward_diff.jl")

# User Facing
include("api/batchnorm.jl")
Expand Down
50 changes: 50 additions & 0 deletions src/impl/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter]
luxlibop = Symbol("__$(op)")

@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims;
kwargs...) where {N, Tag, V, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)

y = $(luxlibop)(value_fn.(x1), x2, cdims; kwargs...)
dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P)

partials = ForwardDiff.Partials.(tuple.(dys...))
return ForwardDiff.Dual{Tag, V, P}.(y, partials)
end

@eval function NNlib.$(op)(x1::AbstractArray{<:Real, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)

y = $(luxlibop)(x1, value_fn.(x2), cdims; kwargs...)
dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P)

partials = ForwardDiff.Partials.(tuple.(dys...))
return ForwardDiff.Dual{Tag, V, P}.(y, partials)
end

@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)

x1_data, x2_data = value_fn.(x1), value_fn.(x2)

y = $(luxlibop)(x1_data, x2_data, cdims; kwargs...)

dys₁ = ntuple(P) do i
dys₁ᵢ = $(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...)
dys₂ᵢ = $(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...)
dys₁ᵢ .+= dys₂ᵢ
return dys₁ᵢ
end

partials = ForwardDiff.Partials.(tuple.(dys₁...))
return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials)
end
end
13 changes: 13 additions & 0 deletions src/impl/fused_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ function __get_conv_input_weight(
::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T}
return __materialize_subarray(x), __materialize_subarray(weight)
end
function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual},
::Type{T}, x, weight) where {T}
return __materialize_subarray(x), __materialize_subarray(weight)
end
function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T},
::Type{<:ForwardDiff.Dual}, x, weight) where {T}
return __materialize_subarray(x), __materialize_subarray(weight)
end
function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual},
::Type{<:ForwardDiff.Dual}, x, weight)
return __materialize_subarray(x), __materialize_subarray(weight)
end

function __get_conv_input_weight(
::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT}
return __materialize_subarray(x), __materialize_subarray(weight)
Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ CRC.@non_differentiable __is_immutable_array_val(::Any...)
EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing

__has_dual(x) = false
__has_dual(::ForwardDiff.Dual) = true
__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true

__is_immutable_array_or_dual(x) = __is_immutable_array(x) || __has_dual(x)
function __is_immutable_array_or_dual_val(x::Tuple)
return Val(unrolled_any(__is_immutable_array_or_dual, x))
Expand Down Expand Up @@ -189,4 +192,8 @@ __value(x::Number) = x
__value(x::AbstractArray) = x
__value(::Type{T}) where {T <: Number} = T

__value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T)

__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl

0 comments on commit a60f5ee

Please sign in to comment.