Skip to content

Commit

Permalink
reinitialize trainer after fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Dec 4, 2024
1 parent 2f9baee commit 86d946f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions molpipeline/estimators/chemprop/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def __init__(
self.trainer_params = get_params_trainer(self.lightning_trainer)
self.set_params(**kwargs)

def _update_trainer(
self,
) -> None:
def _update_trainer(self) -> None:
"""Update the trainer for the model."""
trainer_params = dict(self.trainer_params)
if self.model_ckpoint_params:
Expand Down Expand Up @@ -139,6 +137,8 @@ def fit(
X, batch_size=self.batch_size, num_workers=self.n_jobs
)
self.lightning_trainer.fit(self.model, training_data)
# The trainer is reinitalized to avoid storing the training data
self._update_trainer()
return self

def set_params(self, **params: Any) -> Self:
Expand Down

0 comments on commit 86d946f

Please sign in to comment.