diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index a205029..6988728 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -116,6 +116,8 @@ def fit(self, x, t, e, vsize=0.15, self.torch_model = model.eval() self.fitted = True + + return self def _prepocess_test_data(self, x): return torch.from_numpy(x)