Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Re-fix type stability
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 2, 2023
1 parent b424886 commit 1158df8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.2"
version = "0.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R
running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N},
momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims}
m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims))
m_ = m / (m - one(m))
if last(reduce_dims) != N
batchmean = mean(batchmean; dims=N)
batchvar = mean(batchvar; dims=N)
end
running_mean = @. (1 - momentum) * running_mean + momentum * batchmean
running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m)))
running_var = @. (1 - momentum) * running_var + momentum * batchvar * m_
return (running_mean, running_var)
end

Expand Down

0 comments on commit 1158df8

Please sign in to comment.