Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/4] Add get_device_stats to accelerator interface #9586

Merged
merged 31 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c11eb87
Add interface to accelerator to get_device_stats
daniellepintz Sep 17, 2021
cba4916
Update changelog
daniellepintz Sep 17, 2021
d4252c5
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 17, 2021
d0e1233
address comments
daniellepintz Sep 17, 2021
269f3ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
4d8cc75
comments
daniellepintz Sep 18, 2021
8e37419
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
6d9cc2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
018e5cd
fix gpu
daniellepintz Sep 18, 2021
310f254
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
ec8084d
fix
daniellepintz Sep 18, 2021
5abce11
update docstring
daniellepintz Sep 18, 2021
3936242
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
0fdd368
fix tests
daniellepintz Sep 18, 2021
32f1047
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
d8314cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
5699e85
type fix
daniellepintz Sep 18, 2021
8d66aba
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
3ac0821
fix test
daniellepintz Sep 18, 2021
1160cd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
ef5bc17
Update pytorch_lightning/accelerators/gpu.py
daniellepintz Sep 19, 2021
497680c
address comments
daniellepintz Sep 21, 2021
d3d13ec
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 21, 2021
ae7e912
Add unit tests
daniellepintz Sep 23, 2021
418e4a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
46b9f36
comments
daniellepintz Sep 23, 2021
ccadca5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
2658b4a
lint
daniellepintz Sep 23, 2021
07bc597
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 23, 2021
e19239e
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 23, 2021
c4f0d02
comments
daniellepintz Sep 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


- Added `get_device_stats` to Accelerator interface and implement it for GPU and TPU ([#9586](https://github.com/PyTorchLightning/pytorch-lightning/pull/9586))
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved


- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))


Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Accelerator:
- CPU
- GPU
- TPU
- IPU

Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
Expand Down Expand Up @@ -436,6 +437,10 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
return self.training_type_plugin.restore_checkpoint_after_pre_dispatch

def get_device_stats(self, device: Optional[torch.device] = None) -> Dict[str, Any]:
"""Gets stats for a given device."""
pass
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

def on_train_start(self) -> None:
"""Called when train begins."""
return self.training_type_plugin.on_train_start()
Expand Down
63 changes: 63 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.
import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Optional

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8

_log = logging.getLogger(__name__)

Expand All @@ -39,6 +43,13 @@ def setup(self, trainer: "pl.Trainer") -> None:
If the selected device is not GPU.
"""
self.set_nvidia_flags(trainer.local_rank)

# The logical device IDs for selected devices
self._device_ids: List[int] = sorted(set(trainer.data_parallel_device_ids))

# The unmasked real GPU IDs
self._gpu_ids: List[int] = self._get_gpu_ids(self._device_ids)

return super().setup(trainer)

def on_train_start(self) -> None:
Expand All @@ -53,6 +64,58 @@ def set_nvidia_flags(local_rank: int) -> None:
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: Optional[torch.device] = None) -> Dict[str, Any]:
"""Gets stats for the given GPU device."""
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
if _TORCH_GREATER_EQUAL_1_8:
return torch.cuda.memory_stats(device=device)
else:
gpu_stat_keys = [
("utilization.gpu", "%"),
("memory.used", "MB"),
("memory.free", "MB"),
("utilization.memory", "%"),
("fan.speed", "%"),
("temperature.gpu", "°C"),
("temperature.memory", "°C"),
]
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
device_stats = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
return device_stats
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
if not queries:
return []

"""Run nvidia-smi to get the gpu stats"""
gpu_query = ",".join(queries)
format = "csv,nounits,noheader"
gpu_ids = ",".join(self._gpu_ids)
result = subprocess.run(
[shutil.which("nvidia-smi"), f"--query-gpu={gpu_query}", f"--format={format}", f"--id={gpu_ids}"],
encoding="utf-8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
check=True,
)

def _to_float(x: str) -> float:
try:
return float(x)
except ValueError:
return 0.0

stats = result.stdout.strip().split(os.linesep)
stats = [[_to_float(x) for x in s.split(", ")] for s in stats]
return stats

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
return [cuda_visible_devices[device_id].strip() for device_id in device_ids]

daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
def teardown(self) -> None:
super().teardown()
self._move_optimizer_state(torch.device("cpu"))
17 changes: 16 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -59,3 +59,18 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)

def get_device_stats(self, device: Optional[torch.device] = None) -> Dict[str, Any]:
"""Gets stats for the given TPU device."""
device_stats = {}
memory_info = xm.get_memory_info(device)
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

free_memory = memory_info["kb_free"]
peak_memory = memory_info["kb_total"] - free_memory

free_memory = self.training_type_plugin.reduce(free_memory) * 0.001
peak_memory = self.training_type_plugin.reduce(peak_memory) * 0.001
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

device_stats["avg. free memory (MB)"] = free_memory
device_stats["avg. peak memory (MB)"] = peak_memory
return device_stats