From 1158df85fa3b7ea0977be5c7e8eefb5aaf667965 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 2 Sep 2023 16:48:21 -0400 Subject: [PATCH] Re-fix type stability --- Project.toml | 2 +- src/impl/normalization.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 8b6329ac..44514925 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.2" +version = "0.3.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index a4e6701a..20337774 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -3,12 +3,13 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) + m_ = m / (m - one(m)) if last(reduce_dims) != N batchmean = mean(batchmean; dims=N) batchvar = mean(batchvar; dims=N) end running_mean = @. (1 - momentum) * running_mean + momentum * batchmean - running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) + running_var = @. (1 - momentum) * running_var + momentum * batchvar * m_ return (running_mean, running_var) end