Skip to content

Commit

Permalink
Merge pull request #678 from helmholtz-analytics/bug/675-initdevice
Browse files Browse the repository at this point in the history
Bugfix: Internal functions now use explicit device parametes for DNDarray and torch.Tensor initializations.
  • Loading branch information
mtar authored Oct 26, 2020
2 parents 88f9157 + 27d7134 commit 9c90cb4
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# v0.5.1

- [#678](https://github.com/helmholtz-analytics/heat/pull/678) Bugfix: Internal functions now use explicit device parameters for DNDarray and torch.Tensor initializations.
- [#684](https://github.com/helmholtz-analytics/heat/pull/684) Bug fix: distributed `reshape` does not work on booleans.

# v0.5.0
Expand Down
11 changes: 6 additions & 5 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .communication import MPI, MPI_WORLD
from . import factories
from . import devices
from . import stride_tricks
from . import sanitation
from . import dndarray
Expand Down Expand Up @@ -38,7 +39,9 @@ def __binary_op(operation, t1, t2, out=None):
"""
if np.isscalar(t1):
try:
t1 = factories.array([t1])
t1 = factories.array(
[t1], device=t2.device if isinstance(t2, dndarray.DNDarray) else None
)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))

Expand All @@ -51,11 +54,9 @@ def __binary_op(operation, t1, t2, out=None):
)
output_shape = (1,)
output_split = None
output_device = None
output_device = t2.device
output_comm = MPI_WORLD
elif isinstance(t2, dndarray.DNDarray):
t1 = t1.gpu() if t2.device.device_type == "gpu" else t1.cpu()

output_shape = t2.shape
output_split = t2.split
output_device = t2.device
Expand Down Expand Up @@ -154,7 +155,7 @@ def __binary_op(operation, t1, t2, out=None):
)

if not isinstance(result, torch.Tensor):
result = torch.tensor(result)
result = torch.tensor(result, device=output_device.torch_device)

if out is not None:
out_dtype = out.dtype
Expand Down
3 changes: 2 additions & 1 deletion heat/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):


# create a CPU device singleton
cpu = Device("cpu", 0, "cpu:0")
cpu = Device("cpu", 0, "cpu")

# define the default device to be the CPU
__default_device = cpu
Expand All @@ -71,6 +71,7 @@ def __str__(self):
gpu = Device("gpu", gpu_id, "cuda:{}".format(gpu_id))
# add a GPU device string
__device_mapping[gpu.device_type] = gpu
__device_mapping["cuda"] = gpu
# the GPU device should be exported as global symbol
__all__.append("gpu")

Expand Down
11 changes: 3 additions & 8 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,8 @@ def __init__(self, array, gshape, dtype, split, device, comm):
self.__halo_next = None
self.__halo_prev = None

# handle inconsistencies between torch and heat devices
if (
isinstance(array, torch.Tensor)
and isinstance(device, devices.Device)
and array.device.type != device.device_type
):
self.__array = self.__array.to(devices.sanitize_device(self.__device).torch_device)
# check for inconsistencies between torch and heat devices
assert str(array.device) == device.torch_device

@property
def halo_next(self):
Expand Down Expand Up @@ -1509,7 +1504,7 @@ def __getitem__(self, key):
chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device)
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
arr = torch.Tensor()
arr = torch.tensor([], device=self.device.torch_device)
# all keys should be tuples here
gout = [0] * len(self.gshape)
# handle the dimensional reduction for integers
Expand Down
27 changes: 24 additions & 3 deletions heat/core/factories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import warnings

from .communication import MPI, sanitize_comm
from .stride_tricks import sanitize_axis, sanitize_shape
Expand Down Expand Up @@ -289,6 +290,10 @@ def array(
if dtype is not None:
dtype = types.canonical_heat_type(dtype)

# sanitize device
if device is not None:
device = devices.sanitize_device(device)

# initialize the array
if bool(copy):
if isinstance(obj, torch.Tensor):
Expand All @@ -297,7 +302,13 @@ def array(
obj = obj.clone().detach()
else:
try:
obj = torch.tensor(obj, dtype=dtype.torch_type() if dtype is not None else None)
obj = torch.tensor(
obj,
dtype=dtype.torch_type() if dtype is not None else None,
device=device.torch_device
if device is not None
else devices.get_device().torch_device,
)
except RuntimeError:
raise TypeError("invalid data of type {}".format(type(obj)))

Expand All @@ -309,6 +320,17 @@ def array(
if obj.dtype != torch_dtype:
obj = obj.type(torch_dtype)

# infer device from obj if not explicitly given
if device is None:
device = devices.sanitize_device(obj.device.type)

if str(obj.device) != device.torch_device:
warnings.warn(
"Array 'obj' is not on device '{}'. It will be copied to it.".format(device),
UserWarning,
)
obj = obj.to(device.torch_device)

# sanitize minimum number of dimensions
if not isinstance(ndmin, int):
raise TypeError("expected ndmin to be int, but was {}".format(type(ndmin)))
Expand All @@ -326,8 +348,7 @@ def array(
if split is not None and is_split is not None:
raise ValueError("split and is_split are mutually exclusive parameters")

# sanitize device and object
device = devices.sanitize_device(device)
# sanitize comm object
comm = sanitize_comm(comm)

# determine the local and the global shape, if not split is given, they are identical
Expand Down
4 changes: 3 additions & 1 deletion heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,9 @@ def load_csv(
# Create empty tensor and iteratively fill it with the values
local_shape = (len(line_starts), columns)
actual_length = 0
local_tensor = torch.empty(local_shape, dtype=dtype.torch_type())
local_tensor = torch.empty(
local_shape, dtype=dtype.torch_type(), device=device.torch_device
)
for ind, start in enumerate(line_starts):
if ind == len(line_starts) - 1:
f.seek(displs[rank] + start, 0)
Expand Down
2 changes: 1 addition & 1 deletion heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False):
# If x is distributed, then y is also distributed along the same axis
if t1.comm.is_distributed() and t1.split is not None:
output_gshape = stride_tricks.broadcast_shape(t1.gshape, t2.gshape)
res = torch.empty(output_gshape).bool()
res = torch.empty(output_gshape, device=t1.device.torch_device).bool()
t1.comm.Allgather(_local_isclose, res)
result = factories.array(res, dtype=types.bool, device=t1.device, split=t1.split)
else:
Expand Down
10 changes: 6 additions & 4 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,9 @@ def pad(array, pad_width, mode="constant", constant_values=0):
0 if i == array.split else output_shape[i] for i in range(len(output_shape))
]
adapted_lshape = tuple(adapted_lshape_list)
padded_torch_tensor = torch.empty(adapted_lshape, dtype=array._DNDarray__array.dtype)
padded_torch_tensor = torch.empty(
adapted_lshape, dtype=array._DNDarray__array.dtype, device=array.device.torch_device
)
else:
if array.split is None or array.split not in pad_dim or amount_of_processes == 1:
# values = scalar
Expand Down Expand Up @@ -1987,8 +1989,8 @@ def stack(arrays, axis=0, out=None):
devices = list(array.device for array in arrays)
if devices.count(devices[0]) != num_arrays:
raise RuntimeError(
"DNDarrays in sequence must reside on the same device, got devices {}".format(
devices
"DNDarrays in sequence must reside on the same device, got devices {} {} {}".format(
devices, devices[0].device_id, devices[1].device_id
)
)
balance = list(array.is_balanced() for array in arrays)
Expand Down Expand Up @@ -2155,7 +2157,7 @@ def unique(a, sorted=False, return_inverse=False, axis=None):
# Gather all unique vectors
counts = list(uniques_buf.tolist())
displs = list([0] + uniques_buf.cumsum(0).tolist()[:-1])
gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type())
gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device)
a.comm.Allgatherv(lres, (gres_buf, counts, displs), recv_axis=0)

if return_inverse:
Expand Down
14 changes: 7 additions & 7 deletions heat/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __counter_sequence(shape, dtype, split, device, comm):
lrange[0], lrange[1] = lrange[0] - diff, lrange[1] - diff

# create x_1 counter sequence
x_1 = torch.arange(*lrange, dtype=dtype)
x_1 = torch.arange(*lrange, dtype=dtype, device=device.torch_device)
while diff > signed_mask:
# signed_mask is maximum that can be added at a time because torch does not support unit64 or unit32
x_1 += signed_mask
Expand Down Expand Up @@ -661,9 +661,9 @@ def __threefry32(X_0, X_1):
seed_32 = __seed & 0x7FFFFFFF

# set up key buffer
ks_0 = torch.full((samples,), seed_32, dtype=torch.int32)
ks_1 = torch.full((samples,), seed_32, dtype=torch.int32)
ks_2 = torch.full((samples,), 466688986, dtype=torch.int32)
ks_0 = torch.full((samples,), seed_32, dtype=torch.int32, device=X_0.device)
ks_1 = torch.full((samples,), seed_32, dtype=torch.int32, device=X_1.device)
ks_2 = torch.full((samples,), 466688986, dtype=torch.int32, device=X_0.device)
ks_2 ^= ks_0
ks_2 ^= ks_0

Expand Down Expand Up @@ -751,9 +751,9 @@ def __threefry64(X_0, X_1):
samples = len(X_0)

# set up key buffer
ks_0 = torch.full((samples,), __seed, dtype=torch.int64)
ks_1 = torch.full((samples,), __seed, dtype=torch.int64)
ks_2 = torch.full((samples,), 2004413935125273122, dtype=torch.int64)
ks_0 = torch.full((samples,), __seed, dtype=torch.int64, device=X_0.device)
ks_1 = torch.full((samples,), __seed, dtype=torch.int64, device=X_1.device)
ks_2 = torch.full((samples,), 2004413935125273122, dtype=torch.int64, device=X_0.device)
ks_2 ^= ks_0
ks_2 ^= ks_0

Expand Down
3 changes: 2 additions & 1 deletion heat/core/tests/test_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test_expm1(self):
def test_exp2(self):
elements = 10
tmp = np.exp2(torch.arange(elements, dtype=torch.float64))
comparison = ht.array(tmp)
tmp = tmp.to(self.device.torch_device)
comparison = ht.array(tmp, device=self.device)

# exponential of float32
float32_tensor = ht.arange(elements, dtype=ht.float32)
Expand Down
12 changes: 8 additions & 4 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def test_pad(self):
# test padding of non-distributed tensor
# ======================================

data = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
data = torch.arange(2 * 3 * 4, device=self.device.torch_device).reshape(2, 3, 4)
data_ht = ht.array(data, device=self.device)
data_np = data_ht.numpy()

Expand Down Expand Up @@ -1307,7 +1307,7 @@ def test_pad(self):
# test padding of large distributed tensor
# =========================================

data = torch.arange(8 * 3 * 4).reshape(8, 3, 4)
data = torch.arange(8 * 3 * 4, device=self.device.torch_device).reshape(8, 3, 4)
data_ht_split = ht.array(data, split=0)
data_np = data_ht_split.numpy()

Expand Down Expand Up @@ -2005,7 +2005,9 @@ def test_topk(self):
if size == 1:
size = 4

torch_array = torch.arange(size, dtype=torch.int32).expand(size, size)
torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand(
size, size
)
split_zero = ht.array(torch_array, split=0)
split_one = ht.array(torch_array, split=1)

Expand All @@ -2027,7 +2029,9 @@ def test_topk(self):
self.assertTrue((indcs._DNDarray__array == exp_one_indcs._DNDarray__array).all())
self.assertTrue(indcs._DNDarray__array.dtype == exp_one_indcs._DNDarray__array.dtype)

torch_array = torch.arange(size, dtype=torch.float64).expand(size, size)
torch_array = torch.arange(
size, dtype=torch.float64, device=self.device.torch_device
).expand(size, size)
split_zero = ht.array(torch_array, split=0)
split_one = ht.array(torch_array, split=1)

Expand Down
3 changes: 2 additions & 1 deletion heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,11 @@ def assert_func_equal_for_tensor(

if isinstance(tensor, np.ndarray):
torch_tensor = torch.from_numpy(tensor.copy())
torch_tensor = torch_tensor.to(self.device.torch_device)
np_array = tensor
elif isinstance(tensor, torch.Tensor):
torch_tensor = tensor
np_array = tensor.numpy().copy()
np_array = tensor.cpu().numpy().copy()
else:
raise TypeError(
"The input tensors type must be one of [tuple, list, "
Expand Down
6 changes: 3 additions & 3 deletions heat/core/tests/test_suites/test_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_assert_array_equal(self):
self.assert_array_equal(heat_array, expected_array)

if self.get_rank() == 0:
data = torch.arange(self.get_size(), dtype=torch.int32)
data = torch.arange(self.get_size(), dtype=torch.int32, device=self.device.torch_device)
else:
data = torch.empty((0,), dtype=torch.int32)
data = torch.empty((0,), dtype=torch.int32, device=self.device.torch_device)

ht_array = ht.array(data, is_split=0)
np_array = np.arange(self.get_size(), dtype=np.int32)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_assert_func_equal_for_tensor(self):
array, ht_func, np_func, heat_args=ht_args, numpy_args=np_args
)

array = torch.randn(15, 15)
array = torch.randn(15, 15, device=self.device.torch_device)
ht_func = ht.exp
np_func = np.exp
self.assert_func_equal_for_tensor(array, heat_func=ht_func, numpy_func=np_func)
Expand Down

0 comments on commit 9c90cb4

Please sign in to comment.