Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Dec 10, 2024
1 parent 32db7d4 commit 0b3322d
Showing 2 changed files with 11 additions and 16 deletions.
7 changes: 3 additions & 4 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -30,12 +30,9 @@
from typing import Any, Literal, Optional, Union, cast
from weakref import proxy

import pytorch_lightning as pl
import torch
import yaml
from torch import Tensor
from typing_extensions import override

import pytorch_lightning as pl
from lightning_fabric.utilities.cloud_io import (
_is_dir,
_is_local_file_protocol,
@@ -50,6 +47,8 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from typing_extensions import override

log = logging.getLogger(__name__)
warning_cache = WarningCache()
20 changes: 8 additions & 12 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -28,9 +28,6 @@
import pytest
import torch
import yaml
from tests_pytorch.helpers.runif import RunIf
from torch import optim

from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -39,6 +36,9 @@
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from torch import optim

from tests_pytorch.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
@@ -888,13 +888,11 @@ def test_default_checkpoint_behavior(tmp_path):
assert len(results) == 1
save_dir = tmp_path / "checkpoints"
save_weights_only = trainer.checkpoint_callback.save_weights_only
save_mock.assert_has_calls(
[
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
]
)
save_mock.assert_has_calls([
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
])
ckpts = os.listdir(save_dir)
assert len(ckpts) == 1
assert ckpts[0] == "epoch=2-step=15.ckpt"
@@ -1478,8 +1476,6 @@ def test_save_last_versioning(tmp_path):
assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path)))




def test_none_monitor_saves_correct_best_model_path(tmp_path):
mc = ModelCheckpoint(dirpath=tmp_path, monitor=None)
trainer = Trainer(callbacks=mc)

0 comments on commit 0b3322d

Please sign in to comment.