From 9960aa090ff8de0be532b23af281ce83cd6400a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 21:04:14 -0700 Subject: [PATCH] feat: improved fallback BN implementation --- .buildkite/testing.yml | 8 +- src/api/batchnorm.jl | 3 +- src/impl/affine_normalize.jl | 290 ++++++++++++++++++++++++++++++++--- src/impl/normalization.jl | 10 ++ src/utils.jl | 7 + 5 files changed, 287 insertions(+), 31 deletions(-) diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 429b91ac..4f65d90d 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -61,9 +61,7 @@ steps: - src - ext env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 4 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" @@ -93,9 +91,7 @@ steps: rocm: "*" rocmgpu: "*" env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 4 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 240 matrix: diff --git a/src/api/batchnorm.jl b/src/api/batchnorm.jl index 63d85d6f..7bd80138 100644 --- a/src/api/batchnorm.jl +++ b/src/api/batchnorm.jl @@ -42,7 +42,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} - x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, + x_, xm, xv = _batchnorm_impl( + x, __value(running_mean), __value(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=__value(xm), running_var=__value(xv))) diff --git a/src/impl/affine_normalize.jl b/src/impl/affine_normalize.jl index 11be7a0e..c2fef261 100644 --- a/src/impl/affine_normalize.jl +++ b/src/impl/affine_normalize.jl @@ -18,42 +18,270 @@ end # implementation. We bypass julia's broadcasting mechanism if we can. We still might fall # back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) -## Group Normalization +for norm_op in (:bn, :gn) + op = Symbol("_affine_normalize_$(norm_op)") + impl_op = Symbol("_affine_normalize_$(norm_op)_impl") + impl_op! = Symbol("__affine_normalize_$(norm_op)_impl!") + @eval begin + function $(op)(act::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F} + return $(op)(internal_operation_mode((x, μ, σ², scale, bias)), + act, x, μ, σ², scale, bias, ϵ) + end -function _affine_normalize_gn( - f::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F} - return _affine_normalize_gn( - internal_operation_mode((x, μ, σ², scale, bias)), f, x, μ, σ², scale, bias, ϵ) -end + function $(op)(::GenericBroadcastOp, act::F, x::AbstractArray{T, N}, + μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + return _affine_normalize( + act, x, μ, σ², _reshape_into_normalization_shape(scale, x), + _reshape_into_normalization_shape(bias, x), ϵ) + end -function _affine_normalize_gn(::GenericBroadcastOp, f::F, x::AbstractArray, - μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F} - return _affine_normalize(f, x, μ, σ², _reshape_into_normalization_shape(scale, x), - _reshape_into_normalization_shape(bias, x), ϵ) + function $(impl_op)(opmode::AbstractInternalArrayOpMode, act::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type(__eltype(x), __eltype(μ), __eltype(σ²), + __eltype(scale), __eltype(bias))) + $(impl_op!)(opmode, y, act, x, μ, σ², scale, bias, ϵ) + return y + end + end end -function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, +## Batch Normalization + +function _affine_normalize_bn(opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) - μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) - σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) - scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) - bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + μ_ = reshape(μ, 1, size(x, N - 1), 1) + σ²_ = reshape(σ², 1, size(x, N - 1), 1) + scale_ = __reshape(scale, 1, size(x, N - 1), 1) + bias_ = __reshape(bias, 1, size(x, N - 1), 1) + + return reshape( + _affine_normalize_bn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) +end + +function __affine_normalize_bn_impl!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, + μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}}, + bias::Optional{<:AbstractArray{<:Number, 3}}, ϵ::Real, + _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, + _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + N = size(y, 2) + _scale = _sc === nothing ? + similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, N, 1) : + _sc + _bias = _bc === nothing ? + similar( + x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), 1, N, 1) : _bc + + if scale !== nothing + @simd ivdep for J in axes(y, 2) + @inbounds _scale[1, J, 1] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ) + @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + bias[1, J, 1] + end + else + @simd ivdep for J in axes(y, 2) + @inbounds _scale[1, J, 1] = inv(sqrt(σ²[1, J, 1] + ϵ)) + @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + end + end + + for K in axes(y, 3), J in axes(y, 2) + @simd ivdep for I in axes(y, 1) + @inbounds y[I, J, K] = muladd(x[I, J, K], _scale[1, J, 1], _bias[1, J, 1]) + end + end + _fast_activation!(f, y) # NOTE: don't fuse into the above loop +end + +function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, + f::F, x::AbstractArray{<:Number, 3}, μ, σ², + scale::Optional{<:AbstractArray{<:Number, 3}}, + bias::Optional{<:AbstractArray{<:Number, 3}}, + ϵ::Real, _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, + _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + backend = KA.get_backend(y) + if _sc === nothing + kernel! = __affine_normalize_bn_kernel!(backend) + kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + else + kernel! = __affine_normalize_bn_kernel_cached!(backend) + kernel!(y, _sc, _bc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + end + KA.synchronize(backend) +end - return _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ) +@kernel function __affine_normalize_bn_kernel!( + y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds _sc = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) + @inbounds _bc = muladd(-μ[1, j, 1], _sc, bias[1, j, 1]) + else + @inbounds _sc = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds _bc = -μ[1, j, 1] * _sc + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc, _bc)) end -function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F, +@kernel function __affine_normalize_bn_kernel_cached!( + y::AbstractArray{<:Number, 3}, _sc::AbstractArray{<:Number, 3}, + _bc::AbstractArray{<:Number, 3}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds _sc[1, j, 1] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) + @inbounds _bc[1, j, 1] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1]) + else + @inbounds _sc[1, j, 1] = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds _bc[1, j, 1] = -μ[1, j, 1] * _sc[1, j, 1] + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[1, j, 1], _bc[1, j, 1])) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_bn_impl), + opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} y = similar(x, promote_type( __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) - __affine_normalize_gn_impl!(opmode, y, f, x, μ, σ², scale, bias, ϵ) - return y + _sc = similar( + x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, size(x, N - 1), 1) + _bc = similar( + x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), 1, size(x, N - 1), 1) + __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc, _bc) + z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) + + proj_x = CRC.ProjectTo(x) + proj_μ = CRC.ProjectTo(μ) + proj_σ² = CRC.ProjectTo(σ²) + proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale) + proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias) + + ∇affine_normalize_bn_impl_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_bn_impl( + opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + return ( + ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) + end + + return z, ∇affine_normalize_bn_impl_internal +end + +function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + ∂x = similar(x) + ∂μ = similar(μ, size(x)) + ∂σ² = similar(σ², size(x)) + ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) + ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) + + fill!(∂μ, false) + fill!(∂σ², false) + scale === nothing || fill!(∂sc, false) + bias === nothing || fill!(∂b, false) + + backend = KA.get_backend(∂x) + kernel! = ∇affine_normalize_bn_kernel!(backend) + kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc; ndrange=size(∂x)) + KA.synchronize(backend) + + ∂μ_ = __reduce_sum(μ, ∂μ) + ∂σ²_ = __reduce_sum(σ², ∂σ²) + ∂sc_ = __reduce_sum(scale, ∂sc) + ∂b_ = __reduce_sum(bias, ∂b) + + __unsafe_free!(∂μ) + __unsafe_free!(∂σ²) + __unsafe_free!(∂sc) + __unsafe_free!(∂b) + + return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ +end + +@kernel function ∇affine_normalize_bn_kernel!( + ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), + @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc), @Const(_bc)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds idenom = inv(sqrt(σ²[1, j, 1] + ϵ)) + else + @inbounds idenom = _sc[1, j, 1] + end + idenom² = idenom^2 + + @inbounds xμ = x[i, j, k] - μ[1, j, 1] + + @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[1, j, 1] + @inbounds ∂μ[i, j, k] = -∂x[i, j, k] + @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + + if scale !== nothing + @inbounds ∂sc[i, j, k] = ∂y[i, j, k] * xμ * idenom + @inbounds ∂b[i, j, k] = ∂y[i, j, k] + end +end + +function ∇affine_normalize_bn_impl( + ::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc, _bc) + ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) + half = eltype(∂σ²)(0.5) + + for K in axes(∂y, 3), J in axes(∂y, 2) + @inbounds idenom = _sc[1, J, 1] + idenom² = idenom^2 + @simd for I in axes(∂y, 1) + @inbounds xμ = x[I, J, K] - μ[1, J, 1] + + @inbounds ∂x[I, J, K] = ∂y[I, J, K] * idenom + @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] + @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + end + end + + return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ +end + +function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) + half = eltype(∂σ²)(0.5) + + for K in axes(∂y, 3), J in axes(∂y, 2) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, J, 1] + ϵ)) + idenom² = idenom^2 + @simd for I in axes(∂y, 1) + @inbounds xμ = x[I, J, K] - μ[1, J, 1] + + @inbounds ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1] + @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] + @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + @inbounds ∂sc[1, J, 1] += ∂y[I, J, K] * xμ * idenom + @inbounds ∂b[1, J, 1] += ∂y[I, J, K] + end + end + + return ∂x, ∂μ, ∂σ², ∂sc, ∂b +end + +## Group Normalization + +function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) + bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) + + return reshape( + _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, @@ -146,13 +374,27 @@ function ∇affine_normalize_gn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) + fill!(∂μ, false) + fill!(∂σ², false) + scale === nothing || fill!(∂sc, false) + bias === nothing || fill!(∂b, false) + backend = KA.get_backend(∂x) kernel! = ∇affine_normalize_gn_kernel!(backend) kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ; ndrange=size(∂x)) KA.synchronize(backend) - return (∂x, __reduce_sum(μ, ∂μ), __reduce_sum(σ², ∂σ²), - __reduce_sum(scale, ∂sc), __reduce_sum(bias, ∂b)) + ∂μ_ = __reduce_sum(μ, ∂μ) + ∂σ²_ = __reduce_sum(σ², ∂σ²) + ∂sc_ = __reduce_sum(scale, ∂sc) + ∂b_ = __reduce_sum(bias, ∂b) + + __unsafe_free!(∂μ) + __unsafe_free!(∂σ²) + __unsafe_free!(∂sc) + __unsafe_free!(∂b) + + return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ end @kernel function ∇affine_normalize_gn_kernel!( diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index a603cbed..3d6301cf 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -113,3 +113,13 @@ function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, x, nothing, nothing, reduce_dims, Val(false), nothing) return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end + +function _batchnorm_impl(x::AbstractArray, running_mean::Optional{<:AbstractVector}, + running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, reduce_dims::Val, + training::Val, momentum, epsilon, act::F=identity) where {F} + (μ, σ²), (rμ, rσ²) = _get_batch_statistics( + x, _reshape_into_normalization_shape(running_mean, x), + _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) + return _affine_normalize_bn(act, x, μ, σ², scale, bias, epsilon), _vec(rμ), _vec(rσ²) +end diff --git a/src/utils.jl b/src/utils.jl index 8def3aa3..9689c337 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -129,6 +129,7 @@ CRC.@non_differentiable __depwarn(::Any...) EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing __eltype(::AbstractArray{T}) where {T} = T +__eltype(::T) where {T <: Number} = T __eltype(::Nothing) = Bool CRC.@non_differentiable __eltype(::Any) @@ -148,6 +149,12 @@ __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) CRC.@non_differentiable __default_epsilon(::Any...) EnzymeRules.inactive_noinl(::typeof(__default_epsilon), ::Any...) = nothing +__unsafe_free!(x) = nothing +__unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) + +CRC.@non_differentiable __unsafe_free!(::Any) +EnzymeRules.inactive_noinl(::typeof(__unsafe_free!), ::Any) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args)