Skip to content

Commit

Permalink
Update Activation class for easy select act type.
Browse files Browse the repository at this point in the history
  • Loading branch information
PistonY committed Apr 8, 2020
1 parent e2f465e commit 7f1549f
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions torchtoolbox/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .functional import swish
from torch import nn
from torch.nn import functional as F


class Swish(nn.Module):
Expand All @@ -21,3 +22,47 @@ def __init__(self, beta=1.0):

def forward(self, x):
return swish(x, self.beta)


class HardSwish(nn.Module):
def __init__(self, inplace=False):
super(HardSwish, self).__init__()
self.inplace = inplace

def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.


class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
super(HardSigmoid, self).__init__()
self.inplace = inplace

def forward(self, x):
return F.relu6(x + 3., inplace=self.inplace) / 6.


class Activation(nn.Module):
def __init__(self, act_type, **kwargs):
super(Activation, self).__init__()
if act_type == 'relu':
self.act = nn.ReLU(**kwargs)
elif act_type == 'relu6':
self.act = nn.ReLU6(**kwargs)
elif act_type == 'h_swish':
self.act = HardSwish(**kwargs)
elif act_type == 'h_sigmoid':
self.act = HardSigmoid(**kwargs)
elif act_type == 'swish':
self.act = Swish(**kwargs)
elif act_type == 'sigmoid':
self.act = nn.Sigmoid()
elif act_type == 'lrelu':
self.act = nn.LeakyReLU(**kwargs)
elif act_type == 'prelu':
self.act = nn.PReLU(**kwargs)
else:
raise NotImplementedError('{} activation is not implemented.'.format(act_type))

def forward(self, x):
return self.act(x)

0 comments on commit 7f1549f

Please sign in to comment.