From 9cf31ce70877caf312901ff1c260e99785f2d2a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 01:20:11 -0400 Subject: [PATCH] Some more cleanup --- Project.toml | 7 ++++++- README.md | 4 ++-- ext/LuxLibForwardDiffExt.jl | 1 - ext/LuxLibReverseDiffExt.jl | 10 ++++++---- ...uxTrackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} | 4 ++-- ext/LuxLibTrackercuDNNExt.jl | 4 ++-- ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 4 ++-- ext/LuxLibcuDNNExt/batchnorm.jl | 4 ++-- src/LuxLib.jl | 2 +- src/api/layernorm.jl | 9 ++++----- test/qa_tests.jl | 10 ++++++++++ 11 files changed, 37 insertions(+), 22 deletions(-) rename ext/{LuxTrackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} (94%) diff --git a/Project.toml b/Project.toml index 898476a1..1181f429 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,9 @@ cuDNN = "1.3" julia = "1.9" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -72,10 +74,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["AMDGPU", "Aqua", "CUDA", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/README.md b/README.md index 0a6e39ce..f2970c30 100644 --- a/README.md +++ b/README.md @@ -18,11 +18,11 @@ Backend for [Lux.jl](http://lux.csail.mit.edu/). This is a developer-facing project and most users **should not** depend on it directly. As such, we don't have tutorials for this package. Instead, we recommend you check out the -[Lux tutorials](http://lux.csail.mit.edu/stable/). +[Lux tutorials](http://lux.csail.mit.edu/). ## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)? -This is currently a place to hold more specialized kernels and layer implementation for +This is currently a place to hold more specialized kernels and layer implementations for Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I probably don't have the time to do so myself. But incase you do, open an issue here and let me know I will delete the code from this package. diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl index 4c31d830..dd141912 100644 --- a/ext/LuxLibForwardDiffExt.jl +++ b/ext/LuxLibForwardDiffExt.jl @@ -1,6 +1,5 @@ module LuxLibForwardDiffExt -using FastClosures: @closure using ForwardDiff: ForwardDiff using LuxLib: LuxLib using NNlib: NNlib diff --git a/ext/LuxLibReverseDiffExt.jl b/ext/LuxLibReverseDiffExt.jl index ac199332..f7017ac0 100644 --- a/ext/LuxLibReverseDiffExt.jl +++ b/ext/LuxLibReverseDiffExt.jl @@ -1,17 +1,19 @@ module LuxLibReverseDiffExt -using ChainRulesCore: NoTangent +using ChainRulesCore: ChainRulesCore using LuxLib: LuxLib using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules +const CRC = ChainRulesCore + # Patches: Needs upstreaming @inline function ReverseDiff.increment_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i) end @inline function ReverseDiff.decrement_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i) end @@ -39,7 +41,7 @@ end @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) for pool in (:maxpool, :meanpool, :lpnormpool) - @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::PoolDims; kwargs...) + @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end end diff --git a/ext/LuxTrackerAMDGPUExt.jl b/ext/LuxLibTrackerAMDGPUExt.jl similarity index 94% rename from ext/LuxTrackerAMDGPUExt.jl rename to ext/LuxLibTrackerAMDGPUExt.jl index 11ed5d5e..eef503f6 100644 --- a/ext/LuxTrackerAMDGPUExt.jl +++ b/ext/LuxLibTrackerAMDGPUExt.jl @@ -35,8 +35,8 @@ for poolname in (:maxpool, :meanpool) _, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); - dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), - stride=NNlib.stride(npdims)) + dims=NNlib.kernel_size(npdims), + padding=nnlib_padding(npdims), stride=NNlib.stride(npdims)) function ∇pooling(Δ) dx = similar(x) diff --git a/ext/LuxLibTrackercuDNNExt.jl b/ext/LuxLibTrackercuDNNExt.jl index 5c8187be..adddd549 100644 --- a/ext/LuxLibTrackercuDNNExt.jl +++ b/ext/LuxLibTrackercuDNNExt.jl @@ -4,7 +4,7 @@ using FastClosures: @closure # cuDNN not loaded but it is needed for the batchnorm_cudnn implementation using CUDA: CUDA, CuArray, CuVector, CuPtr using LuxLib: LuxLib -using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal +using Tracker: Tracker, TrackedVector, TrackedArray # api/batchnorm.jl const TR_CUDNN_BN_ARRAY_TYPE = Union{ @@ -42,7 +42,7 @@ end @inline __make_nothing(x) = x @inline __make_nothing(::CuPtr{Nothing}) = 0 -@grad function LuxLib.batchnorm_cudnn( +Tracker.@grad function LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, eps, training) y, xmean, xivar = LuxLib.batchnorm_cudnn( Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), diff --git a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 644cc90c..7e6c39c3 100644 --- a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,7 +1,7 @@ module LuxLibcuDNNExt using LuxLib: LuxLib -using CUDA: CUDA, CuArray, CuVector, CuPtr, CU_NULL, DenseCuArray +using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, @@ -34,7 +34,7 @@ end scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end -function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, +function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) diff --git a/ext/LuxLibcuDNNExt/batchnorm.jl b/ext/LuxLibcuDNNExt/batchnorm.jl index a0c16d99..e3787220 100644 --- a/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/ext/LuxLibcuDNNExt/batchnorm.jl @@ -80,8 +80,8 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra xd = cudnnTensorDescriptor(x) yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), - Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) if training mean = fill!(similar(x, dims), zero(T)) diff --git a/src/LuxLib.jl b/src/LuxLib.jl index ccf34fea..033f712c 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -11,7 +11,7 @@ using PrecompileTools: @recompile_invalidations using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Reexport: @reexport - using Statistics: Statistics, mean, var, varm + using Statistics: Statistics, mean, std, var end @reexport using NNlib diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index 72c7b819..aa1fe33d 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -31,13 +31,12 @@ Normalized Array of same size as `x`. """ function layernorm(x::AbstractArray{T1, N}, scale::AbstractArray{T2, N}, bias::AbstractArray{T3, N}; dims, epsilon) where {N, T1, T2, T3} - x_norm = layernorm(x, nothing, nothing; dims, epsilon) - return scale .* x_norm .+ bias + _mean = mean(x; dims) + return scale .* (x .- _mean) ./ + (std(x; dims, mean=_mean, corrected=false) .+ epsilon) .+ bias end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) - _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - - return (x .- _mean) .* _rstd + return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index f339224a..e043e388 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -2,3 +2,13 @@ using Aqua Aqua.test_all(LuxLib) end + +@testitem "Explicit Imports" begin + import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(LuxLib; skip=(NNlib, Base, Core)) === nothing + @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing +end