-
-
Notifications
You must be signed in to change notification settings - Fork 608
/
train.jl
138 lines (112 loc) · 2.84 KB
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
using Juno
import Zygote: Params, gradient
"""
update!(x, x̄)
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, x̄)
x .-= x̄
end
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
"""
function update!(opt, x, x̄)
x .-= apply!(opt, x, x̄)
end
function update!(opt, xs::Params, gs)
for x in xs
gs[x] == nothing && continue
update!(opt, x, gs[x])
end
end
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
struct SkipException <: Exception end
"""
skip()
Call `Flux.skip()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to skip the current data point and not update with the calculated gradient.
# Examples
```julia
cb = function ()
loss() > 1e7 && Flux.skip()
end
```
"""
function skip()
throw(SkipException())
end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to stop and exit.
# Examples
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data`, compute the gradient of `loss` with
respect to `params` through backpropagation and call the optimizer `opt`.
If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end
end
end
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
# Examples
```jldoctest
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
[ Info: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end