Skip to content

Commit

Permalink
Make device and dtype required (#9168)
Browse files Browse the repository at this point in the history
  • Loading branch information
stancld authored Aug 28, 2021
1 parent 39dd3a6 commit 46b00a7
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
Expand All @@ -36,22 +35,16 @@


def to_dtype_tensor(
value: Union[int, float, List[Union[int, float]]],
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = None,
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)


def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> torch.Tensor:
return torch.from_numpy(value).to(device)


CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], torch.Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
Expand Down Expand Up @@ -276,9 +269,6 @@ def batch_to(data: Any) -> Any:


def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
if device is None:
raise MisconfigurationException("`torch.device` should be provided.")

for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device)

Expand Down

0 comments on commit 46b00a7

Please sign in to comment.