Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: improved fallback BN implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 27, 2024
1 parent baba87c commit 9960aa0
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 31 deletions.
8 changes: 2 additions & 6 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ steps:
- src
- ext
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
RETESTITEMS_NWORKERS: 4
BACKEND_GROUP: "AMDGPU"
agents:
queue: "juliagpu"
Expand Down Expand Up @@ -93,9 +91,7 @@ steps:
rocm: "*"
rocmgpu: "*"
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
RETESTITEMS_NWORKERS: 4
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 240
matrix:
Expand Down
3 changes: 2 additions & 1 deletion src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector
bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector},
running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity,
momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N}
x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias,
x_, xm, xv = _batchnorm_impl(
x, __value(running_mean), __value(running_var), scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon,
select_fastest_activation(σ, x, scale, bias, running_mean, running_var))
return (x_, (; running_mean=__value(xm), running_var=__value(xv)))
Expand Down
290 changes: 266 additions & 24 deletions src/impl/affine_normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,270 @@ end
# implementation. We bypass julia's broadcasting mechanism if we can. We still might fall
# back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff)

## Group Normalization
for norm_op in (:bn, :gn)
op = Symbol("_affine_normalize_$(norm_op)")
impl_op = Symbol("_affine_normalize_$(norm_op)_impl")
impl_op! = Symbol("__affine_normalize_$(norm_op)_impl!")
@eval begin
function $(op)(act::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F}
return $(op)(internal_operation_mode((x, μ, σ², scale, bias)),
act, x, μ, σ², scale, bias, ϵ)
end

function _affine_normalize_gn(
f::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F}
return _affine_normalize_gn(
internal_operation_mode((x, μ, σ², scale, bias)), f, x, μ, σ², scale, bias, ϵ)
end
function $(op)(::GenericBroadcastOp, act::F, x::AbstractArray{T, N},
μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
return _affine_normalize(
act, x, μ, σ², _reshape_into_normalization_shape(scale, x),
_reshape_into_normalization_shape(bias, x), ϵ)
end

function _affine_normalize_gn(::GenericBroadcastOp, f::F, x::AbstractArray,
μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F}
return _affine_normalize(f, x, μ, σ², _reshape_into_normalization_shape(scale, x),
_reshape_into_normalization_shape(bias, x), ϵ)
function $(impl_op)(opmode::AbstractInternalArrayOpMode, act::F,
x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N}
y = similar(x,
promote_type(__eltype(x), __eltype(μ), __eltype(σ²),
__eltype(scale), __eltype(bias)))
$(impl_op!)(opmode, y, act, x, μ, σ², scale, bias, ϵ)
return y
end
end
end

function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F,
## Batch Normalization

function _affine_normalize_bn(opmode::AbstractInternalArrayOpMode, f::F,
x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N))
μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N))
σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N))
scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1)
bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1)
x_ = reshape(x, :, size(x, N - 1), size(x, N))
μ_ = reshape(μ, 1, size(x, N - 1), 1)
σ²_ = reshape(σ², 1, size(x, N - 1), 1)
scale_ = __reshape(scale, 1, size(x, N - 1), 1)
bias_ = __reshape(bias, 1, size(x, N - 1), 1)

return reshape(
_affine_normalize_bn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x))
end

function __affine_normalize_bn_impl!(
::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3},
μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}},
bias::Optional{<:AbstractArray{<:Number, 3}}, ϵ::Real,
_sc::Optional{<:AbstractArray{<:Number, 3}}=nothing,
_bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F}
N = size(y, 2)
_scale = _sc === nothing ?
similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, N, 1) :
_sc
_bias = _bc === nothing ?
similar(
x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), 1, N, 1) : _bc

if scale !== nothing
@simd ivdep for J in axes(y, 2)
@inbounds _scale[1, J, 1] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ)
@inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + bias[1, J, 1]
end
else
@simd ivdep for J in axes(y, 2)
@inbounds _scale[1, J, 1] = inv(sqrt(σ²[1, J, 1] + ϵ))
@inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1]
end
end

for K in axes(y, 3), J in axes(y, 2)
@simd ivdep for I in axes(y, 1)
@inbounds y[I, J, K] = muladd(x[I, J, K], _scale[1, J, 1], _bias[1, J, 1])
end
end
_fast_activation!(f, y) # NOTE: don't fuse into the above loop
end

