diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index b936fde6db..660c2272ad 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -9,6 +9,7 @@ from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.callbacks.early_stopper import EarlyStopper from composer.callbacks.export_for_inference import ExportForInferenceCallback +from composer.callbacks.health_checker import HealthChecker from composer.callbacks.image_visualizer import ImageVisualizer from composer.callbacks.lr_monitor import LRMonitor from composer.callbacks.memory_monitor import MemoryMonitor @@ -29,5 +30,6 @@ 'ExportForInferenceCallback', 'ThresholdStopper', 'ImageVisualizer', + 'HealthChecker', 'RuntimeEstimator', ] diff --git a/composer/callbacks/health_checker.py b/composer/callbacks/health_checker.py new file mode 100644 index 0000000000..b052bb47cb --- /dev/null +++ b/composer/callbacks/health_checker.py @@ -0,0 +1,193 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Check GPU Health during training.""" +import logging +from collections import deque +from datetime import datetime +from typing import List, Optional, Tuple + +import torch + +try: + import pynvml +except ImportError: + pynvml = None + +import os + +import numpy as np +from slack_sdk.webhook import WebhookClient + +from composer.core import Callback, State +from composer.core.time import Timestamp +from composer.loggers import Logger +from composer.utils import dist + +log = logging.getLogger(__name__) + +__all__ = ['HealthChecker'] + + +class HealthChecker(Callback): + """Checks for GPU health. + + This callback checks for GPU health by tracking and alerting for abnormal + GPU utilizations. + + For example, if the average utilization during the observation window is, + [30, 30, 45], then the range (45-30=15) would exceed a threshold of 10%. + + Args: + threshold (float, optional): Threshold of GPU utilization range to + trigger an alert. Defaults to 10. + sample_freq (int, optional): Sample frequency in seconds. Default: 5. + window_size (int, optional): Window size in seconds. HealthChecker will + check for abnormalities at this frequency. Default: 120. + wait (int, optional): Seconds to wait for starting to sample. Default: 120. + slack_webhook_url (str, optional): Slack URL to send alerts. Can also + be set with the SLACK_WEBHOOK_URL environment variable. Default: None + test_mode (bool, optional): If True, will send a test alert at the first check. + Default: False + """ + + def __init__( + self, + threshold: float = 10, + sample_freq: int = 5, + window_size: int = 120, + wait: int = 120, + slack_webhook_url: Optional[str] = None, + test_mode: bool = False, + ) -> None: + self.sample_freq = sample_freq + self.window_size = window_size + self.wait = wait + self.slack_webhook_url = slack_webhook_url + self.test_mode = test_mode + + if not self.slack_webhook_url: + self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None) + + self.last_sample = 0 + self.last_check = 0 + + self.metrics = [] + if self._is_available(): + self.metrics.append(GPUUtilization(threshold)) + + def init(self, state: State, logger: Logger) -> None: + pass + + def after_train_batch(self, state: State, logger: Logger): + if not self.metrics: + return + + if self._sample(state.timestamp): + for metric in self.metrics: + metric.sample() + + if self._check(state.timestamp): + for metric in self.metrics: + message, alert = metric.check() + if self.test_mode and message: + alert = True + message = '[**THIS IS A TEST**]' + message + if alert and not metric.alerted: + self._alert(message, state) + metric.alerted = True + metric.clear() + + def _sample(self, timestamp: Timestamp) -> bool: + now = timestamp.total_wct.seconds + + if now < self.wait: + return False + + if now - self.last_sample >= self.sample_freq: + self.last_sample = now + return True + + return False + + def _check(self, timestamp: Timestamp) -> bool: + now = timestamp.total_wct.seconds + + if now - self.last_check >= self.window_size: + self.last_check = now + return True + return False + + def _alert(self, message: str, state: State) -> None: + prefix = '[{now}][{run_name}][node_rank={node_rank}]'.format( + now=datetime.now(), + run_name=state.run_name, + node_rank=dist.get_node_rank(), + ) + + node_name = os.environ.get('NODENAME', None) + if node_name is not None: + prefix += f'[node={node_name}]' + + message = prefix + ' : ' + message + + logging.warning(message) + if self.slack_webhook_url: + client = WebhookClient(url=self.slack_webhook_url) + client.send(text=message) + + @staticmethod + def _is_available() -> bool: + if not torch.cuda.is_available(): + return False + try: + pynvml.nvmlInit() # type: ignore + return True + except pynvml.NVMLError_LibraryNotFound: # type: ignore + logging.warning('NVML not found, disabling GPU health checking') + except ImportError: + logging.warning('pynvml library not found, disabling GPU health checking.') + except Exception as e: + logging.warning(f'Error initializing NVML: {e}') + + return False + + +class GPUUtilization: + """GPU Utilization Metric.""" + + def __init__(self, threshold=10) -> None: + self.samples = deque() + self.threshold = threshold + self.alerted = False + + def sample(self) -> None: + if dist.get_local_rank() == 0: + sample = self._sample() + if sample is not None: + self.samples.append(sample) + + def _sample(self) -> Optional[List]: + try: + samples = [] + device_count = pynvml.nvmlDeviceGetCount() # type: ignore + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore + samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore + except pynvml.NVMLError: # type: ignore + return None + return samples + + def check(self) -> Tuple[Optional[str], bool]: + if dist.get_local_rank() == 0: + average_sample = np.nanmean(list(self.samples), axis=0) + if np.nanmax(average_sample) - np.nanmin(average_sample) > self.threshold: + message = f'Abnormal GPU utilizations: {average_sample}' + return message, True + else: + message = f':+1: Normal GPU utilizations: {average_sample}' + return message, False + return None, False + + def clear(self) -> None: + self.samples.clear() diff --git a/setup.py b/setup.py index 8a3800291d..b62c4829b8 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,11 @@ def package_files(prefix: str, directory: str, extension: str): 'setuptools<=59.5.0', ] +extra_deps['health_checker'] = { + 'pynvml>=11.5.0,<12', + 'slack_sdk>=3.19.5,<4', +} + extra_deps['deepspeed'] = [ 'deepspeed==0.7.7', ] diff --git a/tests/callbacks/test_health_checker.py b/tests/callbacks/test_health_checker.py new file mode 100644 index 0000000000..deff058a45 --- /dev/null +++ b/tests/callbacks/test_health_checker.py @@ -0,0 +1,109 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from composer import Timestamp +from composer.callbacks import HealthChecker +from composer.callbacks.health_checker import GPUUtilization +from composer.utils import dist +from tests.common import world_size + +pynvml = pytest.importorskip('pynvml') +pytest.importorskip('slack_sdk') + + +class MockUtil: + + def __init__(self, util): + self.gpu = util + + +@pytest.mark.gpu +@world_size(1, 2) +def test_gpu_utilization(world_size): + assert HealthChecker._is_available() + + gpu_utilization_values = [ + MockUtil(100), + MockUtil(10), + MockUtil(100), + MockUtil(100), + MockUtil(100), + MockUtil(100), + ] + + with patch.multiple(pynvml, + nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values), + nvmlDeviceGetCount=MagicMock(return_value=world_size)): + + gpu_utilization = GPUUtilization() + gpu_utilization.sample() + gpu_utilization.sample() + gpu_utilization.sample() + _, alert = gpu_utilization.check() + + should_alert = dist.get_local_rank() == 0 and world_size > 1 + assert alert == should_alert + + +@pytest.mark.gpu +@world_size(1, 2) +def test_health_checker(world_size): + + state = MagicMock() + state.run_name = 'pytest-mock-run-kwei73' + logger = MagicMock() + + health_checker = HealthChecker( + sample_freq=1, + window_size=3, + wait=0, + ) + + gpu_utilization_values = [ + MockUtil(100), + MockUtil(10), + MockUtil(100), + MockUtil(100), + MockUtil(100), + MockUtil(100), + ] + + with patch.multiple(pynvml, + nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values), + nvmlDeviceGetCount=MagicMock(return_value=world_size)): + + # collect data and checker + for seconds in [1, 2, 3]: + state.timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds)) + health_checker.after_train_batch(state, logger) + + should_alert = dist.get_local_rank() == 0 and world_size > 1 + assert health_checker.metrics[0].alerted == should_alert + + +def test_health_checker_sampling(): + timestamp = Timestamp(total_wct=datetime.timedelta(seconds=0)) + + health_checker = HealthChecker( + sample_freq=1, + window_size=5, + wait=10, + ) + + config = [ + (5, False), # before wait + (11, True), + (11.5, False), # below sample frequency + (12, True), + (20, True), + (11, False), # no time travel + ] + + for seconds, is_sample in config: + timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds)) + assert health_checker._sample(timestamp) == is_sample