Skip to content

Commit

Permalink
Lowercase TrainerFn
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 12, 2021
1 parent 5c52e89 commit 97a6628
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn not in (TrainerFn.fit, TrainerFn.tune):
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:

def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
return trainer.state.fn != TrainerFn.fit or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.state.fn != TrainerFn.fit # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand All @@ -298,7 +298,7 @@ def __recover_child_process_weights(self, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.fit:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def restore_model_state_from_ckpt_path(
if not self.save_full_weights and self.world_size > 1:
# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.fit
save_dir = self._filepath_to_dir(ckpt_path)

if self.zero_stage_3:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization
"""
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.FITTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.fit

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.fit:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.fit:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
if self.model.trainer.state.fn != TrainerFn.fit:
return
self._reinit_optimizers_with_oss()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
if self.model.trainer.state.fn != TrainerFn.fit:
return
self._reinit_optimizers_with_oss()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
model: The model to check the configuration.
"""
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.trainer.state.fn in (TrainerFn.fit, TrainerFn.tune):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state.fn == TrainerFn.VALIDATING:
elif self.trainer.state.fn == TrainerFn.validate:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state.fn == TrainerFn.TESTING:
elif self.trainer.state.fn == TrainerFn.test:
self.__verify_eval_loop_configuration(model, 'test')
elif self.trainer.state.fn == TrainerFn.PREDICTING:
elif self.trainer.state.fn == TrainerFn.predict:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# TODO(carmocca): when we implement flushing the logger connector metrics after
# the trainer.state changes, this should check trainer.evaluating instead
if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING):
if self.trainer.state.fn in (TrainerFn.test, TrainerFn.validate):
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:

# log results of evaluation
if (
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
self.trainer.state.fn != TrainerFn.fit and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

if self.trainer.state.fn != TrainerFn.FITTING:
if self.trainer.state.fn != TrainerFn.fit:
# summarize profile results
self.trainer.profiler.describe()

Expand Down
24 changes: 12 additions & 12 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ class TrainerFn(LightningEnum):
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
FITTING = 'fit'
VALIDATING = 'validate'
TESTING = 'test'
PREDICTING = 'predict'
TUNING = 'tune'
fit = 'fit'
validate = 'validate'
test = 'test'
predict = 'predict'
tune = 'tune'

@property
def _setup_fn(self) -> 'TrainerFn':
"""
``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
``fit`` is used instead of ``tune`` as there are no "tune" dataloaders.
This is used for the ``setup()`` and ``teardown()`` hooks
"""
return TrainerFn.FITTING if self == TrainerFn.TUNING else self
return TrainerFn.fit if self == TrainerFn.tune else self


class RunningStage(LightningEnum):
Expand All @@ -58,11 +58,11 @@ class RunningStage(LightningEnum):
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
More than one running stage value can be set while a :class:`TrainerFn` is running:
- ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.fit`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.validate`` - ``RunningStage.VALIDATING``
- ``TrainerFn.test`` - ``RunningStage.TESTING``
- ``TrainerFn.predict`` - ``RunningStage.PREDICTING``
- ``TrainerFn.tune`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""
TRAINING = 'train'
SANITY_CHECKING = 'sanity_check'
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def fit(
"""
Trainer._log_api_event("fit")

self.state.fn = TrainerFn.FITTING
self.state.fn = TrainerFn.fit
self.state.status = TrainerStatus.RUNNING
self.training = True

Expand Down Expand Up @@ -492,7 +492,7 @@ def validate(
Trainer._log_api_event("validate")
self.verbose_evaluate = verbose

self.state.fn = TrainerFn.VALIDATING
self.state.fn = TrainerFn.validate
self.state.status = TrainerStatus.RUNNING
self.validating = True

Expand Down Expand Up @@ -554,7 +554,7 @@ def test(
Trainer._log_api_event("test")
self.verbose_evaluate = verbose

self.state.fn = TrainerFn.TESTING
self.state.fn = TrainerFn.test
self.state.status = TrainerStatus.RUNNING
self.testing = True

Expand Down Expand Up @@ -612,7 +612,7 @@ def predict(

model = model or self.lightning_module

self.state.fn = TrainerFn.PREDICTING
self.state.fn = TrainerFn.predict
self.state.status = TrainerStatus.RUNNING
self.predicting = True

Expand Down Expand Up @@ -660,7 +660,7 @@ def tune(
"""
Trainer._log_api_event("tune")

self.state.fn = TrainerFn.TUNING
self.state.fn = TrainerFn.tune
self.state.status = TrainerStatus.RUNNING
self.tuning = True

Expand Down Expand Up @@ -742,7 +742,7 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
# TRAIN
# ----------------------------
# hook
if self.state.fn == TrainerFn.FITTING:
if self.state.fn == TrainerFn.fit:
self.call_hook("on_fit_start")

# plugin will setup fitting (e.g. ddp will launch child processes)
Expand All @@ -758,7 +758,7 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
# POST-Training CLEAN UP
# ----------------------------
# hook
if self.state.fn == TrainerFn.FITTING:
if self.state.fn == TrainerFn.fit:
self.call_hook('on_fit_end')

# teardown
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,19 @@ def on_validation_batch_start(self, *_):
def on_test_batch_start(self, *_):
assert self.trainer.testing

model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING)
model = TestModel(TrainerFn.tune, RunningStage.TRAINING)
trainer.tune(model)
assert trainer.state.finished

model = TestModel(TrainerFn.FITTING, RunningStage.TRAINING)
model = TestModel(TrainerFn.fit, RunningStage.TRAINING)
trainer.fit(model)
assert trainer.state.finished

model = TestModel(TrainerFn.VALIDATING, RunningStage.VALIDATING)
model = TestModel(TrainerFn.validate, RunningStage.VALIDATING)
trainer.validate(model)
assert trainer.state.finished

model = TestModel(TrainerFn.TESTING, RunningStage.TESTING)
model = TestModel(TrainerFn.test, RunningStage.TESTING)
trainer.test(model)
assert trainer.state.finished

Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def mock_save_function(filepath, *args):
verbose=True
)
trainer = Trainer()
trainer.state.fn = TrainerFn.FITTING
trainer.state.fn = TrainerFn.fit
trainer.save_checkpoint = mock_save_function

# emulate callback's calls during the training
Expand Down

0 comments on commit 97a6628

Please sign in to comment.