-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdecayer.py
30 lines (24 loc) · 894 Bytes
/
decayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
from torch.nn import init
class Decayer(nn.Module):
def __init__(self, device, w, decay_method='exp'):
super(Decayer,self).__init__()
self.decay_method = decay_method
self.linear = nn.Linear(1,1,False).to(device)
self.w = w
def exponetial_decay(self, delta_t):
return torch.exp(-self.w*delta_t)
def log_decay(self, delta_t):
return 1/torch.log(2.7183 + self.w*delta_t)
def rev_decay(self, delta_t):
return 1/(1 + self.w*delta_t)
def forward(self,delta_t):
if self.decay_method == 'exp':
return self.exponetial_decay(delta_t)
elif self.decay_method == 'log':
return self.log_decay(delta_t)
elif self.decay_method == 'rev':
return self.rev_decay(delta_t)
else:
return self.exponetial_decay(delta_t)