function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3},
f::F, x::AbstractArray{<:Number, 3}, μ, σ²,
scale::Optional{<:AbstractArray{<:Number, 3}},
bias::Optional{<:AbstractArray{<:Number, 3}},
ϵ::Real, _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing,
_bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F}
backend = KA.get_backend(y)
if _sc === nothing
kernel! = __affine_normalize_bn_kernel!(backend)
kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y))
else
kernel! = __affine_normalize_bn_kernel_cached!(backend)
kernel!(y, _sc, _bc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y))
end
KA.synchronize(backend)
end

return _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ)
@kernel function __affine_normalize_bn_kernel!(
y::AbstractArray{<:Number, 3}, @Const(f), @Const(x),
@Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ))
(i, j, k) = @index(Global, NTuple)
if scale !== nothing
@inbounds _sc = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ)
@inbounds _bc = muladd(-μ[1, j, 1], _sc, bias[1, j, 1])
else
@inbounds _sc = inv(sqrt(σ²[1, j, 1] + ϵ))
@inbounds _bc = -μ[1, j, 1] * _sc
end
@inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc, _bc))
end

function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F,
@kernel function __affine_normalize_bn_kernel_cached!(
y::AbstractArray{<:Number, 3}, _sc::AbstractArray{<:Number, 3},
_bc::AbstractArray{<:Number, 3}, @Const(f), @Const(x),
@Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ))
(i, j, k) = @index(Global, NTuple)
if scale !== nothing
@inbounds _sc[1, j, 1] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ)
@inbounds _bc[1, j, 1] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1])
else
@inbounds _sc[1, j, 1] = inv(sqrt(σ²[1, j, 1] + ϵ))
@inbounds _bc[1, j, 1] = -μ[1, j, 1] * _sc[1, j, 1]
end
@inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[1, j, 1], _bc[1, j, 1]))
end

function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_bn_impl),
opmode::AbstractInternalArrayOpMode, f::F,
x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N}
y = similar(x,
promote_type(
__eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias)))
__affine_normalize_gn_impl!(opmode, y, f, x, μ, σ², scale, bias, ϵ)
return y
_sc = similar(
x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, size(x, N - 1), 1)
_bc = similar(
x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), 1, size(x, N - 1), 1)
__affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc, _bc)
z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y)

proj_x = CRC.ProjectTo(x)
proj_μ = CRC.ProjectTo(μ)
proj_σ² = CRC.ProjectTo(σ²)
proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale)
proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias)

∇affine_normalize_bn_impl_internal = @closure Δ -> begin
∂y = last(∇activation(Δ))
∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_bn_impl(
opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc)
return (
∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅)
end

return z, ∇affine_normalize_bn_impl_internal
end

function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc)
∂x = similar(x)
∂μ = similar(μ, size(x))
∂σ² = similar(σ², size(x))
∂sc = scale === nothing ? ∂∅ : similar(scale, size(x))
∂b = bias === nothing ? ∂∅ : similar(bias, size(x))

fill!(∂μ, false)
fill!(∂σ², false)
scale === nothing || fill!(∂sc, false)
bias === nothing || fill!(∂b, false)

backend = KA.get_backend(∂x)
kernel! = ∇affine_normalize_bn_kernel!(backend)
kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc; ndrange=size(∂x))
KA.synchronize(backend)

∂μ_ = __reduce_sum(μ, ∂μ)
∂σ²_ = __reduce_sum(σ², ∂σ²)
∂sc_ = __reduce_sum(scale, ∂sc)
∂b_ = __reduce_sum(bias, ∂b)

__unsafe_free!(∂μ)
__unsafe_free!(∂σ²)
__unsafe_free!(∂sc)
__unsafe_free!(∂b)

return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_
end

@kernel function ∇affine_normalize_bn_kernel!(
∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), @Const(σ²),
@Const(scale), @Const(bias), @Const(ϵ), @Const(_sc), @Const(_bc))
(i, j, k) = @index(Global, NTuple)
if scale !== nothing
@inbounds idenom = inv(sqrt(σ²[1, j, 1] + ϵ))
else
@inbounds idenom = _sc[1, j, 1]
end
idenom² = idenom^2

@inbounds= x[i, j, k] - μ[1, j, 1]

@inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[1, j, 1]
@inbounds ∂μ[i, j, k] = -∂x[i, j, k]
@inbounds ∂σ²[i, j, k] = -∂x[i, j, k] ** idenom² / 2

if scale !== nothing
@inbounds ∂sc[i, j, k] = ∂y[i, j, k] ** idenom
@inbounds ∂b[i, j, k] = ∂y[i, j, k]
end
end

