Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to accelerators/gpu.py #11333

Merged
merged 6 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.accelerators.gpu",
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.base",
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
from __future__ import annotations

from typing import Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _DEVICE


class CPUAccelerator(Accelerator):
Expand All @@ -28,10 +31,10 @@ def setup_environment(self, root_device: torch.device) -> None:
MisconfigurationException:
If the selected device is not CPU.
"""
if "cpu" not in str(root_device):
if root_device.type != "cpu":
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""CPU device stats aren't supported yet."""
return {}

Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Union
from typing import Any

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.types import _DEVICE

_log = logging.getLogger(__name__)

Expand All @@ -36,11 +39,11 @@ def setup_environment(self, root_device: torch.device) -> None:
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
if root_device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
torch.cuda.set_device(root_device)

def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
# clear cache before training
Expand All @@ -54,7 +57,7 @@ def set_nvidia_flags(local_rank: int) -> None:
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Gets stats for the given GPU device.

Args:
Expand All @@ -77,7 +80,7 @@ def auto_device_count() -> int:
return torch.cuda.device_count()


def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

Args:
Expand Down Expand Up @@ -106,7 +109,8 @@ def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
gpu_query = ",".join(gpu_stat_keys)

gpu_id = _get_gpu_id(device.index)
index = torch._utils._get_device_index(device)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
gpu_id = _get_gpu_id(index)
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
Expand All @@ -130,5 +134,5 @@ def _get_gpu_id(device_id: int) -> str:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return cuda_visible_devices[device_id].strip()
23 changes: 13 additions & 10 deletions pytorch_lightning/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from __future__ import annotations

from typing import Any

import torch

Expand All @@ -20,20 +22,21 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.types import _DEVICE


class SingleDeviceStrategy(Strategy):
"""Strategy that handles communication on a single device."""

def __init__(
self,
device: torch.device,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
device: _DEVICE,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
accelerator: pl.accelerators.accelerator.Accelerator | None = None,
checkpoint_io: CheckpointIO | None = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
precision_plugin: PrecisionPlugin | None = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
self.device: torch.device = device
self._root_device = torch.device(device)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
Expand All @@ -46,7 +49,7 @@ def on_tpu(self) -> bool:
def on_gpu(self) -> bool:
return self.root_device.type == "cuda" and torch.cuda.is_available()

def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
operates with a single device, the reduction is simply the identity.

Expand All @@ -60,18 +63,18 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) ->
"""
return tensor

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes."""
return tensor

@property
def root_device(self) -> torch.device:
return self.device
return self._root_device

def model_to_device(self) -> None:
self.model.to(self.root_device)

def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
self.model_to_device()
super().setup(trainer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def select_strategy(self) -> Strategy:
plugin = IPUStrategy(parallel_devices=self.parallel_devices)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDeviceStrategy(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu"))
plugin = SingleDeviceStrategy(device=single_gpu_ordinal if self.use_gpu else "cpu")
return plugin

def resolve_strategy(self, training_type: Strategy) -> Strategy:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
carmocca marked this conversation as resolved.
Show resolved Hide resolved


# Copied from `torch.optim.lr_scheduler.pyi`
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_v1_8_0_deprecated_single_device_plugin_class():
" Use `.*SingleDeviceStrategy` instead."
)
):
SingleDevicePlugin(Mock())
SingleDevicePlugin("cpu")


@RunIf(tpu=True)
Expand Down
6 changes: 2 additions & 4 deletions tests/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ def test_checkpoint_plugin_called(tmpdir):
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)

model = BoringModel()
device = torch.device("cpu")
trainer = Trainer(
default_root_dir=tmpdir,
strategy=SingleDeviceStrategy(device, checkpoint_io=checkpoint_plugin),
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=2,
)
Expand All @@ -63,10 +62,9 @@ def test_checkpoint_plugin_called(tmpdir):
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)

model = BoringModel()
device = torch.device("cpu")
trainer = Trainer(
default_root_dir=tmpdir,
strategy=SingleDeviceStrategy(device),
strategy=SingleDeviceStrategy("cpu"),
plugins=[checkpoint_plugin],
callbacks=ck,
max_epochs=2,
Expand Down