From 95cc6f77f14079f7815f5a7cee3711a370299b2d Mon Sep 17 00:00:00 2001 From: Charlotte Date: Thu, 2 May 2019 12:28:46 +0200 Subject: [PATCH 1/7] Changed __binary_op for scalar values - Scalars are not converted to 1D tensors, but are passed directly to torch opperation - Error handling for non-supported types is passed on from torch --- heat/core/operations.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/heat/core/operations.py b/heat/core/operations.py index e6b21a8785..b02326c79e 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -36,43 +36,47 @@ def __binary_op(operation, t1, t2): """ if np.isscalar(t1): - try: - t1 = factories.array([t1]) - except (ValueError, TypeError,): - raise TypeError('Data type not supported, input was {}'.format(type(t1))) - if np.isscalar(t2): try: - t2 = factories.array([t2]) + result = operation(t1, t2) except (ValueError, TypeError,): - raise TypeError('Only numeric scalars are supported, but input was {}'.format(type(t2))) + raise TypeError('Only numeric scalars are supported, but inputs were {} and {}'.format(type(t1), type(t2))) output_shape = (1,) output_split = None output_device = None output_comm = MPI_WORLD + elif isinstance(t2, dndarray.DNDarray): + try: + result = operation(t1, t2._tensor__array) + except (ValueError, TypeError,): + raise TypeError('Data type not supported, input was {}'.format(type(t1))) + output_shape = t2.shape output_split = t2.split output_device = t2.device output_comm = t2.comm + else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) - if t1.dtype != t2.dtype: - t1 = t1.astype(t2.dtype) elif isinstance(t1, dndarray.DNDarray): if np.isscalar(t2): try: - t2 = factories.array([t2]) - output_shape = t1.shape - output_split = t1.split - output_device = t1.device - output_comm = t1.comm + result = operation(t1._tensor__array, t2) except (ValueError, TypeError,): raise TypeError('Data type not supported, input was {}'.format(type(t2))) + output_shape = t1.shape + output_split = t1.split + output_device = t1.device + output_comm = t1.comm + elif isinstance(t2, dndarray.DNDarray): + if t1.dtype != t2.dtype: + t1 = t1.astype(t2.dtype) + # TODO: implement complex NUMPY rules if t2.split is None or t2.split == t1.split: output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) From b7fd55ae160e419ce4c3cf018326f60e716acd3c Mon Sep 17 00:00:00 2001 From: Charlotte Date: Thu, 2 May 2019 13:01:24 +0200 Subject: [PATCH 2/7] Clean-up - Implementation of case: one tensor split, one tensor not split - some restructuring --- heat/core/operations.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/heat/core/operations.py b/heat/core/operations.py index b02326c79e..acf4efd77a 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -46,6 +46,9 @@ def __binary_op(operation, t1, t2): output_device = None output_comm = MPI_WORLD + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + + elif isinstance(t2, dndarray.DNDarray): try: result = operation(t1, t2._tensor__array) @@ -57,6 +60,8 @@ def __binary_op(operation, t1, t2): output_device = t2.device output_comm = t2.comm + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) @@ -73,16 +78,26 @@ def __binary_op(operation, t1, t2): output_device = t1.device output_comm = t1.comm + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + elif isinstance(t2, dndarray.DNDarray): - if t1.dtype != t2.dtype: - t1 = t1.astype(t2.dtype) + + if t2.dtype != t1.dtype: + t2 = t2.astype(t1.dtype) # TODO: implement complex NUMPY rules - if t2.split is None or t2.split == t1.split: + if t2.split == t1.split: output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) output_split = t1.split output_device = t1.device output_comm = t1.comm + + elif (t1.split is not None) and (t2.split is None): + t2.resplit(axis=t1.split) + + elif (t2.split is not None) and (t1.split is None): + t1.resplit(axis=t2.split) + else: # It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0 # and split=1 @@ -106,9 +121,6 @@ def __binary_op(operation, t1, t2): else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) - if t2.dtype != t1.dtype: - t2 = t2.astype(t1.dtype) - else: raise NotImplementedError('Not implemented for non scalar') @@ -118,7 +130,7 @@ def __binary_op(operation, t1, t2): result = t1._DNDarray__array.type(promoted_type) else: result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) - elif t1.split is not None: + elif t2.split is not None: if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: result = t2._DNDarray__array.type(promoted_type) else: From d0e1de181caff5a6185073642e97c595aa2e803b Mon Sep 17 00:00:00 2001 From: Charlotte Date: Thu, 2 May 2019 13:25:06 +0200 Subject: [PATCH 3/7] Minor fixes --- heat/core/operations.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/heat/core/operations.py b/heat/core/operations.py index acf4efd77a..91a20921ff 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -46,12 +46,12 @@ def __binary_op(operation, t1, t2): output_device = None output_comm = MPI_WORLD - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) elif isinstance(t2, dndarray.DNDarray): try: - result = operation(t1, t2._tensor__array) + result = operation(t1, t2._DNDarray__array) except (ValueError, TypeError,): raise TypeError('Data type not supported, input was {}'.format(type(t1))) @@ -60,7 +60,7 @@ def __binary_op(operation, t1, t2): output_device = t2.device output_comm = t2.comm - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) @@ -69,7 +69,7 @@ def __binary_op(operation, t1, t2): elif isinstance(t1, dndarray.DNDarray): if np.isscalar(t2): try: - result = operation(t1._tensor__array, t2) + result = operation(t1._DNDarray__array, t2) except (ValueError, TypeError,): raise TypeError('Data type not supported, input was {}'.format(type(t2))) @@ -78,7 +78,7 @@ def __binary_op(operation, t1, t2): output_device = t1.device output_comm = t1.comm - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) elif isinstance(t2, dndarray.DNDarray): From 13290fa745b1ee7ec00ffec6ccd1bac1694641f0 Mon Sep 17 00:00:00 2001 From: Charlotte Date: Mon, 20 May 2019 15:29:46 +0200 Subject: [PATCH 4/7] Fixed Unittests for binary and relational operations --- heat/core/arithmetics.py | 32 +++++++- heat/core/operations.py | 17 ++-- heat/core/relational.py | 121 ++++++++++++++++++++++++++-- heat/core/tests/test_arithmetics.py | 26 ++---- heat/core/tests/test_relational.py | 22 ++--- 5 files changed, 174 insertions(+), 44 deletions(-) diff --git a/heat/core/arithmetics.py b/heat/core/arithmetics.py index 519c0671b9..1f0516655a 100644 --- a/heat/core/arithmetics.py +++ b/heat/core/arithmetics.py @@ -1,8 +1,10 @@ import torch +import numpy as np from .communication import MPI from . import operations from . import dndarray +from . import factories __all__ = [ @@ -128,7 +130,25 @@ def fmod(t1, t2): tensor([[0., 0.] [2., 2.]]) """ - return operations.__binary_op(torch.fmod, t1, t2) + + if np.isscalar(t1): + #Special treatment for fmod, since torch operation fmod only supports formats (Tensor input, Tensor other, Tensor out) or (Tensor input, Number other, Tensor out) + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.fmod, tmp, t2) + return result + else: + return operations.__binary_op(torch.fmod, t1, t2) + def mod(t1, t2): @@ -247,7 +267,15 @@ def pow(t1, t2): tensor([[1., 8.], [27., 64.]]) """ - return operations.__binary_op(torch.pow, t1, t2) + + if np.isscalar(t1) and np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + return operations.__binary_op(torch.pow, tmp, t2) + else: + return operations.__binary_op(torch.pow, t1, t2) def sub(t1, t2): diff --git a/heat/core/operations.py b/heat/core/operations.py index 91a20921ff..cedbd810e5 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -87,15 +87,17 @@ def __binary_op(operation, t1, t2): # TODO: implement complex NUMPY rules if t2.split == t1.split: - output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) - output_split = t1.split - output_device = t1.device - output_comm = t1.comm - + pass elif (t1.split is not None) and (t2.split is None): + if t2.shape[t1.split] == 1: + warnings.warn('Broadcasting requires transferring data of second operator between MPI ranks!') + t2.comm.Bcast(t2) t2.resplit(axis=t1.split) elif (t2.split is not None) and (t1.split is None): + if t1.shape[t2.split] == 1: + warnings.warn('Broadcasting requires transferring data of first operator between MPI ranks!') + t1.comm.Bcast(t1) t1.resplit(axis=t2.split) else: @@ -103,6 +105,11 @@ def __binary_op(operation, t1, t2): # and split=1 raise NotImplementedError('Not implemented for other splittings') + output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) + output_split = t1.split + output_device = t1.device + output_comm = t1.comm + # ToDo: Fine tuning in case of comm.size>t1.shape[t1.split]. Send torch tensors only to ranks, that will hold data. if t1.split is not None: if t1.shape[t1.split] == 1 and t1.comm.is_distributed(): diff --git a/heat/core/relational.py b/heat/core/relational.py index 52659d18b5..2548b9bf47 100644 --- a/heat/core/relational.py +++ b/heat/core/relational.py @@ -48,7 +48,22 @@ def eq(t1, t2): tensor([[0, 1], [0, 0]]) """ - return operations.__binary_op(torch.eq, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.eq, tmp, t2) + return result + else: + return operations.__binary_op(torch.eq, t1, t2) def equal(t1, t2): @@ -80,7 +95,24 @@ def equal(t1, t2): >>> ht.eq(T1, 3.0) False """ - result_tensor = operations.__binary_op(torch.equal, t1, t2) + if np.isscalar(t1): + try: + tmp1 = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + tmp1 = t1 + if np.isscalar(t2): + try: + tmp2 = factories.array([t2]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + tmp2 = t2 + + result_tensor = operations.__binary_op(torch.equal, tmp1, tmp2) + + result_value = result_tensor._DNDarray__array if isinstance(result_value, torch.Tensor): result_value = True @@ -120,7 +152,22 @@ def ge(t1, t2): tensor([[0, 1], [1, 1]], dtype=torch.uint8) """ - return operations.__binary_op(torch.ge, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.ge, tmp, t2) + return result + else: + return operations.__binary_op(torch.ge, t1, t2) def gt(t1, t2): @@ -156,7 +203,22 @@ def gt(t1, t2): tensor([[0, 0], [1, 1]], dtype=torch.uint8) """ - return operations.__binary_op(torch.gt, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.gt, tmp, t2) + return result + else: + return operations.__binary_op(torch.gt, t1, t2) def le(t1, t2): @@ -191,7 +253,22 @@ def le(t1, t2): tensor([[1, 1], [0, 0]], dtype=torch.uint8) """ - return operations.__binary_op(torch.le, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.le, tmp, t2) + return result + else: + return operations.__binary_op(torch.le, t1, t2) def lt(t1, t2): @@ -226,7 +303,22 @@ def lt(t1, t2): tensor([[1, 0], [0, 0]], dtype=torch.uint8) """ - return operations.__binary_op(torch.lt, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.lt, tmp, t2) + return result + else: + return operations.__binary_op(torch.lt, t1, t2) def ne(t1, t2): @@ -260,4 +352,19 @@ def ne(t1, t2): tensor([[1, 0], [1, 1]]) """ - return operations.__binary_op(torch.ne, t1, t2) + if np.isscalar(t1): + if np.isscalar(t2): + try: + tmp = factories.array(t1) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + else: + try: + tmp = factories.array([t1]) + except (ValueError, TypeError,): + raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) + + result = operations.__binary_op(torch.ne, tmp, t2) + return result + else: + return operations.__binary_op(torch.ne, t1, t2) diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index 51c0c3456c..d2f7b84e83 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -30,8 +30,8 @@ def test_add(self): [3.0, 4.0], [5.0, 6.0] ]) - - self.assertTrue(ht.equal(ht.add(self.a_scalar, self.a_scalar), ht.float32([4.0]))) + + self.assertTrue(ht.equal(ht.add(self.a_scalar, self.a_scalar), ht.array(4.0))) self.assertTrue(ht.equal(ht.add(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.add(self.a_scalar, self.a_tensor), result)) self.assertTrue(ht.equal(ht.add(self.a_tensor, self.another_tensor), result)) @@ -41,8 +41,6 @@ def test_add(self): with self.assertRaises(ValueError): ht.add(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.add(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.add(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): @@ -58,7 +56,7 @@ def test_div(self): [2.0/3.0, 0.5] ]) - self.assertTrue(ht.equal(ht.div(self.a_scalar, self.a_scalar), ht.float32([1.0]))) + self.assertTrue(ht.equal(ht.div(self.a_scalar, self.a_scalar), ht.array(1.0))) self.assertTrue(ht.equal(ht.div(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.div(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.div(self.a_tensor, self.another_tensor), result)) @@ -68,8 +66,6 @@ def test_div(self): with self.assertRaises(ValueError): ht.div(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.sub(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.div(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): @@ -98,7 +94,7 @@ def test_fmod(self): another_float = ht.array([1.9]) result_float = ht.array([1.5]) - self.assertTrue(ht.equal(ht.fmod(self.a_scalar, self.a_scalar), ht.float32([0.0]))) + self.assertTrue(ht.equal(ht.fmod(self.a_scalar, self.a_scalar), ht.array(0.0))) self.assertTrue(ht.equal(ht.fmod(self.a_tensor, self.a_tensor), zero_tensor)) self.assertTrue(ht.equal(ht.fmod(self.a_tensor, self.an_int_scalar), result)) self.assertTrue(ht.equal(ht.fmod(self.a_tensor, self.another_tensor), result)) @@ -111,8 +107,6 @@ def test_fmod(self): with self.assertRaises(ValueError): ht.fmod(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.fmod(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.fmod(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): @@ -146,7 +140,7 @@ def test_mul(self): [6.0, 8.0] ]) - self.assertTrue(ht.equal(ht.mul(self.a_scalar, self.a_scalar), ht.array([4.0]))) + self.assertTrue(ht.equal(ht.mul(self.a_scalar, self.a_scalar), ht.array(4.0))) self.assertTrue(ht.equal(ht.mul(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.mul(self.a_scalar, self.a_tensor), result)) self.assertTrue(ht.equal(ht.mul(self.a_tensor, self.another_tensor), result)) @@ -156,8 +150,6 @@ def test_mul(self): with self.assertRaises(ValueError): ht.mul(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.mul(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.mul(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): @@ -173,7 +165,7 @@ def test_pow(self): [8.0, 16.0] ]) - self.assertTrue(ht.equal(ht.pow(self.a_scalar, self.a_scalar), ht.array([4.0]))) + self.assertTrue(ht.equal(ht.pow(self.a_scalar, self.a_scalar), ht.array(4.0))) self.assertTrue(ht.equal(ht.pow(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.pow(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.pow(self.a_tensor, self.another_tensor), result)) @@ -183,8 +175,6 @@ def test_pow(self): with self.assertRaises(ValueError): ht.pow(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.pow(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.pow(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): @@ -200,7 +190,7 @@ def test_sub(self): [-1.0, -2.0] ]) - self.assertTrue(ht.equal(ht.sub(self.a_scalar, self.a_scalar), ht.array([0.0]))) + self.assertTrue(ht.equal(ht.sub(self.a_scalar, self.a_scalar), ht.array(0.0))) self.assertTrue(ht.equal(ht.sub(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.sub(self.a_scalar, self.a_tensor), minus_result)) self.assertTrue(ht.equal(ht.sub(self.a_tensor, self.another_tensor), result)) @@ -210,8 +200,6 @@ def test_sub(self): with self.assertRaises(ValueError): ht.sub(self.a_tensor, self.another_vector) - with self.assertRaises(NotImplementedError): - ht.sub(self.a_tensor, self.a_split_tensor) with self.assertRaises(TypeError): ht.sub(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): diff --git a/heat/core/tests/test_relational.py b/heat/core/tests/test_relational.py index 7dbbc7288f..c9554a31a9 100644 --- a/heat/core/tests/test_relational.py +++ b/heat/core/tests/test_relational.py @@ -31,7 +31,7 @@ def test_eq(self): [0, 0] ]) - self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_scalar), ht.uint8([1]))) + self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_scalar), ht.uint8(1))) self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_tensor), result)) self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.another_tensor), result)) @@ -64,7 +64,7 @@ def test_ge(self): [0, 0] ]) - self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_scalar), ht.uint8([1]))) + self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_scalar), ht.uint8(1))) self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.another_tensor), result)) @@ -91,7 +91,7 @@ def test_gt(self): [0, 0] ]) - self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_scalar), ht.uint8([0]))) + self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_scalar), ht.uint8(0))) self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.another_tensor), result)) @@ -118,7 +118,7 @@ def test_le(self): [1, 1] ]) - self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_scalar), ht.uint8([1]))) + self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_scalar), ht.uint8(1))) self.assertTrue(ht.equal(ht.le(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.le(self.a_tensor, self.another_tensor), result)) @@ -145,7 +145,7 @@ def test_lt(self): [1, 1] ]) - self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_scalar), ht.uint8([0]))) + self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_scalar), ht.uint8(0))) self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.a_scalar), result)) self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_tensor), commutated_result)) self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.another_tensor), result)) @@ -168,12 +168,12 @@ def test_ne(self): [1, 1] ]) - # self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_scalar), ht.uint8([0]))) - # self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.a_scalar), result)) - # self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_tensor), result)) - # self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.another_tensor), result)) - # self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.a_vector), result)) - # self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.an_int_scalar), result)) + self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_scalar), ht.uint8(0))) + self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.a_scalar), result)) + self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_tensor), result)) + self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.another_tensor), result)) + self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.a_vector), result)) + self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.an_int_scalar), result)) self.assertTrue(ht.equal(ht.ne(self.a_split_tensor, self.a_tensor), result)) with self.assertRaises(ValueError): From e1b1a4a77bae64368a4d29cfabc3297d6a9282f8 Mon Sep 17 00:00:00 2001 From: Charlotte Date: Mon, 20 May 2019 16:45:54 +0200 Subject: [PATCH 5/7] Cosmetics --- heat/core/arithmetics.py | 10 ++++----- heat/core/relational.py | 46 ++++++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/heat/core/arithmetics.py b/heat/core/arithmetics.py index 1f0516655a..0fc4bd5506 100644 --- a/heat/core/arithmetics.py +++ b/heat/core/arithmetics.py @@ -135,16 +135,16 @@ def fmod(t1, t2): #Special treatment for fmod, since torch operation fmod only supports formats (Tensor input, Tensor other, Tensor out) or (Tensor input, Number other, Tensor out) if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.fmod, tmp, t2) + result = operations.__binary_op(torch.fmod, tensor_1, t2) return result else: return operations.__binary_op(torch.fmod, t1, t2) @@ -270,10 +270,10 @@ def pow(t1, t2): if np.isscalar(t1) and np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - return operations.__binary_op(torch.pow, tmp, t2) + return operations.__binary_op(torch.pow, tensor_1, t2) else: return operations.__binary_op(torch.pow, t1, t2) diff --git a/heat/core/relational.py b/heat/core/relational.py index 2548b9bf47..28b3f00dd8 100644 --- a/heat/core/relational.py +++ b/heat/core/relational.py @@ -51,16 +51,16 @@ def eq(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.eq, tmp, t2) + result = operations.__binary_op(torch.eq, tensor_1, t2) return result else: return operations.__binary_op(torch.eq, t1, t2) @@ -97,20 +97,20 @@ def equal(t1, t2): """ if np.isscalar(t1): try: - tmp1 = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: - tmp1 = t1 + tensor_1 = t1 if np.isscalar(t2): try: - tmp2 = factories.array([t2]) + tensor_2 = factories.array([t2]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: - tmp2 = t2 + tensor_2 = t2 - result_tensor = operations.__binary_op(torch.equal, tmp1, tmp2) + result_tensor = operations.__binary_op(torch.equal, tensor_1, tensor_2) result_value = result_tensor._DNDarray__array @@ -155,16 +155,16 @@ def ge(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.ge, tmp, t2) + result = operations.__binary_op(torch.ge, tensor_1, t2) return result else: return operations.__binary_op(torch.ge, t1, t2) @@ -206,16 +206,16 @@ def gt(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.gt, tmp, t2) + result = operations.__binary_op(torch.gt, tensor_1, t2) return result else: return operations.__binary_op(torch.gt, t1, t2) @@ -256,16 +256,16 @@ def le(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.le, tmp, t2) + result = operations.__binary_op(torch.le, tensor_1, t2) return result else: return operations.__binary_op(torch.le, t1, t2) @@ -306,16 +306,16 @@ def lt(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.lt, tmp, t2) + result = operations.__binary_op(torch.lt, tensor_1, t2) return result else: return operations.__binary_op(torch.lt, t1, t2) @@ -355,16 +355,16 @@ def ne(t1, t2): if np.isscalar(t1): if np.isscalar(t2): try: - tmp = factories.array(t1) + tensor_1 = factories.array(t1) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) else: try: - tmp = factories.array([t1]) + tensor_1 = factories.array([t1]) except (ValueError, TypeError,): raise TypeError('First operand must be numeric scalar or NDNArray, but was {}'.format(type(t1))) - result = operations.__binary_op(torch.ne, tmp, t2) + result = operations.__binary_op(torch.ne, tensor_1, t2) return result else: return operations.__binary_op(torch.ne, t1, t2) From 27c101aaa08d2a715f7090facca46be96863d807 Mon Sep 17 00:00:00 2001 From: Charlotte Date: Tue, 21 May 2019 09:41:35 +0200 Subject: [PATCH 6/7] Added unit tests for code cov --- heat/core/tests/test_arithmetics.py | 2 ++ heat/core/tests/test_relational.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index 02f5671777..ffbb1401d5 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -111,6 +111,8 @@ def test_fmod(self): ht.fmod(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.fmod('T', 's') + with self.assertRaises(TypeError): + ht.fmod('2', self.a_tensor) def test_mod(self): a_tensor = ht.array([ diff --git a/heat/core/tests/test_relational.py b/heat/core/tests/test_relational.py index c9554a31a9..e6195004e7 100644 --- a/heat/core/tests/test_relational.py +++ b/heat/core/tests/test_relational.py @@ -45,15 +45,24 @@ def test_eq(self): ht.eq(self.a_tensor, self.split_ones_tensor) with self.assertRaises(TypeError): ht.eq(self.a_tensor, self.errorneous_type) + with self.assertRaises(TypeError): + ht.eq('2', self.a_vector) with self.assertRaises(TypeError): ht.eq('self.a_tensor', 's') def test_equal(self): + self.assertTrue(ht.equal(self.a_scalar, 2.0)) self.assertTrue(ht.equal(self.a_tensor, self.a_tensor)) self.assertFalse(ht.equal(self.a_tensor, self.another_tensor)) self.assertFalse(ht.equal(self.a_tensor, self.a_scalar)) self.assertFalse(ht.equal(self.another_tensor, self.a_scalar)) + with self.assertRaises(TypeError): + ht.equal('2', self.a_scalar) + with self.assertRaises(TypeError): + ht.equal(self.a_scalar, '2') + + def test_ge(self): result = ht.uint8([ [0, 1], @@ -80,6 +89,8 @@ def test_ge(self): ht.ge(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.ge('self.a_tensor', 's') + with self.assertRaises(TypeError): + ht.ge('2', self.a_vector) def test_gt(self): result = ht.uint8([ @@ -107,6 +118,8 @@ def test_gt(self): ht.gt(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.gt('self.a_tensor', 's') + with self.assertRaises(TypeError): + ht.gt('2', self.a_vector) def test_le(self): result = ht.uint8([ @@ -134,6 +147,8 @@ def test_le(self): ht.le(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.le('self.a_tensor', 's') + with self.assertRaises(TypeError): + ht.le('2', self.a_vector) def test_lt(self): result = ht.uint8([ @@ -161,6 +176,8 @@ def test_lt(self): ht.lt(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.lt('self.a_tensor', 's') + with self.assertRaises(TypeError): + ht.lt('2', self.a_vector) def test_ne(self): result = ht.uint8([ @@ -184,3 +201,5 @@ def test_ne(self): ht.ne(self.a_tensor, self.errorneous_type) with self.assertRaises(TypeError): ht.ne('self.a_tensor', 's') + with self.assertRaises(TypeError): + ht.ne('2', self.a_vector) From a65d95ebdbb99f8230bf7cde1f546276a595c91e Mon Sep 17 00:00:00 2001 From: Charlotte Date: Wed, 22 May 2019 16:26:19 +0200 Subject: [PATCH 7/7] Restructuring of _binary_op --- heat/core/operations.py | 61 ++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/heat/core/operations.py b/heat/core/operations.py index f5aafd3adb..a5c4fa1faa 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -42,13 +42,11 @@ def __binary_op(operation, t1, t2): except (ValueError, TypeError,): raise TypeError('Only numeric scalars are supported, but inputs were {} and {}'.format(type(t1), type(t2))) output_shape = (1,) + output_type = types.canonical_heat_type(result.dtype) output_split = None output_device = None output_comm = MPI_WORLD - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) - - elif isinstance(t2, dndarray.DNDarray): try: result = operation(t1, t2._DNDarray__array) @@ -56,12 +54,11 @@ def __binary_op(operation, t1, t2): raise TypeError('Data type not supported, input was {}'.format(type(t1))) output_shape = t2.shape + output_type = types.canonical_heat_type(result.dtype) output_split = t2.split output_device = t2.device output_comm = t2.comm - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) - else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) @@ -74,42 +71,30 @@ def __binary_op(operation, t1, t2): raise TypeError('Data type not supported, input was {}'.format(type(t2))) output_shape = t1.shape + output_type = types.canonical_heat_type(result.dtype) output_split = t1.split output_device = t1.device output_comm = t1.comm - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(result.dtype), output_split, output_device, output_comm) - elif isinstance(t2, dndarray.DNDarray): if t2.dtype != t1.dtype: t2 = t2.astype(t1.dtype) + output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) + # TODO: implement complex NUMPY rules if t2.split == t1.split: pass elif (t1.split is not None) and (t2.split is None): - if t2.shape[t1.split] == 1: - warnings.warn('Broadcasting requires transferring data of second operator between MPI ranks!') - t2.comm.Bcast(t2) t2.resplit(axis=t1.split) - elif (t2.split is not None) and (t1.split is None): - if t1.shape[t2.split] == 1: - warnings.warn('Broadcasting requires transferring data of first operator between MPI ranks!') - t1.comm.Bcast(t1) t1.resplit(axis=t2.split) - else: # It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0 # and split=1 raise NotImplementedError('Not implemented for other splittings') - output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) - output_split = t1.split - output_device = t1.device - output_comm = t1.comm - # ToDo: Fine tuning in case of comm.size>t1.shape[t1.split]. Send torch tensors only to ranks, that will hold data. if t1.split is not None: if t1.shape[t1.split] == 1 and t1.comm.is_distributed(): @@ -125,27 +110,33 @@ def __binary_op(operation, t1, t2): t2._DNDarray__array = torch.zeros(t2.shape, dtype=t2.dtype.torch_type()) t2.comm.Bcast(t2) + promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type() + if t1.split is not None: + if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0: + result = t1._DNDarray__array.type(promoted_type) + else: + result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) + elif t2.split is not None: + if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: + result = t2._DNDarray__array.type(promoted_type) + else: + result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) + else: + result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) + + output_type = types.canonical_heat_type(t1.dtype) + output_split = t1.split + output_device = t1.device + output_comm = t1.comm + else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) else: - raise NotImplementedError('Not implemented for non scalar') + raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t1))) - promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type() - if t1.split is not None: - if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0: - result = t1._DNDarray__array.type(promoted_type) - else: - result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) - elif t2.split is not None: - if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: - result = t2._DNDarray__array.type(promoted_type) - else: - result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) - else: - result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) - return dndarray.DNDarray(result, output_shape, types.canonical_heat_type(t1.dtype), output_split, output_device, output_comm) + return dndarray.DNDarray(result, output_shape, output_type, output_split, output_device, output_comm) def __local_op(operation, x, out, **kwargs):