-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathschedulers.py
58 lines (49 loc) · 1.3 KB
/
schedulers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Schedulers for our model
import optax
def get_adam(params):
optimizer = optax.adam(learning_rate = 1e-5)
opt_state = optimizer.init(params)
return optimizer, opt_state
def get_adan(params):
learning_rate = 5e-5
weight_decay = 1e-4
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
optimizer = optax.adan(
learning_rate=learning_rate,
weight_decay=weight_decay,
b1=beta1,
b2=beta2,
eps=epsilon
)
opt_state = optimizer.init(params)
return optimizer, opt_state
def get_onecycle(params, epoch):
if epoch == 0:
transition_steps = 2e6
peak_value = 1e-3
elif epoch == 1:
transition_steps = 1e6
peak_value = 1e-4
elif epoch == 2:
transition_steps = 500e3
peak_value = 1e-5
print(f'{transition_steps=}')
print(f'{peak_value=}')
one_cycle_schedule = optax.schedules.cosine_onecycle_schedule(
transition_steps = transition_steps,
peak_value = peak_value,
pct_start = 0.1,
div_factor = 10,
final_div_factor = 100
)
optimizer = optax.adamw(
learning_rate=one_cycle_schedule,
weight_decay=1e-4,
b1=0.9,
b2=0.999,
eps=1e-8
)
opt_state = optimizer.init(params)
return optimizer, opt_state