Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix / Enhancement array initialization different devices #678

Merged
merged 27 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# v0.5.1

- [#678](https://github.com/helmholtz-analytics/heat/pull/678) Bugfix: High number of implicit device transfers in `DNDarray.__init__()`

# v0.5.0

- [#488](https://github.com/helmholtz-analytics/heat/pull/488) Enhancement: Rework of the test device selection.
Expand Down
11 changes: 6 additions & 5 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .communication import MPI, MPI_WORLD
from . import factories
from . import devices
from . import stride_tricks
from . import sanitation
from . import dndarray
Expand Down Expand Up @@ -38,7 +39,9 @@ def __binary_op(operation, t1, t2, out=None):
"""
if np.isscalar(t1):
try:
t1 = factories.array([t1])
t1 = factories.array(
[t1], device=t2.device if isinstance(t2, dndarray.DNDarray) else None
)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))

Expand All @@ -51,11 +54,9 @@ def __binary_op(operation, t1, t2, out=None):
)
output_shape = (1,)
output_split = None
output_device = None
output_device = t2.device
output_comm = MPI_WORLD
elif isinstance(t2, dndarray.DNDarray):
t1 = t1.gpu() if t2.device.device_type == "gpu" else t1.cpu()

output_shape = t2.shape
output_split = t2.split
output_device = t2.device
Expand Down Expand Up @@ -154,7 +155,7 @@ def __binary_op(operation, t1, t2, out=None):
)

if not isinstance(result, torch.Tensor):
result = torch.tensor(result)
result = torch.tensor(result, device=output_device.torch_device)

if out is not None:
out_dtype = out.dtype
Expand Down
3 changes: 2 additions & 1 deletion heat/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):


# create a CPU device singleton
cpu = Device("cpu", 0, "cpu:0")
cpu = Device("cpu", 0, "cpu")

# define the default device to be the CPU
__default_device = cpu
Expand All @@ -71,6 +71,7 @@ def __str__(self):
gpu = Device("gpu", gpu_id, "cuda:{}".format(gpu_id))
# add a GPU device string
__device_mapping[gpu.device_type] = gpu
__device_mapping["cuda"] = gpu
# the GPU device should be exported as global symbol
__all__.append("gpu")

Expand Down
11 changes: 3 additions & 8 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,8 @@ def __init__(self, array, gshape, dtype, split, device, comm):
self.__halo_next = None
self.__halo_prev = None

# handle inconsistencies between torch and heat devices
if (
isinstance(array, torch.Tensor)
and isinstance(device, devices.Device)
and array.device.type != device.device_type
):
self.__array = self.__array.to(devices.sanitize_device(self.__device).torch_device)
# check for inconsistencies between torch and heat devices
assert str(array.device) == device.torch_device

@property
def halo_next(self):
Expand Down Expand Up @@ -1509,7 +1504,7 @@ def __getitem__(self, key):
chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device)
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
arr = torch.Tensor()
arr = torch.tensor([], device=self.device.torch_device)
# all keys should be tuples here
gout = [0] * len(self.gshape)
# handle the dimensional reduction for integers
Expand Down
30 changes: 27 additions & 3 deletions heat/core/factories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import warnings

from .communication import MPI, sanitize_comm
from .stride_tricks import sanitize_axis, sanitize_shape
Expand Down Expand Up @@ -289,6 +290,10 @@ def array(
if dtype is not None:
dtype = types.canonical_heat_type(dtype)

# sanitize device
if device is not None:
device = devices.sanitize_device(device)

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
# initialize the array
if bool(copy):
if isinstance(obj, torch.Tensor):
Expand All @@ -297,7 +302,13 @@ def array(
obj = obj.clone().detach()
else:
try:
obj = torch.tensor(obj, dtype=dtype.torch_type() if dtype is not None else None)
obj = torch.tensor(
obj,
dtype=dtype.torch_type() if dtype is not None else None,
device=device.torch_device
if device is not None
else devices.get_device().torch_device,
)
except RuntimeError:
raise TypeError("invalid data of type {}".format(type(obj)))

Expand All @@ -309,6 +320,20 @@ def array(
if obj.dtype != torch_dtype:
obj = obj.type(torch_dtype)

# infer device from obj if not explicitly given
if device is None:
device = devices.sanitize_device(obj.device.type)

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
# change device if it do not match
assert str(obj.device) == device.torch_device

if str(obj.device) != device.torch_device:
warnings.warn(
"Array 'obj' is not on device '{}'. It will be copied to it.".format(device),
UserWarning,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I keep the assertion or the warning here?

Copy link
Collaborator Author

@mtar mtar Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted the assertion as there are some rare cases where it is allowed. However the unittest results must be reviewed for possible wrong devices.

obj = obj.to(device.torch_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the implicit data transfer ok here? 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's analogous to dtype in factories.array(). The user explicitly wants to use a specific device here.

Copy link
Collaborator Author

@mtar mtar Oct 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On internal functions, however, it still can obfuscate some wrong devices if the device parameter is overused. I add a warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, yes I can see why we want to transfer data here. On internal functions, we have to keep an eye on it during review and testing I guess.


# sanitize minimum number of dimensions
if not isinstance(ndmin, int):
raise TypeError("expected ndmin to be int, but was {}".format(type(ndmin)))
Expand All @@ -326,8 +351,7 @@ def array(
if split is not None and is_split is not None:
raise ValueError("split and is_split are mutually exclusive parameters")

# sanitize device and object
device = devices.sanitize_device(device)
# sanitize comm object
comm = sanitize_comm(comm)

# determine the local and the global shape, if not split is given, they are identical
Expand Down
4 changes: 3 additions & 1 deletion heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,9 @@ def load_csv(
# Create empty tensor and iteratively fill it with the values
local_shape = (len(line_starts), columns)
actual_length = 0
local_tensor = torch.empty(local_shape, dtype=dtype.torch_type())
local_tensor = torch.empty(
local_shape, dtype=dtype.torch_type(), device=device.torch_device
)
for ind, start in enumerate(line_starts):
if ind == len(line_starts) - 1:
f.seek(displs[rank] + start, 0)
Expand Down
2 changes: 1 addition & 1 deletion heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False):
# If x is distributed, then y is also distributed along the same axis
if t1.comm.is_distributed() and t1.split is not None:
output_gshape = stride_tricks.broadcast_shape(t1.gshape, t2.gshape)
res = torch.empty(output_gshape).bool()
res = torch.empty(output_gshape, device=t1.device.torch_device).bool()
t1.comm.Allgather(_local_isclose, res)
result = factories.array(res, dtype=types.bool, device=t1.device, split=t1.split)
else:
Expand Down
10 changes: 6 additions & 4 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,9 @@ def pad(array, pad_width, mode="constant", constant_values=0):
0 if i == array.split else output_shape[i] for i in range(len(output_shape))
]
adapted_lshape = tuple(adapted_lshape_list)
padded_torch_tensor = torch.empty(adapted_lshape, dtype=array._DNDarray__array.dtype)
padded_torch_tensor = torch.empty(
adapted_lshape, dtype=array._DNDarray__array.dtype, device=array.device.torch_device
)
else:
if array.split is None or array.split not in pad_dim or amount_of_processes == 1:
# values = scalar
Expand Down Expand Up @@ -1987,8 +1989,8 @@ def stack(arrays, axis=0, out=None):
devices = list(array.device for array in arrays)
if devices.count(devices[0]) != num_arrays:
raise RuntimeError(
"DNDarrays in sequence must reside on the same device, got devices {}".format(
devices
"DNDarrays in sequence must reside on the same device, got devices {} {} {}".format(
devices, devices[0].device_id, devices[1].device_id
)
)
balance = list(array.is_balanced() for array in arrays)
Expand Down Expand Up @@ -2155,7 +2157,7 @@ def unique(a, sorted=False, return_inverse=False, axis=None):
# Gather all unique vectors
counts = list(uniques_buf.tolist())
displs = list([0] + uniques_buf.cumsum(0).tolist()[:-1])
gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type())
gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device)
a.comm.Allgatherv(lres, (gres_buf, counts, displs), recv_axis=0)

if return_inverse:
Expand Down
14 changes: 7 additions & 7 deletions heat/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __counter_sequence(shape, dtype, split, device, comm):
lrange[0], lrange[1] = lrange[0] - diff, lrange[1] - diff

# create x_1 counter sequence
x_1 = torch.arange(*lrange, dtype=dtype)
x_1 = torch.arange(*lrange, dtype=dtype, device=device.torch_device)
while diff > signed_mask:
# signed_mask is maximum that can be added at a time because torch does not support unit64 or unit32
x_1 += signed_mask
Expand Down Expand Up @@ -661,9 +661,9 @@ def __threefry32(X_0, X_1):
seed_32 = __seed & 0x7FFFFFFF

# set up key buffer
ks_0 = torch.full((samples,), seed_32, dtype=torch.int32)
ks_1 = torch.full((samples,), seed_32, dtype=torch.int32)
ks_2 = torch.full((samples,), 466688986, dtype=torch.int32)
ks_0 = torch.full((samples,), seed_32, dtype=torch.int32, device=X_0.device)
ks_1 = torch.full((samples,), seed_32, dtype=torch.int32, device=X_1.device)
ks_2 = torch.full((samples,), 466688986, dtype=torch.int32, device=X_0.device)
ks_2 ^= ks_0
ks_2 ^= ks_0

Expand Down Expand Up @@ -751,9 +751,9 @@ def __threefry64(X_0, X_1):
samples = len(X_0)

# set up key buffer
ks_0 = torch.full((samples,), __seed, dtype=torch.int64)
ks_1 = torch.full((samples,), __seed, dtype=torch.int64)
ks_2 = torch.full((samples,), 2004413935125273122, dtype=torch.int64)
ks_0 = torch.full((samples,), __seed, dtype=torch.int64, device=X_0.device)
ks_1 = torch.full((samples,), __seed, dtype=torch.int64, device=X_1.device)
ks_2 = torch.full((samples,), 2004413935125273122, dtype=torch.int64, device=X_0.device)
ks_2 ^= ks_0
ks_2 ^= ks_0

Expand Down
3 changes: 2 additions & 1 deletion heat/core/tests/test_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test_expm1(self):
def test_exp2(self):
elements = 10
tmp = np.exp2(torch.arange(elements, dtype=torch.float64))
comparison = ht.array(tmp)
tmp = tmp.to(self.device.torch_device)
comparison = ht.array(tmp, device=self.device)

# exponential of float32
float32_tensor = ht.arange(elements, dtype=ht.float32)
Expand Down
12 changes: 8 additions & 4 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def test_pad(self):
# test padding of non-distributed tensor
# ======================================

data = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
data = torch.arange(2 * 3 * 4, device=self.device.torch_device).reshape(2, 3, 4)
data_ht = ht.array(data, device=self.device)
data_np = data_ht.numpy()

Expand Down Expand Up @@ -1307,7 +1307,7 @@ def test_pad(self):
# test padding of large distributed tensor
# =========================================

data = torch.arange(8 * 3 * 4).reshape(8, 3, 4)
data = torch.arange(8 * 3 * 4, device=self.device.torch_device).reshape(8, 3, 4)
data_ht_split = ht.array(data, split=0)
data_np = data_ht_split.numpy()

Expand Down Expand Up @@ -1996,7 +1996,9 @@ def test_topk(self):
if size == 1:
size = 4

torch_array = torch.arange(size, dtype=torch.int32).expand(size, size)
torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand(
size, size
)
split_zero = ht.array(torch_array, split=0)
split_one = ht.array(torch_array, split=1)

Expand All @@ -2018,7 +2020,9 @@ def test_topk(self):
self.assertTrue((indcs._DNDarray__array == exp_one_indcs._DNDarray__array).all())
self.assertTrue(indcs._DNDarray__array.dtype == exp_one_indcs._DNDarray__array.dtype)

torch_array = torch.arange(size, dtype=torch.float64).expand(size, size)
torch_array = torch.arange(
size, dtype=torch.float64, device=self.device.torch_device
).expand(size, size)
split_zero = ht.array(torch_array, split=0)
split_one = ht.array(torch_array, split=1)

Expand Down
3 changes: 2 additions & 1 deletion heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,11 @@ def assert_func_equal_for_tensor(

if isinstance(tensor, np.ndarray):
torch_tensor = torch.from_numpy(tensor.copy())
torch_tensor = torch_tensor.to(self.device.torch_device)
np_array = tensor
elif isinstance(tensor, torch.Tensor):
torch_tensor = tensor
np_array = tensor.numpy().copy()
np_array = tensor.cpu().numpy().copy()
else:
raise TypeError(
"The input tensors type must be one of [tuple, list, "
Expand Down
6 changes: 3 additions & 3 deletions heat/core/tests/test_suites/test_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_assert_array_equal(self):
self.assert_array_equal(heat_array, expected_array)

if self.get_rank() == 0:
data = torch.arange(self.get_size(), dtype=torch.int32)
data = torch.arange(self.get_size(), dtype=torch.int32, device=self.device.torch_device)
else:
data = torch.empty((0,), dtype=torch.int32)
data = torch.empty((0,), dtype=torch.int32, device=self.device.torch_device)

ht_array = ht.array(data, is_split=0)
np_array = np.arange(self.get_size(), dtype=np.int32)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_assert_func_equal_for_tensor(self):
array, ht_func, np_func, heat_args=ht_args, numpy_args=np_args
)

array = torch.randn(15, 15)
array = torch.randn(15, 15, device=self.device.torch_device)
ht_func = ht.exp
np_func = np.exp
self.assert_func_equal_for_tensor(array, heat_func=ht_func, numpy_func=np_func)
Expand Down