function ∇affine_normalize_bn_impl(
::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc, _bc)
∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²)
half = eltype(∂σ²)(0.5)

for K in axes(∂y, 3), J in axes(∂y, 2)
@inbounds idenom = _sc[1, J, 1]
idenom² = idenom^2
@simd for I in axes(∂y, 1)
@inbounds= x[I, J, K] - μ[1, J, 1]

@inbounds ∂x[I, J, K] = ∂y[I, J, K] * idenom
@inbounds ∂μ[1, J, 1] -= ∂x[I, J, K]
@inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] ** half * idenom²
end
end

return ∂x, ∂μ, ∂σ², ∂∅, ∂∅
end

function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc)
∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias)
half = eltype(∂σ²)(0.5)

for K in axes(∂y, 3), J in axes(∂y, 2)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, J, 1] + ϵ))
idenom² = idenom^2
@simd for I in axes(∂y, 1)
@inbounds= x[I, J, K] - μ[1, J, 1]

@inbounds ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1]
@inbounds ∂μ[1, J, 1] -= ∂x[I, J, K]
@inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] ** half * idenom²
@inbounds ∂sc[1, J, 1] += ∂y[I, J, K] ** idenom
@inbounds ∂b[1, J, 1] += ∂y[I, J, K]
end
end

return ∂x, ∂μ, ∂σ², ∂sc, ∂b
end

## Group Normalization

function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F,
x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N))
μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N))
σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N))
scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1)
bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1)

return reshape(
_affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x))
end

function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F,
Expand Down Expand Up @@ -146,13 +374,27 @@ function ∇affine_normalize_gn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale,
∂sc = scale === nothing ? ∂∅ : similar(scale, size(x))
∂b = bias === nothing ? ∂∅ : similar(bias, size(x))

fill!(∂μ, false)
fill!(∂σ², false)
scale === nothing || fill!(∂sc, false)
bias === nothing || fill!(∂b, false)

backend = KA.get_backend(∂x)
kernel! = ∇affine_normalize_gn_kernel!(backend)
kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ; ndrange=size(∂x))
KA.synchronize(backend)

return (∂x, __reduce_sum(μ, ∂μ), __reduce_sum(σ², ∂σ²),
__reduce_sum(scale, ∂sc), __reduce_sum(bias, ∂b))
∂μ_ = __reduce_sum(μ, ∂μ)
∂σ²_ = __reduce_sum(σ², ∂σ²)
∂sc_ = __reduce_sum(scale, ∂sc)
∂b_ = __reduce_sum(bias, ∂b)

__unsafe_free!(∂μ)
__unsafe_free!(∂σ²)
__unsafe_free!(∂sc)
__unsafe_free!(∂b)

return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_
end

@kernel function ∇affine_normalize_gn_kernel!(
Expand Down
10 changes: 10 additions & 0 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,13 @@ function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector},
x, nothing, nothing, reduce_dims, Val(false), nothing)
return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon)
end

function _batchnorm_impl(x::AbstractArray, running_mean::Optional{<:AbstractVector},
running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, reduce_dims::Val,
training::Val, momentum, epsilon, act::F=identity) where {F}
(μ, σ²), (rμ, rσ²) = _get_batch_statistics(
x, _reshape_into_normalization_shape(running_mean, x),
_reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum)
return _affine_normalize_bn(act, x, μ, σ², scale, bias, epsilon), _vec(rμ), _vec(rσ²)
end
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ CRC.@non_differentiable __depwarn(::Any...)
EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing

__eltype(::AbstractArray{T}) where {T} = T
__eltype(::T) where {T <: Number} = T
__eltype(::Nothing) = Bool

CRC.@non_differentiable __eltype(::Any)
Expand All @@ -148,6 +149,12 @@ __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T)
CRC.@non_differentiable __default_epsilon(::Any...)
EnzymeRules.inactive_noinl(::typeof(__default_epsilon), ::Any...) = nothing

__unsafe_free!(x) = nothing
__unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x)

CRC.@non_differentiable __unsafe_free!(::Any)
EnzymeRules.inactive_noinl(::typeof(__unsafe_free!), ::Any) = nothing

# Meta Programming Utilities
__is_tracked(x) = x == :TrackedArray || x == :TrackedVector
__is_tracked(args...) = any(__is_tracked, args)
Expand Down

0 comments on commit 9960aa0

Please sign in to comment.