Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Format .jl files
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Oct 30, 2023
1 parent 63e17d9 commit 30975b8
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 52 deletions.
7 changes: 4 additions & 3 deletions ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ for op in [:conv, :depthwiseconv]
op! = Symbol("$(op)!")

@eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N},
w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P}
w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P}
x_ = ForwardDiff.value.(x)

y = $(op)(x_, w, cdims; kwargs...)
Expand All @@ -27,7 +27,7 @@ for op in [:conv, :depthwiseconv]
end

@eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N},
cdims::ConvDims; kwargs...) where {N, Tag, V, P}
cdims::ConvDims; kwargs...) where {N, Tag, V, P}
w_ = ForwardDiff.value.(w)

y = $(op)(x, w_, cdims; kwargs...)
Expand All @@ -38,7 +38,8 @@ for op in [:conv, :depthwiseconv]
end

@eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N},
w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P}
w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims;
kwargs...) where {N, Tag, Vₓ, Vₚ, P}
x_ = ForwardDiff.value.(x)
w_ = ForwardDiff.value.(w)

Expand Down
8 changes: 4 additions & 4 deletions ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}
const BNParamType = Union{Nothing, CuVector{<:FP_32_64}}

function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType,
running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val,
epsilon::Real)
running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val,
epsilon::Real)
rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training)

x_ = first(batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training))
return x_, (; running_mean=rm, running_var=rv)
end

function batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps,
training)
training)
return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum,
training; ϵ=eps)
end

function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, bias, x,
momentum, epsilon, t::Val{training}) where {training}
momentum, epsilon, t::Val{training}) where {training}
y, xmean, xivar = batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum,
epsilon, t)
function ∇batchnorm_cudnn_internal(Δ)
Expand Down
30 changes: 16 additions & 14 deletions ext/LuxLibLuxCUDAExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwa
end

function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2},
args...; kwargs...) where {T <: FP_32_64}
args...; kwargs...) where {T <: FP_32_64}
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...)
return dropdims(y; dims=(1, 2)), xμ, xσ⁻²
end

function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂},
x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...;
kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64}
x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...;
kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64}
@warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the
highest precision type. Avoid this code-path if possible" maxlog=1
Tₓ = eltype(x)
Expand All @@ -57,14 +57,14 @@ function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂},
end

function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T},
x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...;
kwargs...) where {T <: FP_32_64}
x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...;
kwargs...) where {T <: FP_32_64}
return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...)
end

function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T},
x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training};
α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training}
x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training};
α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training}
dims = _wsize(x)
if ϵ < CUDNN_BN_MIN_EPSILON
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
Expand Down Expand Up @@ -102,7 +102,7 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra
end

function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray,
running_μ, running_σ², args...; kwargs...)
running_μ, running_σ², args...; kwargs...)
affine_sz = _wsize(x)
g = fill!(similar(x, affine_sz), 1)
b = fill!(similar(x, affine_sz), 0)
Expand All @@ -118,16 +118,17 @@ function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::Dense
end

function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2},
∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64}
∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...;
kwargs...) where {T <: FP_32_64}
∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)),
reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...;
kwargs...)
return (∂g, ∂b, dropdims(∂x; dims=(1, 2)))
end

function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂},
x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...;
kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64}
x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...;
kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64}
@warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the
highest precision type. Avoid this code-path if possible" maxlog=1
Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ)
Expand All @@ -148,7 +149,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂},
end

function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
∂y::DenseCuArray{T}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64}
∂y::DenseCuArray{T}, running_μ, running_σ², args...;
kwargs...) where {T <: FP_32_64}
∂g = similar(g)
∂b = similar(b)
∂x = similar(x)
Expand All @@ -157,8 +159,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr
end

function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T},
∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ²,
xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64}
∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ²,
xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64}
if running_μ === nothing && running_σ² === nothing
running_μ = CU_NULL
running_σ² = CU_NULL
Expand Down
8 changes: 4 additions & 4 deletions ext/LuxLibLuxCUDATrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP
CuVector{<:FP_32_64}}

function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType,
bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType;
momentum::Real, training::Val, epsilon::Real)
bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType;
momentum::Real, training::Val, epsilon::Real)
rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training)

x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1]
Expand All @@ -31,7 +31,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
__is_tracked(RM, RV, S, B, XT) || continue

@eval function batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S,
bias::$B, x::$XT, momentum, eps, training::Val)
bias::$B, x::$XT, momentum, eps, training::Val)
return track(batchnorm_cudnn, running_mean, running_var, scale, bias, x, momentum,
eps, training)
end
Expand All @@ -41,7 +41,7 @@ __make_nothing(x) = x
__make_nothing(::CuPtr{Nothing}) = 0

@grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum,
eps, training)
eps, training)
y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale),
data(bias), data(x), momentum, eps, training)
function ∇batchnorm_cudnn_internal(Δ)
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVecto
__is_tracked(T1, T2, T3) || continue

@eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, scale::$T2{<:FP_32_64},
bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real)
bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real)
return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon)
end
end

@grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64},
bias::AV{<:FP_32_64}; groups::Int, epsilon::Real)
bias::AV{<:FP_32_64}; groups::Int, epsilon::Real)
LuxLib._assert_same_backend(data(x), data(scale), data(bias))
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array)."))
Expand Down
2 changes: 1 addition & 1 deletion src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fallback is used which is not highly optimized.
learning. PMLR, 2015.
"""
function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR,
running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N}
running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N}
x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean),
_drop_forwarddiff_partials(running_var), scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon)
Expand Down
8 changes: 4 additions & 4 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@ function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) wh
end

function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}, invp::T;
dims) where {T}
dims) where {T}
return dropout(rng, x, p, t; dims, invp)
end

function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true},
::Val{false}, invp::T; dims) where {T, T1, T2, N}
::Val{false}, invp::T; dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp)
return x .* ignore_derivatives(mask), mask, rng
end

function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false},
::Val{false}, invp::T; dims) where {T, T1, T2, N}
::Val{false}, invp::T; dims) where {T, T1, T2, N}
return (x, mask, rng)
end

function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, um::Val;
dims, invp::T=inv(p)) where {T, T1, T2, N}
dims, invp::T=inv(p)) where {T, T1, T2, N}
return dropout(rng, x, mask, p, t, um, invp; dims)
end

Expand Down
6 changes: 3 additions & 3 deletions src/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ interface.
on computer vision (ECCV). 2018.
"""
function groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, bias::AV{<:FP_32_64};
groups::Int, epsilon::Real)
groups::Int, epsilon::Real)
_assert_same_backend(x, scale, bias)
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array)."))
Expand All @@ -56,7 +56,7 @@ end

# Slow Fallback (without custom Pullback Implementation)
function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int,
epsilon::Real) where {N}
epsilon::Real) where {N}
_assert_same_backend(x, scale, bias)
if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array)."))
Expand All @@ -79,7 +79,7 @@ end

# Custom Pullbacks
function CRC.rrule(::typeof(groupnorm), x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64},
bias::AV{<:FP_32_64}; groups::Int, epsilon::Real)
bias::AV{<:FP_32_64}; groups::Int, epsilon::Real)
_assert_same_backend(x, scale, bias)
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array)."))
Expand Down
2 changes: 1 addition & 1 deletion src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mean and variance.
missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
"""
function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val,
epsilon::Real) where {N}
epsilon::Real) where {N}
_test_valid_instancenorm_arguments(x)

x_, xm, xv = _normalization(x, nothing, nothing, scale, bias,
Expand Down
2 changes: 1 addition & 1 deletion src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Normalized Array of same size as `x`.
preprint arXiv:1607.06450 (2016).
"""
function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims,
epsilon) where {N}
epsilon) where {N}
x_norm = layernorm(x, nothing, nothing; dims, epsilon)
return scale .* x_norm .+ bias
end
Expand Down
12 changes: 6 additions & 6 deletions src/impl/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ _linear_threads_groupnorm(::GPU) = 256
# Low-Level Kernels
## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu
@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ),
@Const(σ⁻¹), @Const(γ), @Const(β))
@Const(σ⁻¹), @Const(γ), @Const(β))
idx = @index(Global)
ng = _div_idx(idx, K)
c = _mod_idx(idx, C)
Expand All @@ -16,14 +16,14 @@ _linear_threads_groupnorm(::GPU) = 256
end

@kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale),
@Const(bias))
@Const(bias))
idx = @index(Global)
nc = _div_idx(idx, WxH)
@inbounds Y[idx] = X[idx] * scale[nc] + bias[nc]
end

@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹),
@Const(γ))
@Const(γ))
idx = @index(Global)
ng = _div_idx(idx, K)
c = _mod_idx(idx, C)
Expand All @@ -32,15 +32,15 @@ end
end

@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ),
@Const(σ⁻¹), @Const(ds_sum), @Const(db_sum))
@Const(σ⁻¹), @Const(ds_sum), @Const(db_sum))
idx = @index(Global)
@inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha
@inbounds X_scale[idx] = x
@inbounds bias[idx] = -(x * μ[idx] + db_sum[idx] * σ⁻¹[idx] * alpha)
end

@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale),
@Const(dY), @Const(X_scale), @Const(X), @Const(bias))
@Const(dY), @Const(X_scale), @Const(X), @Const(bias))
idx = @index(Global)
nc = _div_idx(idx, WxH)
ng = _div_idx(nc, K)
Expand Down Expand Up @@ -77,7 +77,7 @@ end
end

@inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, γ::AV, β::AV, μ::AA5D,
σ⁻¹::AA5D)
σ⁻¹::AA5D)
W, H, C, N = size(X)
K = div(C, G)
WxH = W * H
Expand Down
18 changes: 9 additions & 9 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generic Normalization Implementation
function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:Real, N},
running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N},
momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims}
running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N},
momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims}
m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims))
m_ = m / (m - one(m))
if last(reduce_dims) != N
Expand All @@ -14,8 +14,8 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R
end

@generated function _get_batch_statistics(x::AA, running_mean::R, running_var::R,
r::Val{rdims}, ::Val{training},
momentum::Union{Real, Nothing}) where {R, rdims, training}
r::Val{rdims}, ::Val{training},
momentum::Union{Real, Nothing}) where {R, rdims, training}
calls = []
if !training
if R == Nothing
Expand All @@ -40,7 +40,7 @@ end
end

@generated function _affine_normalize(x::AA, xmean::ST, xvar::ST, scale::A,
bias::A, epsilon::Real) where {ST, A}
bias::A, epsilon::Real) where {ST, A}
if A != Nothing
return quote
x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon)
Expand All @@ -52,17 +52,17 @@ end
end

function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A,
bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing},
epsilon::Real) where {R, A, reduce_dims}
bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing},
epsilon::Real) where {R, A, reduce_dims}
_stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum)
(batchmean, batchvar), (running_mean, running_var) = _stats
x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon)
return (x_norm, running_mean, running_var)
end

function _normalization(x::AA, running_mean::NOrAVR, running_var::NOrAVR, scale::NOrAVR,
bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing},
epsilon::Real)
bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing},
epsilon::Real)
rm_ = _reshape_into_proper_shape(running_mean, x)
rv_ = _reshape_into_proper_shape(running_var, x)
s_ = _reshape_into_proper_shape(scale, x)
Expand Down

0 comments on commit 30975b8

Please sign in to comment.