diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl index fac745ca..e6c52330 100644 --- a/ext/LuxLibForwardDiffExt.jl +++ b/ext/LuxLibForwardDiffExt.jl @@ -16,7 +16,7 @@ for op in [:conv, :depthwiseconv] op! = Symbol("$(op)!") @eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N}, - w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} + w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} x_ = ForwardDiff.value.(x) y = $(op)(x_, w, cdims; kwargs...) @@ -27,7 +27,7 @@ for op in [:conv, :depthwiseconv] end @eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N}, - cdims::ConvDims; kwargs...) where {N, Tag, V, P} + cdims::ConvDims; kwargs...) where {N, Tag, V, P} w_ = ForwardDiff.value.(w) y = $(op)(x, w_, cdims; kwargs...) @@ -38,7 +38,8 @@ for op in [:conv, :depthwiseconv] end @eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N}, - w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; + kwargs...) where {N, Tag, Vₓ, Vₚ, P} x_ = ForwardDiff.value.(x) w_ = ForwardDiff.value.(w) diff --git a/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index 80f34b90..78c347d1 100644 --- a/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ b/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -17,8 +17,8 @@ const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4} const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, - epsilon::Real) + running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, + epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = first(batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) @@ -26,13 +26,13 @@ function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType end function batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, - training) + training) return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, bias, x, - momentum, epsilon, t::Val{training}) where {training} + momentum, epsilon, t::Val{training}) where {training} y, xmean, xivar = batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇batchnorm_cudnn_internal(Δ) diff --git a/ext/LuxLibLuxCUDAExt/batchnorm.jl b/ext/LuxLibLuxCUDAExt/batchnorm.jl index 8effb21c..dd4c68c2 100644 --- a/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -29,15 +29,15 @@ function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwa end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - args...; kwargs...) where {T <: FP_32_64} + args...; kwargs...) where {T <: FP_32_64} x = reshape(x, 1, 1, size(x, 1), size(x, 2)) y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} + x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the highest precision type. Avoid this code-path if possible" maxlog=1 Tₓ = eltype(x) @@ -57,14 +57,14 @@ function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, - x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; - kwargs...) where {T <: FP_32_64} + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) end function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, - x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; - α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} + x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; + α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} dims = _wsize(x) if ϵ < CUDNN_BN_MIN_EPSILON @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" @@ -102,7 +102,7 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra end function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, - running_μ, running_σ², args...; kwargs...) + running_μ, running_σ², args...; kwargs...) affine_sz = _wsize(x) g = fill!(similar(x, affine_sz), 1) b = fill!(similar(x, affine_sz), 0) @@ -118,7 +118,8 @@ function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::Dense end function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...; kwargs...) @@ -126,8 +127,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr end function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} + x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the highest precision type. Avoid this code-path if possible" maxlog=1 Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) @@ -148,7 +149,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, end function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, - ∂y::DenseCuArray{T}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂y::DenseCuArray{T}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} ∂g = similar(g) ∂b = similar(b) ∂x = similar(x) @@ -157,8 +159,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr end function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, - ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², - xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} + ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², + xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} if running_μ === nothing && running_σ² === nothing running_μ = CU_NULL running_σ² = CU_NULL diff --git a/ext/LuxLibLuxCUDATrackerExt.jl b/ext/LuxLibLuxCUDATrackerExt.jl index 4726610b..06f45a8a 100644 --- a/ext/LuxLibLuxCUDATrackerExt.jl +++ b/ext/LuxLibLuxCUDATrackerExt.jl @@ -14,8 +14,8 @@ const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP CuVector{<:FP_32_64}} function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, - bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; - momentum::Real, training::Val, epsilon::Real) + bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; + momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] @@ -31,7 +31,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), __is_tracked(RM, RV, S, B, XT) || continue @eval function batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) + bias::$B, x::$XT, momentum, eps, training::Val) return track(batchnorm_cudnn, running_mean, running_var, scale, bias, x, momentum, eps, training) end @@ -41,7 +41,7 @@ __make_nothing(x) = x __make_nothing(::CuPtr{Nothing}) = 0 @grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, - eps, training) + eps, training) y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale), data(bias), data(x), momentum, eps, training) function ∇batchnorm_cudnn_internal(Δ) diff --git a/ext/LuxLibTrackerExt.jl b/ext/LuxLibTrackerExt.jl index 35a41697..26fa3bb3 100644 --- a/ext/LuxLibTrackerExt.jl +++ b/ext/LuxLibTrackerExt.jl @@ -78,13 +78,13 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVecto __is_tracked(T1, T2, T3) || continue @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, scale::$T2{<:FP_32_64}, - bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) + bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end @grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/src/api/batchnorm.jl b/src/api/batchnorm.jl index c2a2e120..134e394c 100644 --- a/src/api/batchnorm.jl +++ b/src/api/batchnorm.jl @@ -39,7 +39,7 @@ fallback is used which is not highly optimized. learning. PMLR, 2015. """ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, - running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} + running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) diff --git a/src/api/dropout.jl b/src/api/dropout.jl index 6fd9f409..0612ef76 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -46,23 +46,23 @@ function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) wh end function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}, invp::T; - dims) where {T} + dims) where {T} return dropout(rng, x, p, t; dims, invp) end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} + ::Val{false}, invp::T; dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} + ::Val{false}, invp::T; dims) where {T, T1, T2, N} return (x, mask, rng) end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, um::Val; - dims, invp::T=inv(p)) where {T, T1, T2, N} + dims, invp::T=inv(p)) where {T, T1, T2, N} return dropout(rng, x, mask, p, t, um, invp; dims) end diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index 296d381a..f8b4d4a5 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -42,7 +42,7 @@ interface. on computer vision (ECCV). 2018. """ function groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, bias::AV{<:FP_32_64}; - groups::Int, epsilon::Real) + groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -56,7 +56,7 @@ end # Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, - epsilon::Real) where {N} + epsilon::Real) where {N} _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -79,7 +79,7 @@ end # Custom Pullbacks function CRC.rrule(::typeof(groupnorm), x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index 56e77dd7..8222e45a 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -29,7 +29,7 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, - epsilon::Real) where {N} + epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index f33ddcbc..39ad6cbf 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -30,7 +30,7 @@ Normalized Array of same size as `x`. preprint arXiv:1607.06450 (2016). """ function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, - epsilon) where {N} + epsilon) where {N} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias end diff --git a/src/impl/groupnorm.jl b/src/impl/groupnorm.jl index 89e40322..e9c0e769 100644 --- a/src/impl/groupnorm.jl +++ b/src/impl/groupnorm.jl @@ -5,7 +5,7 @@ _linear_threads_groupnorm(::GPU) = 256 # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu @kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ), - @Const(σ⁻¹), @Const(γ), @Const(β)) + @Const(σ⁻¹), @Const(γ), @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -16,14 +16,14 @@ _linear_threads_groupnorm(::GPU) = 256 end @kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), - @Const(bias)) + @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] end @kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), - @Const(γ)) + @Const(γ)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -32,7 +32,7 @@ end end @kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), - @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) + @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x @@ -40,7 +40,7 @@ end end @kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), - @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) + @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) ng = _div_idx(nc, K) @@ -77,7 +77,7 @@ end end @inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, γ::AV, β::AV, μ::AA5D, - σ⁻¹::AA5D) + σ⁻¹::AA5D) W, H, C, N = size(X) K = div(C, G) WxH = W * H diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index a1d6f7cc..b36a8169 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -1,7 +1,7 @@ # Generic Normalization Implementation function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:Real, N}, - running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, - momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} + running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, + momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) m_ = m / (m - one(m)) if last(reduce_dims) != N @@ -14,8 +14,8 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R end @generated function _get_batch_statistics(x::AA, running_mean::R, running_var::R, - r::Val{rdims}, ::Val{training}, - momentum::Union{Real, Nothing}) where {R, rdims, training} + r::Val{rdims}, ::Val{training}, + momentum::Union{Real, Nothing}) where {R, rdims, training} calls = [] if !training if R == Nothing @@ -40,7 +40,7 @@ end end @generated function _affine_normalize(x::AA, xmean::ST, xvar::ST, scale::A, - bias::A, epsilon::Real) where {ST, A} + bias::A, epsilon::Real) where {ST, A} if A != Nothing return quote x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) @@ -52,8 +52,8 @@ end end function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, - bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) where {R, A, reduce_dims} + bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real) where {R, A, reduce_dims} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) @@ -61,8 +61,8 @@ function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, end function _normalization(x::AA, running_mean::NOrAVR, running_var::NOrAVR, scale::NOrAVR, - bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) + bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x)