Skip to content

Commit

Permalink
Don't register signal in thread (#10610)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
awaelchli and tchaton committed Nov 24, 2021
1 parent c179a7d commit 7d35da1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631))


- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))



## [1.5.2] - 2021-11-16

Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import signal
import sys
import threading
from signal import Signals
from subprocess import call
from types import FrameType, FunctionType
Expand Down Expand Up @@ -43,11 +44,11 @@ def register_signal_handlers(self) -> None:

# signal.SIGUSR1 doesn't seem available on windows
if not self._is_on_windows():
if not self._has_already_handler(signal.SIGUSR1):
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))
if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1):
self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))

if not self._has_already_handler(signal.SIGTERM):
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))

def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
if self.trainer.is_global_zero:
Expand Down Expand Up @@ -107,3 +108,8 @@ def _has_already_handler(self, signum: Signals) -> bool:
return isinstance(signal.getsignal(signum), FunctionType)
except AttributeError:
return False

@staticmethod
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
if threading.current_thread() is threading.main_thread():
signal.signal(signum, handlers)
17 changes: 17 additions & 0 deletions tests/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import os
import signal
from time import sleep
Expand Down Expand Up @@ -57,3 +58,19 @@ def training_step(self, batch, batch_idx):
else:
trainer.fit(model)
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)

# reset the signal to system defaults
signal.signal(signal.SIGUSR1, signal.SIG_DFL)


def _registering_signals():
trainer = Trainer()
trainer.signal_connector.register_signal_handlers()


@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_signal_connector_in_thread():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]):
assert future.exception() is None

0 comments on commit 7d35da1

Please sign in to comment.