Skip to content

Commit

Permalink
Merge pull request #2047 from MarcoVela/patch-1
Browse files Browse the repository at this point in the history
Typo in BatchNorm number of channels assertion
  • Loading branch information
ToucheSir authored Aug 24, 2022
2 parents d4f1d81 + 3dbb05e commit f5882c7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels"
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
Expand Down

0 comments on commit f5882c7

Please sign in to comment.