From c2e28a2d65f04ebecccdcba23dae7ba422dee7e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 01:20:11 -0400 Subject: [PATCH] Some more cleanup --- Project.toml | 7 ++++++- README.md | 4 ++-- ext/LuxLibForwardDiffExt.jl | 1 - ext/LuxLibReverseDiffExt.jl | 10 ++++++---- ...rackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} | 4 ++-- ext/LuxLibTrackerExt.jl | 8 ++++---- ext/LuxLibTrackercuDNNExt.jl | 13 +++++++------ ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 6 +++--- ext/LuxLibcuDNNExt/batchnorm.jl | 4 ++-- src/LuxLib.jl | 2 +- src/api/layernorm.jl | 5 ++--- test/qa_tests.jl | 10 ++++++++++ 12 files changed, 45 insertions(+), 29 deletions(-) rename ext/{LuxTrackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} (94%) diff --git a/Project.toml b/Project.toml index 898476a1..1181f429 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,9 @@ cuDNN = "1.3" julia = "1.9" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -72,10 +74,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["AMDGPU", "Aqua", "CUDA", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/README.md b/README.md index 0a6e39ce..f2970c30 100644 --- a/README.md +++ b/README.md @@ -18,11 +18,11 @@ Backend for [Lux.jl](http://lux.csail.mit.edu/). This is a developer-facing project and most users **should not** depend on it directly. As such, we don't have tutorials for this package. Instead, we recommend you check out the -[Lux tutorials](http://lux.csail.mit.edu/stable/). +[Lux tutorials](http://lux.csail.mit.edu/). ## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)? -This is currently a place to hold more specialized kernels and layer implementation for +This is currently a place to hold more specialized kernels and layer implementations for Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I probably don't have the time to do so myself. But incase you do, open an issue here and let me know I will delete the code from this package. diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl index 4c31d830..dd141912 100644 --- a/ext/LuxLibForwardDiffExt.jl +++ b/ext/LuxLibForwardDiffExt.jl @@ -1,6 +1,5 @@ module LuxLibForwardDiffExt -using FastClosures: @closure using ForwardDiff: ForwardDiff using LuxLib: LuxLib using NNlib: NNlib diff --git a/ext/LuxLibReverseDiffExt.jl b/ext/LuxLibReverseDiffExt.jl index ac199332..f7017ac0 100644 --- a/ext/LuxLibReverseDiffExt.jl +++ b/ext/LuxLibReverseDiffExt.jl @@ -1,17 +1,19 @@ module LuxLibReverseDiffExt -using ChainRulesCore: NoTangent +using ChainRulesCore: ChainRulesCore using LuxLib: LuxLib using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules +const CRC = ChainRulesCore + # Patches: Needs upstreaming @inline function ReverseDiff.increment_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i) end @inline function ReverseDiff.decrement_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i) end @@ -39,7 +41,7 @@ end @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) for pool in (:maxpool, :meanpool, :lpnormpool) - @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::PoolDims; kwargs...) + @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end end diff --git a/ext/LuxTrackerAMDGPUExt.jl b/ext/LuxLibTrackerAMDGPUExt.jl similarity index 94% rename from ext/LuxTrackerAMDGPUExt.jl rename to ext/LuxLibTrackerAMDGPUExt.jl index 11ed5d5e..eef503f6 100644 --- a/ext/LuxTrackerAMDGPUExt.jl +++ b/ext/LuxLibTrackerAMDGPUExt.jl @@ -35,8 +35,8 @@ for poolname in (:maxpool, :meanpool) _, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); - dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), - stride=NNlib.stride(npdims)) + dims=NNlib.kernel_size(npdims), + padding=nnlib_padding(npdims), stride=NNlib.stride(npdims)) function ∇pooling(Δ) dx = similar(x) diff --git a/ext/LuxLibTrackerExt.jl b/ext/LuxLibTrackerExt.jl index bdf98df6..57354cb1 100644 --- a/ext/LuxLibTrackerExt.jl +++ b/ext/LuxLibTrackerExt.jl @@ -46,9 +46,9 @@ Base.repeat(x::TrackedArray, counts...) = Tracker.track(Base.repeat, x, counts.. @grad function Base.repeat(x, counts...) y, ∇repeat_cr = CRC.rrule(Base.repeat, Tracker.data(x), counts...) ∇repeat = @closure Δ -> begin - _, res... = ∇repeat_cr(Δ) - return nobacksies( - :repeat, map(x -> x == CRC.NoTangent() ? nothing : CRC.unthunk(x), res)) + res = ∇repeat_cr(Δ)[2:(2 + length(counts))] + return Tracker.nobacksies( + :repeat, map(x -> x isa CRC.NoTangent ? nothing : CRC.unthunk(x), res)) end return y, ∇repeat end @@ -109,7 +109,7 @@ end ∇groupnorm = @closure Δ -> begin dx, dscale, dbias = LuxLib._∇groupnorm( Δ, y, Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), μ, σ⁻¹) - return nobacksies(:groupnorm, (dx, dscale, dbias)) + return Tracker.nobacksies(:groupnorm, (dx, dscale, dbias)) end return y, ∇groupnorm end diff --git a/ext/LuxLibTrackercuDNNExt.jl b/ext/LuxLibTrackercuDNNExt.jl index 5c8187be..1694ef8e 100644 --- a/ext/LuxLibTrackercuDNNExt.jl +++ b/ext/LuxLibTrackercuDNNExt.jl @@ -2,9 +2,9 @@ module LuxLibTrackercuDNNExt using FastClosures: @closure # cuDNN not loaded but it is needed for the batchnorm_cudnn implementation -using CUDA: CUDA, CuArray, CuVector, CuPtr +using CUDA: CUDA, CuArray, CuVector, CU_NULL using LuxLib: LuxLib -using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal +using Tracker: Tracker, TrackedVector, TrackedArray # api/batchnorm.jl const TR_CUDNN_BN_ARRAY_TYPE = Union{ @@ -20,7 +20,8 @@ function LuxLib.batchnorm( running_mean::TR_BNParamType, running_var::TR_BNParamType; momentum::Real, training::Val, epsilon::Real) rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) + # NOTE: The following returns a tracked tuple so we can't do `first` on it + x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] return x_, (; running_mean=rm, running_var=rv) end @@ -40,16 +41,16 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), end @inline __make_nothing(x) = x -@inline __make_nothing(::CuPtr{Nothing}) = 0 +@inline __make_nothing(::typeof(CU_NULL)) = 0 -@grad function LuxLib.batchnorm_cudnn( +Tracker.@grad function LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, eps, training) y, xmean, xivar = LuxLib.batchnorm_cudnn( Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), Tracker.data(bias), Tracker.data(x), momentum, eps, training) ∇batchnorm_cudnn_internal = @closure Δ -> begin ∂y = first(Δ) - ∂g, ∂b, ∂x = ∇batchnorm_cudnn( + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( Tracker.data(scale), Tracker.data(bias), Tracker.data(x), ∂y, Tracker.data(running_mean), Tracker.data(running_var), xmean, xivar; ϵ=eps) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) diff --git a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 644cc90c..3727b3b5 100644 --- a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,9 +1,9 @@ module LuxLibcuDNNExt using LuxLib: LuxLib -using CUDA: CUDA, CuArray, CuVector, CuPtr, CU_NULL, DenseCuArray +using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore -using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, +using cuDNN: cuDNN, CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType @@ -34,7 +34,7 @@ end scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end -function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, +function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) diff --git a/ext/LuxLibcuDNNExt/batchnorm.jl b/ext/LuxLibcuDNNExt/batchnorm.jl index a0c16d99..e3787220 100644 --- a/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/ext/LuxLibcuDNNExt/batchnorm.jl @@ -80,8 +80,8 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra xd = cudnnTensorDescriptor(x) yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), - Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) if training mean = fill!(similar(x, dims), zero(T)) diff --git a/src/LuxLib.jl b/src/LuxLib.jl index ccf34fea..033f712c 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -11,7 +11,7 @@ using PrecompileTools: @recompile_invalidations using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Reexport: @reexport - using Statistics: Statistics, mean, var, varm + using Statistics: Statistics, mean, std, var end @reexport using NNlib diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index 72c7b819..3cc25e93 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -37,7 +37,6 @@ end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) - _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - - return (x .- _mean) .* _rstd + rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) + return (x .- _mean) .* rstd end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index f339224a..e043e388 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -2,3 +2,13 @@ using Aqua Aqua.test_all(LuxLib) end + +@testitem "Explicit Imports" begin + import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(LuxLib; skip=(NNlib, Base, Core)) === nothing + @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing +end