Skip to content

Commit

Permalink
enhance overfit_batches logic + revert logging changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilienLC committed Jan 15, 2025
1 parent f1f3be0 commit e1c972b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
3 changes: 1 addition & 2 deletions cneuromax/fitting/deeplearning/litmodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def stage_step(
if isinstance(data, list):
data = tuple(data)
loss: Num[Tensor, " *_"] = self.step(data, stage)
self.log(name=f"{stage}_step/loss", value=loss, on_step=True)
self.log(name=f"{stage}_epoch/loss", value=loss, on_epoch=True)
self.log(name=f"{stage}/loss", value=loss)
return loss

@final
Expand Down
5 changes: 3 additions & 2 deletions cneuromax/fitting/deeplearning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def train(
>= TORCH_COMPILE_MINIMUM_CUDA_VERSION
):
litmodule.nnmodule = torch.compile(litmodule.nnmodule) # type: ignore [assignment]
litmodule.trainer = trainer
if trainer.overfit_batches > 0:
datamodule.val_dataloader = datamodule.train_dataloader
trainer.fit(
model=litmodule,
datamodule=datamodule,
ckpt_path=config.ckpt_path,
)
"""TODO: Add logic for HPO"""
return trainer.validate(model=litmodule, datamodule=datamodule)[0][
"val_epoch/loss"
"val/loss"
]
2 changes: 1 addition & 1 deletion cneuromax/fitting/deeplearning/utils/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def instantiate_trainer(
callbacks.append(
ModelCheckpoint(
dirpath=trainer_partial.keywords["default_root_dir"],
monitor="val_epoch/loss",
monitor="val/loss",
save_last=True,
save_top_k=1,
every_n_train_steps=save_every_n_train_steps,
Expand Down

0 comments on commit e1c972b

Please sign in to comment.