diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 551a875..23bce8d 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -152,7 +152,7 @@ def __init__(self, inputdim, k, layers=None, dist='Weibull', self.act = nn.SELU() self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) for r in range(self.risks)}) - self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k)) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) for r in range(self.risks)}) elif self.dist in ['Normal']: self.act = nn.Identity()