Skip to content

Commit

Permalink
Fix saving relative symlink for ModelCheckpoint callback (#19303)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
shenmishajing and awaelchli authored Jan 20, 2024
1 parent e89f46a commit d02009a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))


- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))


## [2.1.3] - 2023-12-21

### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> Non
elif os.path.isdir(linkpath):
shutil.rmtree(linkpath)
try:
os.symlink(filepath, linkpath)
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
except OSError:
# on Windows, special permissions are required to create symbolic links as a regular user
# fall back to copying the file
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,20 +534,23 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(file)
assert not os.path.isabs(os.readlink(link))

# link exists (is a file)
new_file1 = tmp_path / "new_file1"
new_file1.touch()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file1)
assert not os.path.isabs(os.readlink(link))

# link exists (is a link)
new_file2 = tmp_path / "new_file2"
new_file2.touch()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file2)
assert not os.path.isabs(os.readlink(link))

# link exists (is a folder)
folder = tmp_path / "folder"
Expand All @@ -557,13 +560,15 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(folder)
assert not os.path.isabs(os.readlink(folder_link))

# link exists (is a link to a folder)
new_folder = tmp_path / "new_folder"
new_folder.mkdir()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(new_folder)
assert not os.path.isabs(os.readlink(folder_link))

# simulate permission error on Windows (creation of symbolic links requires privileges)
file = tmp_path / "win_file"
Expand All @@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
assert os.path.isfile(link) # fall back to copying instead of linking


def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
"""Test that linking a checkpoint works with relative paths."""
trainer = Mock()
monkeypatch.chdir(tmp_path)

folder = Path("x/z/z")
folder.mkdir(parents=True)
file = folder / "file"
file.touch()
link = folder / "link"
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
assert os.path.islink(link)
assert Path(os.readlink(link)) == file.relative_to(folder)
assert not os.path.isabs(os.readlink(link))


def test_invalid_top_k(tmpdir):
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):
Expand Down

0 comments on commit d02009a

Please sign in to comment.