Skip to content

Commit

Permalink
Add HealthChecker Callback (#2002)
Browse files Browse the repository at this point in the history
Adds a GPUHealth checker callback that alerts for anomalous GPU metrics
  • Loading branch information
hanlint authored Feb 28, 2023
1 parent e07de7e commit 6f68c16
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
Expand All @@ -29,5 +30,6 @@
'ExportForInferenceCallback',
'ThresholdStopper',
'ImageVisualizer',
'HealthChecker',
'RuntimeEstimator',
]
193 changes: 193 additions & 0 deletions composer/callbacks/health_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Check GPU Health during training."""
import logging
from collections import deque
from datetime import datetime
from typing import List, Optional, Tuple

import torch

try:
import pynvml
except ImportError:
pynvml = None

import os

import numpy as np
from slack_sdk.webhook import WebhookClient

from composer.core import Callback, State
from composer.core.time import Timestamp
from composer.loggers import Logger
from composer.utils import dist

log = logging.getLogger(__name__)

__all__ = ['HealthChecker']


class HealthChecker(Callback):
"""Checks for GPU health.
This callback checks for GPU health by tracking and alerting for abnormal
GPU utilizations.
For example, if the average utilization during the observation window is,
[30, 30, 45], then the range (45-30=15) would exceed a threshold of 10%.
Args:
threshold (float, optional): Threshold of GPU utilization range to
trigger an alert. Defaults to 10.
sample_freq (int, optional): Sample frequency in seconds. Default: 5.
window_size (int, optional): Window size in seconds. HealthChecker will
check for abnormalities at this frequency. Default: 120.
wait (int, optional): Seconds to wait for starting to sample. Default: 120.
slack_webhook_url (str, optional): Slack URL to send alerts. Can also
be set with the SLACK_WEBHOOK_URL environment variable. Default: None
test_mode (bool, optional): If True, will send a test alert at the first check.
Default: False
"""

def __init__(
self,
threshold: float = 10,
sample_freq: int = 5,
window_size: int = 120,
wait: int = 120,
slack_webhook_url: Optional[str] = None,
test_mode: bool = False,
) -> None:
self.sample_freq = sample_freq
self.window_size = window_size
self.wait = wait
self.slack_webhook_url = slack_webhook_url
self.test_mode = test_mode

if not self.slack_webhook_url:
self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None)

self.last_sample = 0
self.last_check = 0

self.metrics = []
if self._is_available():
self.metrics.append(GPUUtilization(threshold))

def init(self, state: State, logger: Logger) -> None:
pass

def after_train_batch(self, state: State, logger: Logger):
if not self.metrics:
return

if self._sample(state.timestamp):
for metric in self.metrics:
metric.sample()

if self._check(state.timestamp):
for metric in self.metrics:
message, alert = metric.check()
if self.test_mode and message:
alert = True
message = '[**THIS IS A TEST**]' + message
if alert and not metric.alerted:
self._alert(message, state)
metric.alerted = True
metric.clear()

def _sample(self, timestamp: Timestamp) -> bool:
now = timestamp.total_wct.seconds

if now < self.wait:
return False

if now - self.last_sample >= self.sample_freq:
self.last_sample = now
return True

return False

def _check(self, timestamp: Timestamp) -> bool:
now = timestamp.total_wct.seconds

if now - self.last_check >= self.window_size:
self.last_check = now
return True
return False

def _alert(self, message: str, state: State) -> None:
prefix = '[{now}][{run_name}][node_rank={node_rank}]'.format(
now=datetime.now(),
run_name=state.run_name,
node_rank=dist.get_node_rank(),
)

node_name = os.environ.get('NODENAME', None)
if node_name is not None:
prefix += f'[node={node_name}]'

message = prefix + ' : ' + message

logging.warning(message)
if self.slack_webhook_url:
client = WebhookClient(url=self.slack_webhook_url)
client.send(text=message)

@staticmethod
def _is_available() -> bool:
if not torch.cuda.is_available():
return False
try:
pynvml.nvmlInit() # type: ignore
return True
except pynvml.NVMLError_LibraryNotFound: # type: ignore
logging.warning('NVML not found, disabling GPU health checking')
except ImportError:
logging.warning('pynvml library not found, disabling GPU health checking.')
except Exception as e:
logging.warning(f'Error initializing NVML: {e}')

return False


class GPUUtilization:
"""GPU Utilization Metric."""

def __init__(self, threshold=10) -> None:
self.samples = deque()
self.threshold = threshold
self.alerted = False

def sample(self) -> None:
if dist.get_local_rank() == 0:
sample = self._sample()
if sample is not None:
self.samples.append(sample)

def _sample(self) -> Optional[List]:
try:
samples = []
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore
except pynvml.NVMLError: # type: ignore
return None
return samples

def check(self) -> Tuple[Optional[str], bool]:
if dist.get_local_rank() == 0:
average_sample = np.nanmean(list(self.samples), axis=0)
if np.nanmax(average_sample) - np.nanmin(average_sample) > self.threshold:
message = f'Abnormal GPU utilizations: {average_sample}'
return message, True
else:
message = f':+1: Normal GPU utilizations: {average_sample}'
return message, False
return None, False

def clear(self) -> None:
self.samples.clear()
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def package_files(prefix: str, directory: str, extension: str):
'setuptools<=59.5.0',
]

extra_deps['health_checker'] = {
'pynvml>=11.5.0,<12',
'slack_sdk>=3.19.5,<4',
}

extra_deps['deepspeed'] = [
'deepspeed==0.7.7',
]
Expand Down
109 changes: 109 additions & 0 deletions tests/callbacks/test_health_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import datetime
from unittest.mock import MagicMock, patch

import pytest

from composer import Timestamp
from composer.callbacks import HealthChecker
from composer.callbacks.health_checker import GPUUtilization
from composer.utils import dist
from tests.common import world_size

pynvml = pytest.importorskip('pynvml')
pytest.importorskip('slack_sdk')


class MockUtil:

def __init__(self, util):
self.gpu = util


@pytest.mark.gpu
@world_size(1, 2)
def test_gpu_utilization(world_size):
assert HealthChecker._is_available()

gpu_utilization_values = [
MockUtil(100),
MockUtil(10),
MockUtil(100),
MockUtil(100),
MockUtil(100),
MockUtil(100),
]

with patch.multiple(pynvml,
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
nvmlDeviceGetCount=MagicMock(return_value=world_size)):

gpu_utilization = GPUUtilization()
gpu_utilization.sample()
gpu_utilization.sample()
gpu_utilization.sample()
_, alert = gpu_utilization.check()

should_alert = dist.get_local_rank() == 0 and world_size > 1
assert alert == should_alert


@pytest.mark.gpu
@world_size(1, 2)
def test_health_checker(world_size):

state = MagicMock()
state.run_name = 'pytest-mock-run-kwei73'
logger = MagicMock()

health_checker = HealthChecker(
sample_freq=1,
window_size=3,
wait=0,
)

gpu_utilization_values = [
MockUtil(100),
MockUtil(10),
MockUtil(100),
MockUtil(100),
MockUtil(100),
MockUtil(100),
]

with patch.multiple(pynvml,
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
nvmlDeviceGetCount=MagicMock(return_value=world_size)):

# collect data and checker
for seconds in [1, 2, 3]:
state.timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
health_checker.after_train_batch(state, logger)

should_alert = dist.get_local_rank() == 0 and world_size > 1
assert health_checker.metrics[0].alerted == should_alert


def test_health_checker_sampling():
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=0))

health_checker = HealthChecker(
sample_freq=1,
window_size=5,
wait=10,
)

config = [
(5, False), # before wait
(11, True),
(11.5, False), # below sample frequency
(12, True),
(20, True),
(11, False), # no time travel
]

for seconds, is_sample in config:
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
assert health_checker._sample(timestamp) == is_sample

0 comments on commit 6f68c16

Please sign in to comment.