Skip to content

Commit

Permalink
Merge pull request #51 from LuxDL/ap/explicit_imports
Browse files Browse the repository at this point in the history
Explicit Imports and Fast Closures
  • Loading branch information
avik-pal authored Apr 15, 2024
2 parents 83e40e9 + 99e55b4 commit d63ef24
Show file tree
Hide file tree
Showing 33 changed files with 664 additions and 572 deletions.
1 change: 1 addition & 0 deletions lib/LuxLib/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
join_lines_based_on_source = false
3 changes: 2 additions & 1 deletion lib/LuxLib/.buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ steps:
cuda: "*"
env:
GROUP: "CUDA"
RETESTITEMS_NWORKERS: 0 # Distributed is causing stalling issues with CUDA
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
matrix:
Expand Down Expand Up @@ -160,6 +161,6 @@ steps:
- "Boltz"

env:
RETESTITEMS_NWORKERS: 2
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg=="
2 changes: 1 addition & 1 deletion lib/LuxLib/.github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1.9']
version: ['1.10']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
47 changes: 31 additions & 16 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.11"
version = "0.3.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -14,58 +16,71 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
LuxLibForwardDiffExt = "ForwardDiff"
LuxLibLuxAMDGPUTrackerExt = ["LuxAMDGPU", "Tracker"]
LuxLibLuxCUDAExt = "LuxCUDA"
LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"]
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
LuxLibTrackerExt = "Tracker"
LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"]
LuxLibcuDNNExt = ["CUDA", "cuDNN"]

[compat]
Aqua = "0.8"
AMDGPU = "0.8.4"
Aqua = "0.8.7"
CUDA = "5.2"
ChainRulesCore = "1.20"
ComponentArrays = "0.15.8"
ExplicitImports = "1.4.1"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
KernelAbstractions = "0.9.2"
KernelAbstractions = "0.9.15"
LuxAMDGPU = "0.2.1"
LuxCUDA = "0.3.1"
LuxCore = "0.1.13"
LuxTestUtils = "0.1.15"
Markdown = "1.9"
NNlib = "0.9.9"
Markdown = "1.10"
NNlib = "0.9.10"
PrecompileTools = "1.2"
Random = "1.9"
Random = "1.10"
ReTestItems = "1"
Reexport = "1"
ReverseDiff = "1.15"
StableRNGs = "1"
Statistics = "1.9"
Test = "1.9"
Tracker = "0.2.26"
Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.31"
Zygote = "0.6.69"
julia = "1.9"
cuDNN = "1.3"
julia = "1.10"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["AMDGPU", "Aqua", "CUDA", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"]
11 changes: 6 additions & 5 deletions lib/LuxLib/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ Backend for [Lux.jl](http://lux.csail.mit.edu/).

This is a developer-facing project and most users **should not** depend on it directly. As
such, we don't have tutorials for this package. Instead, we recommend you check out the
[Lux tutorials](http://lux.csail.mit.edu/stable/).
[Lux tutorials](http://lux.csail.mit.edu/).

## What's the distinction from NNlib.jl?
## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)?

Think of this package as a temporary location for functionalities that will move into
NNlib.jl. At the moment, this is supposed to be a heavier dependency than NNlib.jl, and
it makes no attempt to separate code across different architectures.
This is currently a place to hold more specialized kernels and layer implementations for
Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I
probably don't have the time to do so myself. But incase you do, open an issue here and let
me know I will delete the code from this package.

## Changelog

Expand Down
65 changes: 36 additions & 29 deletions lib/LuxLib/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,70 +1,77 @@
module LuxLibForwardDiffExt

using ForwardDiff, LuxLib, Statistics
import ForwardDiff: Dual
import LuxLib: AA
using ForwardDiff: ForwardDiff
using LuxLib: LuxLib
using NNlib: NNlib

# dropout
LuxLib._dropout_fptype(x::AA{<:Dual}) = ForwardDiff.valtype(eltype(x))
@inline function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual})
return ForwardDiff.valtype(eltype(x))
end

# Convolutions: We might want to capture these furthur down in `conv!`
# NOTE: In principle we can concatenate all of the partials along the batch dimension
# and cut down substantially on the time to compute jacobians.
for op in [:conv, :depthwiseconv]
op! = Symbol("$(op)!")

@eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N},
w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P}
@eval function NNlib.$(op)(
x::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, w::AbstractArray{<:Real, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
x_ = ForwardDiff.value.(x)

y = $(op)(x_, w, cdims; kwargs...)
dys = ntuple(i -> $(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P)
y = NNlib.$(op)(x_, w, cdims; kwargs...)
dys = ntuple(i -> NNlib.$(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P)

return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys...)
return map(
(yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)),
y, dys...)
end

@eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N},
cdims::ConvDims; kwargs...) where {N, Tag, V, P}
@eval function NNlib.$(op)(
x::AbstractArray{<:Real, N}, w::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
w_ = ForwardDiff.value.(w)

y = $(op)(x, w_, cdims; kwargs...)
dys = ntuple(i -> $(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P)
y = NNlib.$(op)(x, w_, cdims; kwargs...)
dys = ntuple(i -> NNlib.$(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P)

return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys...)
return map(
(yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)),
y, dys...)
end

@eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N},
w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims;
kwargs...) where {N, Tag, Vₓ, Vₚ, P}
@eval function NNlib.$(op)(x::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N},
w::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P}
x_ = ForwardDiff.value.(x)
w_ = ForwardDiff.value.(w)

y = $(op)(x_, w_, cdims; kwargs...)
y = NNlib.$(op)(x_, w_, cdims; kwargs...)

dys₁ = ntuple(
_ -> similar(x_, Vₓ, NNlib.output_size(cdims)...,
NNlib.channels_out(cdims), size(x, N)),
_ -> similar(
x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)),
P)
dys₂ = ntuple(
_ -> similar(x_, Vₓ, NNlib.output_size(cdims)...,
NNlib.channels_out(cdims), size(x, N)),
_ -> similar(
x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)),
P)
for i in 1:P
$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...)
$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...)
NNlib.$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...)
NNlib.$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...)
dys₁[i] .+= dys₂[i]
end

# Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation
# failure. We will assume it matches the type of the input.
return map((yᵢ, dyᵢ...) -> Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys₁...)
return map(
(yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)),
y, dys₁...)
end
end

function LuxLib._drop_forwarddiff_partials(x::AA{<:Dual})
function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual})
return ForwardDiff.value.(x)
end

Expand Down
46 changes: 0 additions & 46 deletions lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl

This file was deleted.

Loading

0 comments on commit d63ef24

Please sign in to comment.