Skip to content

Commit

Permalink
[2/4] Add DeviceStatsMonitor callback (#9712)
Browse files Browse the repository at this point in the history
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
7 people authored Oct 13, 2021
1 parent 23e8b59 commit 940b910
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when an unknown key is encountered in optimizer configuration, and when `OneCycleLR` is used with `"interval": "epoch"` ([#9666](https://github.com/PyTorchLightning/pytorch-lightning/pull/9666))


- Added `DeviceStatsMonitor` callback ([#9712](https://github.com/PyTorchLightning/pytorch-lightning/pull/9712))


- Added `enable_progress_bar` to Trainer constructor ([#9664](https://github.com/PyTorchLightning/pytorch-lightning/pull/9664))


Expand Down
1 change: 1 addition & 0 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ local tputests = base.BaseTest {
tests/profiler/test_xla_profiler.py \
pytorch_lightning/utilities/xla_device.py \
tests/accelerators/test_tpu_backend.py \
tests/callbacks/test_device_stats_monitor.py \
tests/models/test_tpu.py
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Currently there are accelerators for:
- CPU
- GPU
- TPU
- IPU

Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Lightning has a few built-in callbacks.
BaseFinetuning
BasePredictionWriter
Callback
DeviceStatsMonitor
EarlyStopping
GPUStatsMonitor
GradientAccumulationScheduler
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ignore_errors = "True"

[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.device_stats_monitor",
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def setup(self, trainer: "pl.Trainer") -> None:
return super().setup(trainer)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Returns dummy implementation for now."""
"""CPU device stats aren't supported yet."""
return {}
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
from typing import Any, Callable, Dict, Union

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -37,3 +38,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
# Optimizer step is handled by the IPU accelerator.
lambda_closure()

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""IPU device stats aren't supported yet."""
return {}
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
Expand All @@ -33,6 +34,7 @@
"BackboneFinetuning",
"BaseFinetuning",
"Callback",
"DeviceStatsMonitor",
"EarlyStopping",
"GPUStatsMonitor",
"XLAStatsMonitor",
Expand Down
82 changes: 82 additions & 0 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Device Stats Monitor
====================
Monitors and logs device stats during training.
"""
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT


class DeviceStatsMonitor(Callback):
r"""
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
Raises:
MisconfigurationException:
If ``Trainer`` has no logger.
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger_connector.should_update_logs:
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger_connector.should_update_logs:
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
return {prefix + "." + k: v for k, v in metrics_dict.items()}
16 changes: 0 additions & 16 deletions tests/accelerators/test_tpu.py

This file was deleted.

130 changes: 130 additions & 0 deletions tests/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


@RunIf(min_torch="1.8")
@RunIf(min_gpus=1)
def test_device_stats_gpu_from_torch(tmpdir):
"""Test GPU stats are logged using a logger with Pytorch >= 1.8.0."""
model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=7,
log_every_n_steps=1,
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


@RunIf(max_torch="1.7")
@RunIf(min_gpus=1)
def test_device_stats_gpu_from_nvidia(tmpdir):
"""Test GPU stats are logged using a logger with Pytorch < 1.8.0."""
model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=7,
log_every_n_steps=1,
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


@RunIf(tpu=True)
def test_device_stats_monitor_tpu(tmpdir):
"""Test TPU stats are logged using a logger."""

model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
tpu_cores=8,
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


def test_device_stats_monitor_no_logger(tmpdir):
"""Test DeviceStatsMonitor with no logger in Trainer."""

model = BoringModel()
device_stats = DeviceStatsMonitor()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[device_stats],
max_epochs=1,
logger=False,
checkpoint_callback=False,
enable_progress_bar=False,
)

with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
trainer.fit(model)

0 comments on commit 940b910

Please sign in to comment.