Skip to content

Commit

Permalink
Fix trivial comparison in model checkpoint test (#3464)
Browse files Browse the repository at this point in the history
We were comparing keys across the same checkpoint dict instead of ckpt_last vs ckpt_last_epoch

All other changes here are formatting
  • Loading branch information
ananthsub authored Sep 11, 2020
1 parent ef20310 commit d1d48e2
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import re
import pickle
import platform
import re
from pathlib import Path

import cloudpickle
import pytest
import torch

import tests.base.develop_utils as tutils
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
Expand All @@ -23,21 +22,31 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_batches=0.20,
max_epochs=2,
)
trainer.fit(model)
assert checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints"
assert (
checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints"
)


@pytest.mark.parametrize(
"logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")],
"logger_version,expected",
[(None, "version_0"), (1, "version_1"), ("awesome", "awesome")],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger)
trainer = Trainer(
default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger
)
trainer.fit(model)

ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
Expand Down Expand Up @@ -79,12 +88,17 @@ def on_train_end(self, trainer, pl_module):
)


@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Distributed training is not supported on Windows",
)
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
model_checkpoint = ModelCheckpointTestInvocations(
expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
distributed_backend="ddp_cpu",
num_processes=2,
Expand All @@ -102,31 +116,32 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
model_checkpoint = ModelCheckpoint(
filepath=tmpdir, save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs,
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
path_last_epoch = model_checkpoint.format_checkpoint_name(
num_epochs - 1, {}
) # epoch=3.ckpt
path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt
assert path_last_epoch != path_last
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)

trainer_keys = (
"epoch",
"global_step",
)
trainer_keys = ("epoch", "global_step")
for key in trainer_keys:
assert ckpt_last_epoch[key] == ckpt_last[key]

checkpoint_callback_keys = (
"best_model_score",
"best_model_path",
)
checkpoint_callback_keys = ("best_model_score", "best_model_path")
for key in checkpoint_callback_keys:
assert (
ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
ckpt_last["callbacks"][type(model_checkpoint)][key]
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
)

Expand All @@ -148,16 +163,16 @@ def test_ckpt_metric_names(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
ckpts = [x for x in ckpts if "val_loss" in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
val = re.sub("[^0-9.]", "", ckpts[0])
assert len(val) > 3


Expand All @@ -179,14 +194,14 @@ def test_ckpt_metric_names_results(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
ckpts = [x for x in ckpts if "val_loss" in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
val = re.sub("[^0-9.]", "", ckpts[0])
assert len(val) > 3

0 comments on commit d1d48e2

Please sign in to comment.