diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8bf34229a9742..4da535c207494 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -369,7 +369,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None: Args: trainer: the Trainer, these optimizers should be connected to """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn not in (TrainerFn.fit, TrainerFn.tune): return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 242eeed808f34..b37ff8d77b1ce 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -159,7 +159,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn - return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking + return trainer.state.fn != TrainerFn.fit or trainer.sanity_checking def on_train_epoch_end(self, trainer, pl_module) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fd749aef8977b..5d29ee07aa910 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -305,7 +305,7 @@ def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool: from pytorch_lightning.trainer.states import TrainerFn return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit + or trainer.state.fn != TrainerFn.fit # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == trainer.global_step # already saved at the last step ) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index df9f0ee158ba3..4ae13c4cb10a1 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -280,7 +280,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None if ( - self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None + self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) @@ -298,7 +298,7 @@ def __recover_child_process_weights(self, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.fit: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8dd04aafa6b86..ec9b710b2076f 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -532,7 +532,7 @@ def restore_model_state_from_ckpt_path( if not self.save_full_weights and self.world_size > 1: # Rely on deepspeed to load the checkpoint and necessary information from pytorch_lightning.trainer.states import TrainerFn - is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.fit save_dir = self._filepath_to_dir(ckpt_path) if self.zero_stage_3: diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index a75839cbdb714..61a831ccde997 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -208,7 +208,7 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.FITTING + return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.fit def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self): return self.world_size def handle_transferred_pipe_module(self) -> None: - if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if self.lightning_module.trainer.state.fn == TrainerFn.fit: torch_distrib.barrier() # Ensure we await main process initialization # Add trainer/configure_optimizers to the pipe model for access in all worker processes rpc_pipe.PipeModel.trainer = self.lightning_module.trainer @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None: # Create pipe_module model = self.lightning_module self._find_and_init_pipe_module(model) - if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if self.lightning_module.trainer.state.fn == TrainerFn.fit: torch_distrib.barrier() # Ensure we join main process initialization model.sequential_module.foreach_worker(register_optimizers, include_self=True) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 02da937286dcc..fbcf802736ed8 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -66,7 +66,7 @@ def _reinit_optimizers_with_oss(self): trainer.convert_to_lightning_optimizers() def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: + if self.model.trainer.state.fn != TrainerFn.fit: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5daf4e5be3735..d9d9efa6a751b 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -53,7 +53,7 @@ def _reinit_optimizers_with_oss(self): trainer.optimizers = optimizers def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: + if self.model.trainer.state.fn != TrainerFn.fit: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1d4a38498b20d..09b9ce113c569 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -190,7 +190,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None if ( - self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None + self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index e73bee761a241..e2f9d4f36d1fb 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -31,14 +31,14 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None: model: The model to check the configuration. """ - if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.trainer.state.fn in (TrainerFn.fit, TrainerFn.tune): self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'val') - elif self.trainer.state.fn == TrainerFn.VALIDATING: + elif self.trainer.state.fn == TrainerFn.validate: self.__verify_eval_loop_configuration(model, 'val') - elif self.trainer.state.fn == TrainerFn.TESTING: + elif self.trainer.state.fn == TrainerFn.test: self.__verify_eval_loop_configuration(model, 'test') - elif self.trainer.state.fn == TrainerFn.PREDICTING: + elif self.trainer.state.fn == TrainerFn.predict: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 09d66c13502a2..e84997a8717c3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -355,7 +355,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: # TODO(carmocca): when we implement flushing the logger connector metrics after # the trainer.state changes, this should check trainer.evaluating instead - if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING): + if self.trainer.state.fn in (TrainerFn.test, TrainerFn.validate): logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8c09de075147a..b38daa00c5aeb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -282,7 +282,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: # log results of evaluation if ( - self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero + self.trainer.state.fn != TrainerFn.fit and self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate ): print('-' * 80) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0d551446b5cf1..620f99b165341 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -101,7 +101,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook('on_validation_end', *args, **kwargs) - if self.trainer.state.fn != TrainerFn.FITTING: + if self.trainer.state.fn != TrainerFn.fit: # summarize profile results self.trainer.profiler.describe() diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index b1dfbc83ac4f6..ce72a3a527d9a 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -35,20 +35,20 @@ class TrainerFn(LightningEnum): such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. """ - FITTING = 'fit' - VALIDATING = 'validate' - TESTING = 'test' - PREDICTING = 'predict' - TUNING = 'tune' + fit = 'fit' + validate = 'validate' + test = 'test' + predict = 'predict' + tune = 'tune' @property def _setup_fn(self) -> 'TrainerFn': """ - ``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders. + ``fit`` is used instead of ``tune`` as there are no "tune" dataloaders. This is used for the ``setup()`` and ``teardown()`` hooks """ - return TrainerFn.FITTING if self == TrainerFn.TUNING else self + return TrainerFn.fit if self == TrainerFn.tune else self class RunningStage(LightningEnum): @@ -58,11 +58,11 @@ class RunningStage(LightningEnum): This stage complements :class:`TrainerFn` by specifying the current running stage for each function. More than one running stage value can be set while a :class:`TrainerFn` is running: - - ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}`` - - ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING`` - - ``TrainerFn.TESTING`` - ``RunningStage.TESTING`` - - ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING`` - - ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}`` + - ``TrainerFn.fit`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}`` + - ``TrainerFn.validate`` - ``RunningStage.VALIDATING`` + - ``TrainerFn.test`` - ``RunningStage.TESTING`` + - ``TrainerFn.predict`` - ``RunningStage.PREDICTING`` + - ``TrainerFn.tune`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}`` """ TRAINING = 'train' SANITY_CHECKING = 'sanity_check' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8732d8c33dce7..290274b461ca0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -432,7 +432,7 @@ def fit( """ Trainer._log_api_event("fit") - self.state.fn = TrainerFn.FITTING + self.state.fn = TrainerFn.fit self.state.status = TrainerStatus.RUNNING self.training = True @@ -492,7 +492,7 @@ def validate( Trainer._log_api_event("validate") self.verbose_evaluate = verbose - self.state.fn = TrainerFn.VALIDATING + self.state.fn = TrainerFn.validate self.state.status = TrainerStatus.RUNNING self.validating = True @@ -554,7 +554,7 @@ def test( Trainer._log_api_event("test") self.verbose_evaluate = verbose - self.state.fn = TrainerFn.TESTING + self.state.fn = TrainerFn.test self.state.status = TrainerStatus.RUNNING self.testing = True @@ -612,7 +612,7 @@ def predict( model = model or self.lightning_module - self.state.fn = TrainerFn.PREDICTING + self.state.fn = TrainerFn.predict self.state.status = TrainerStatus.RUNNING self.predicting = True @@ -660,7 +660,7 @@ def tune( """ Trainer._log_api_event("tune") - self.state.fn = TrainerFn.TUNING + self.state.fn = TrainerFn.tune self.state.status = TrainerStatus.RUNNING self.tuning = True @@ -742,7 +742,7 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # TRAIN # ---------------------------- # hook - if self.state.fn == TrainerFn.FITTING: + if self.state.fn == TrainerFn.fit: self.call_hook("on_fit_start") # plugin will setup fitting (e.g. ddp will launch child processes) @@ -758,7 +758,7 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # POST-Training CLEAN UP # ---------------------------- # hook - if self.state.fn == TrainerFn.FITTING: + if self.state.fn == TrainerFn.fit: self.call_hook('on_fit_end') # teardown diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index c9fb50e8501dd..a43f864575a74 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -58,19 +58,19 @@ def on_validation_batch_start(self, *_): def on_test_batch_start(self, *_): assert self.trainer.testing - model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING) + model = TestModel(TrainerFn.tune, RunningStage.TRAINING) trainer.tune(model) assert trainer.state.finished - model = TestModel(TrainerFn.FITTING, RunningStage.TRAINING) + model = TestModel(TrainerFn.fit, RunningStage.TRAINING) trainer.fit(model) assert trainer.state.finished - model = TestModel(TrainerFn.VALIDATING, RunningStage.VALIDATING) + model = TestModel(TrainerFn.validate, RunningStage.VALIDATING) trainer.validate(model) assert trainer.state.finished - model = TestModel(TrainerFn.TESTING, RunningStage.TESTING) + model = TestModel(TrainerFn.test, RunningStage.TESTING) trainer.test(model) assert trainer.state.finished diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f04061a23e096..3921b0c6a9acc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -334,7 +334,7 @@ def mock_save_function(filepath, *args): verbose=True ) trainer = Trainer() - trainer.state.fn = TrainerFn.FITTING + trainer.state.fn = TrainerFn.fit trainer.save_checkpoint = mock_save_function # emulate callback's calls during the training