Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Some more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 14, 2024
1 parent 58cdcfb commit 1dbcf1e
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 21 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ cuDNN = "1.3"
julia = "1.9"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand All @@ -72,10 +73,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["AMDGPU", "Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"]
1 change: 0 additions & 1 deletion ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module LuxLibForwardDiffExt

using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LuxLib: LuxLib
using NNlib: NNlib
Expand Down
10 changes: 6 additions & 4 deletions ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
module LuxLibReverseDiffExt

using ChainRulesCore: NoTangent
using ChainRulesCore: ChainRulesCore
using LuxLib: LuxLib
using NNlib: NNlib
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules

const CRC = ChainRulesCore

# Patches: Needs upstreaming
@inline function ReverseDiff.increment_deriv!(
t::Union{TrackedArray, TrackedReal}, ::NoTangent, i)
t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i)
return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i)
end
@inline function ReverseDiff.decrement_deriv!(
t::Union{TrackedArray, TrackedReal}, ::NoTangent, i)
t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i)
return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i)
end

Expand Down Expand Up @@ -39,7 +41,7 @@ end
@grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...)

for pool in (:maxpool, :meanpool, :lpnormpool)
@eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::PoolDims; kwargs...)
@eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...)
end

end
4 changes: 2 additions & 2 deletions ext/LuxTrackerAMDGPUExt.jl → ext/LuxLibTrackerAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ for poolname in (:maxpool, :meanpool)
_, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))(
NNlib.insert_singleton_spatial_dimension(y, nd),
NNlib.insert_singleton_spatial_dimension(x, nd);
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
stride=NNlib.stride(npdims))
dims=NNlib.kernel_size(npdims),
padding=nnlib_padding(npdims), stride=NNlib.stride(npdims))

function ∇pooling(Δ)
dx = similar(x)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using LuxLib: LuxLib
using NNlib: NNlib, batched_mul, batched_adjoint
using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal
using Tracker: Tracker, @grad, TrackedArray, TrackedReal

const CRC = ChainRulesCore

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibTrackercuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FastClosures: @closure
# cuDNN not loaded but it is needed for the batchnorm_cudnn implementation
using CUDA: CUDA, CuArray, CuVector, CuPtr
using LuxLib: LuxLib
using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal
using Tracker: Tracker, TrackedArray

# api/batchnorm.jl
const TR_CUDNN_BN_ARRAY_TYPE = Union{
Expand Down Expand Up @@ -42,7 +42,7 @@ end
@inline __make_nothing(x) = x
@inline __make_nothing(::CuPtr{Nothing}) = 0

@grad function LuxLib.batchnorm_cudnn(
Tracker.@grad function LuxLib.batchnorm_cudnn(
running_mean, running_var, scale, bias, x, momentum, eps, training)
y, xmean, xivar = LuxLib.batchnorm_cudnn(
Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale),
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxLibcuDNNExt

using LuxLib: LuxLib
using CUDA: CUDA, CuArray, CuVector, CuPtr, CU_NULL, DenseCuArray
using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray
using ChainRulesCore: ChainRulesCore
using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL,
Expand Down Expand Up @@ -34,7 +34,7 @@ end
scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps)
end

function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale,
function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale,
bias, x, momentum, epsilon, t::Val{training}) where {training}
y, xmean, xivar = LuxLib.batchnorm_cudnn(
running_mean, running_var, scale, bias, x, momentum, epsilon, t)
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibcuDNNExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra

xd = cudnnTensorDescriptor(x)
yd = cudnnTensorDescriptor(y)
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T),
Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW)))
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)),
cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW)))

if training
mean = fill!(similar(x, dims), zero(T))
Expand Down
2 changes: 1 addition & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using PrecompileTools: @recompile_invalidations
using NNlib: NNlib
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, var, varm
using Statistics: Statistics, mean, std, var
end

@reexport using NNlib
Expand Down
9 changes: 4 additions & 5 deletions src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ Normalized Array of same size as `x`.
"""
function layernorm(x::AbstractArray{T1, N}, scale::AbstractArray{T2, N},
bias::AbstractArray{T3, N}; dims, epsilon) where {N, T1, T2, T3}
x_norm = layernorm(x, nothing, nothing; dims, epsilon)
return scale .* x_norm .+ bias
_mean = mean(x; dims)
return scale .* (x .- _mean) ./
(std(x; dims, mean=_mean, corrected=false) .+ epsilon) .+ bias
end

function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon)
_mean = mean(x; dims)
_rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon)

return (x .- _mean) .* _rstd
return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon)
end
10 changes: 10 additions & 0 deletions test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,13 @@
using Aqua
Aqua.test_all(LuxLib)
end

@testitem "Explicit Imports" begin
import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib

using ExplicitImports

# Skip our own packages
@test check_no_implicit_imports(LuxLib; skip=(NNlib, Base, Core)) === nothing
@test check_no_stale_explicit_imports(LuxLib) === nothing
end

0 comments on commit 1dbcf1e

Please sign in to comment.