Skip to content

Commit

Permalink
Remove *_step_end from LitModular (#170)
Browse files Browse the repository at this point in the history
Merge train/validation/test_step_end into train/validation/test_step.
  • Loading branch information
dxoigmn authored Jun 22, 2023
1 parent 4185b3d commit 1b55378
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions mart/models/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,15 @@ def training_step(self, batch, batch_idx):
for log_name, output_key in self.training_step_log.items():
self.log(f"training/{log_name}", output[output_key])

assert "loss" in output
return output

def training_step_end(self, output):
if self.training_metrics is not None:
# Some models only return loss in the training mode.
if self.output_preds_key not in output or self.output_target_key not in output:
raise ValueError(
f"You have specified training_metrics, but the model does not return {self.output_preds_key} or {self.output_target_key} during training. You can either nullify training_metrics or configure the model to return {self.output_preds_key} and {self.output_target_key} in the training output."
)
self.training_metrics(output[self.output_preds_key], output[self.output_target_key])
loss = output.pop(self.output_loss_key)
return loss

return output[self.output_loss_key]

def training_epoch_end(self, outputs):
if self.training_metrics is not None:
Expand All @@ -168,13 +164,9 @@ def validation_step(self, batch, batch_idx):
for log_name, output_key in self.validation_step_log.items():
self.log(f"validation/{log_name}", output[output_key])

return output

def validation_step_end(self, output):
self.validation_metrics(output[self.output_preds_key], output[self.output_target_key])

# I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.)
output.clear()
return None

def validation_epoch_end(self, outputs):
metrics = self.validation_metrics.compute()
Expand All @@ -194,13 +186,9 @@ def test_step(self, batch, batch_idx):
for log_name, output_key in self.test_step_log.items():
self.log(f"test/{log_name}", output[output_key])

return output

def test_step_end(self, output):
self.test_metrics(output[self.output_preds_key], output[self.output_target_key])

# I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.)
output.clear()
return None

def test_epoch_end(self, outputs):
metrics = self.test_metrics.compute()
Expand Down

0 comments on commit 1b55378

Please sign in to comment.