From d41de6c0c2746430fd4ee33961ced932657a0b80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 9 Aug 2021 16:31:53 +0200 Subject: [PATCH] is-instance check to determine the type of a plugin for teardown decision (#8741) --- pytorch_lightning/trainer/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2514d064fb24a..07c8d45601b31 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -32,7 +32,7 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import Plugin +from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -76,7 +76,6 @@ ) from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden @@ -947,7 +946,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # teardown if necessary (similar calls for spawn plugins are excluded as they have # been included at the end of `new_process` functions) - if self._distrib_type not in DistributedType.interactive_compatible_types(): + if not isinstance(self.training_type_plugin, DDPSpawnPlugin): self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: