Skip to content

Commit

Permalink
Merge branch 'FluxML:master' into 0.13.5_regression_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored Aug 28, 2022
2 parents 0273c50 + 6c747f3 commit 7064b94
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels"
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
Expand Down
8 changes: 8 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ function InstanceNorm(chs::Int, λ=identity;
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)

if track_stats
Base.depwarn("`track_stats=true` will be removed from InstanceNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :InstanceNorm)
end

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(chs) : nothing
Expand Down Expand Up @@ -529,6 +533,10 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)

if track_stats
Base.depwarn("`track_stats=true` will be removed from GroupNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :GroupNorm)
end

chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")

β = affine ? initβ(chs) : nothing
Expand Down
17 changes: 16 additions & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ struct SkipException <: Exception end
Call `Flux.skip()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to skip the current data point and not update with the calculated gradient.
!!! note
`Flux.skip()` will be removed from Flux 0.14
# Examples
```julia
cb = function ()
Expand All @@ -46,6 +49,8 @@ end
```
"""
function skip()
Base.depwarn("""Flux.skip() will be removed from Flux 0.14.
and should be replaced with `continue` in an ordinary `for` loop.""", :skip)
throw(SkipException())
end

Expand All @@ -58,6 +63,9 @@ struct StopException <: Exception end
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to stop and exit.
!!! note
`Flux.stop()` will be removed from Flux 0.14. It should be replaced with `break` in an ordinary `for` loop.
# Examples
```julia
cb = function ()
Expand All @@ -66,6 +74,8 @@ end
```
"""
function stop()
Base.depwarn("""Flux.stop() will be removed from Flux 0.14.
It should be replaced with `break` in an ordinary `for` loop.""", :stop)
throw(StopException())
end

Expand Down Expand Up @@ -140,8 +150,11 @@ end
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
!!! note
The macro `@epochs` will be removed from Flux 0.14. Please just write an ordinary `for` loop.
# Examples
```jldoctest
```julia
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
Expand All @@ -150,6 +163,8 @@ hello
```
"""
macro epochs(n, ex)
Base.depwarn("""The macro `@epochs` will be removed from Flux 0.14.
As an alternative, you can write a simple `for i in 1:epochs` loop.""", Symbol("@epochs"), force=true)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
Expand Down
9 changes: 6 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,8 @@ julia> loss() = rand();
julia> trigger = Flux.patience(() -> loss() < 1, 3);
julia> Flux.@epochs 10 begin
julia> for i in 1:10
@info "Epoch \$i"
trigger() && break
end
[ Info: Epoch 1
Expand Down Expand Up @@ -685,7 +686,8 @@ julia> loss = let l = 0
julia> es = Flux.early_stopping(loss, 3);
julia> Flux.@epochs 10 begin
julia> for i in 1:10
@info "Epoch \$i"
es() && break
end
[ Info: Epoch 1
Expand Down Expand Up @@ -726,7 +728,8 @@ julia> f = let v = 10
julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);
julia> Flux.@epochs 10 begin
julia> for i in 1:10
@info "Epoch \$i"
trigger() && break
end
[ Info: Epoch 1
Expand Down

0 comments on commit 7064b94

Please sign in to comment.