Skip to content

Commit

Permalink
Pointers instead of proper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lenablind committed Jun 14, 2021
1 parent 500a167 commit 17781c6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 253 deletions.
170 changes: 15 additions & 155 deletions heat/core/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,37 +152,9 @@ def ge(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__ge__ = lambda self, other: ge(self, other)
DNDarray.__ge__.__doc__ = ge.__doc__


def greater_equal(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Returns a D:class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich greater than or equal comparison between values from operand ``x`` with respect to values of
operand ``y`` (i.e. ``x>=y``), not commutative. Takes the first and second operand (scalar or
:class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument.
Parameters
----------
x: DNDarray or scalar
The first operand to be compared greater than or equal to second operand
y: DNDarray or scalar
The second operand to be compared less than or equal to first operand
Examples
-------
>>> import heat as ht
>>> x = ht.float32([[1, 2],[3, 4]])
>>> ht.ge(x, 3.0)
DNDarray([[False, False],
[ True, True]], dtype=ht.bool, device=cpu:0, split=None)
>>> y = ht.float32([[2, 2], [2, 2]])
>>> ht.ge(x, y)
DNDarray([[False, True],
[ True, True]], dtype=ht.bool, device=cpu:0, split=None)
"""
return ge(x, y)


DNDarray.__greater_equal__ = lambda self, other: ge(self, other)
DNDarray.__greater_equal__.__doc__ = ge.__doc__
# alias
greater_equal = ge
greater_equal.__doc__ = ge.__doc__


def gt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
Expand Down Expand Up @@ -229,37 +201,9 @@ def gt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__gt__ = lambda self, other: gt(self, other)
DNDarray.__gt__.__doc__ = gt.__doc__


def greater(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich greater than comparison between values from operand ``x`` with respect to values of
operand ``y`` (i.e. ``x>y``), not commutative. Takes the first and second operand (scalar or
:class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument.
Parameters
----------
x: DNDarray or scalar
The first operand to be compared greater than second operand
y: DNDarray or scalar
The second operand to be compared less than first operand
Examples
-------
>>> import heat as ht
>>> x = ht.float32([[1, 2],[3, 4]])
>>> ht.gt(x, 3.0)
DNDarray([[False, False],
[False, True]], dtype=ht.bool, device=cpu:0, split=None)
>>> y = ht.float32([[2, 2], [2, 2]])
>>> ht.gt(x, y)
DNDarray([[False, False],
[ True, True]], dtype=ht.bool, device=cpu:0, split=None)
"""
return gt(x, y)


DNDarray.__greater__ = lambda self, other: gt(self, other)
DNDarray.__greater__.__doc__ = gt.__doc__
# alias
greater = gt
greater.__doc__ = gt.__doc__


def le(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
Expand Down Expand Up @@ -306,37 +250,9 @@ def le(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__le__ = lambda self, other: le(self, other)
DNDarray.__le__.__doc__ = le.__doc__


def less_equal(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich less than or equal comparison between values from operand ``x`` with respect to values of
operand ``y`` (i.e. ``x<=y``), not commutative. Takes the first and second operand (scalar or
:class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument.
Parameters
----------
x: DNDarray or scalar
The first operand to be compared less than or equal to second operand
y: DNDarray or scalar
The second operand to be compared greater than or equal to first operand
Examples
-------
>>> import heat as ht
>>> x = ht.float32([[1, 2],[3, 4]])
>>> ht.le(x, 3.0)
DNDarray([[ True, True],
[ True, False]], dtype=ht.bool, device=cpu:0, split=None)
>>> y = ht.float32([[2, 2], [2, 2]])
>>> ht.le(x, y)
DNDarray([[ True, True],
[False, False]], dtype=ht.bool, device=cpu:0, split=None)
"""
return le(x, y)


DNDarray.__less_equal__ = lambda self, other: le(self, other)
DNDarray.__less_equal__.__doc__ = le.__doc__
# alias
less_equal = le
less_equal.__doc__ = le.__doc__


def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
Expand Down Expand Up @@ -383,37 +299,9 @@ def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__lt__ = lambda self, other: lt(self, other)
DNDarray.__lt__.__doc__ = lt.__doc__


def less(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich less than comparison between values from operand ``x`` with respect to values of
operand ``y`` (i.e. ``x<y``), not commutative. Takes the first and second operand (scalar or
:class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument.
Parameters
----------
x: DNDarray or scalar
The first operand to be compared less than second operand
y: DNDarray or scalar
The second operand to be compared greater than first operand
Examples
-------
>>> import heat as ht
>>> x = ht.float32([[1, 2],[3, 4]])
>>> ht.lt(x, 3.0)
DNDarray([[ True, True],
[False, False]], dtype=ht.bool, device=cpu:0, split=None)
>>> y = ht.float32([[2, 2], [2, 2]])
>>> ht.lt(x, y)
DNDarray([[ True, False],
[False, False]], dtype=ht.bool, device=cpu:0, split=None)
"""
return lt(x, y)


DNDarray.__less__ = lambda self, other: lt(self, other)
DNDarray.__less__.__doc__ = lt.__doc__
# alias
less = lt
less.__doc__ = lt.__doc__


def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
Expand Down Expand Up @@ -460,34 +348,6 @@ def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__ne__ = lambda self, other: ne(self, other)
DNDarray.__ne__.__doc__ = ne.__doc__


def not_equal(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich comparison of non-equality between values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be
compared as argument.
Parameters
----------
x: DNDarray or scalar
The first operand involved in the comparison
y: DNDarray or scalar
The second operand involved in the comparison
Examples
---------
>>> import heat as ht
>>> x = ht.float32([[1, 2],[3, 4]])
>>> ht.ne(x, 3.0)
DNDarray([[ True, True],
[False, True]], dtype=ht.bool, device=cpu:0, split=None)
>>> y = ht.float32([[2, 2], [2, 2]])
>>> ht.ne(x, y)
DNDarray([[ True, False],
[ True, True]], dtype=ht.bool, device=cpu:0, split=None)
"""
return ne(x, y)


DNDarray.__not_equal__ = lambda self, other: ne(self, other)
DNDarray.__not_equal__.__doc__ = ne.__doc__
# alias
not_equal = ne
not_equal.__doc__ = ne.__doc__
98 changes: 0 additions & 98 deletions heat/core/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,6 @@ def test_ge(self):
with self.assertRaises(TypeError):
ht.ge("self.a_tensor", "s")

def test_greater_equal(self):
result = ht.uint8([[False, True], [True, True]])
commutated_result = ht.array([[True, True], [False, False]])

self.assertTrue(ht.equal(ht.greater_equal(self.a_scalar, self.a_scalar), ht.array(True)))
self.assertTrue(ht.equal(ht.greater_equal(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.greater_equal(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.greater_equal(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.greater_equal(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.greater_equal(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(
ht.equal(ht.greater_equal(self.a_split_tensor, self.a_tensor), commutated_result)
)

self.assertEqual(ht.greater_equal(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.greater_equal(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.greater_equal(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.greater_equal("self.a_tensor", "s")

def test_gt(self):
result = ht.array([[False, False], [True, True]])
commutated_result = ht.array([[True, False], [False, False]])
Expand All @@ -110,27 +87,6 @@ def test_gt(self):
with self.assertRaises(TypeError):
ht.gt("self.a_tensor", "s")

def test_greater(self):
result = ht.array([[False, False], [True, True]])
commutated_result = ht.array([[True, False], [False, False]])

self.assertTrue(ht.equal(ht.greater(self.a_scalar, self.a_scalar), ht.array(False)))
self.assertTrue(ht.equal(ht.greater(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.greater(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.greater(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.greater(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.greater(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.greater(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.greater(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.greater(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.greater(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.greater("self.a_tensor", "s")

def test_le(self):
result = ht.array([[True, True], [False, False]])
commutated_result = ht.array([[False, True], [True, True]])
Expand All @@ -152,29 +108,6 @@ def test_le(self):
with self.assertRaises(TypeError):
ht.le("self.a_tensor", "s")

def test_less_equal(self):
result = ht.array([[True, True], [False, False]])
commutated_result = ht.array([[False, True], [True, True]])

self.assertTrue(ht.equal(ht.less_equal(self.a_scalar, self.a_scalar), ht.array(True)))
self.assertTrue(ht.equal(ht.less_equal(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.less_equal(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.less_equal(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.less_equal(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.less_equal(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(
ht.equal(ht.less_equal(self.a_split_tensor, self.a_tensor), commutated_result)
)

self.assertEqual(ht.le(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.less_equal(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.less_equal(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.less_equal("self.a_tensor", "s")

def test_lt(self):
result = ht.array([[True, False], [False, False]])
commutated_result = ht.array([[False, False], [True, True]])
Expand All @@ -196,27 +129,6 @@ def test_lt(self):
with self.assertRaises(TypeError):
ht.lt("self.a_tensor", "s")

def test_less(self):
result = ht.array([[True, False], [False, False]])
commutated_result = ht.array([[False, False], [True, True]])

self.assertTrue(ht.equal(ht.less(self.a_scalar, self.a_scalar), ht.array(False)))
self.assertTrue(ht.equal(ht.less(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.less(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.less(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.less(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.less(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.less(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.less(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.less(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.less(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.less("self.a_tensor", "s")

def test_ne(self):
result = ht.array([[True, False], [True, True]])

Expand All @@ -237,13 +149,3 @@ def test_ne(self):
ht.ne(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.ne("self.a_tensor", "s")

def test_not_equal(self):
self.assertEqual(ht.not_equal(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.ne(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.ne(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.ne("self.a_tensor", "s")

0 comments on commit 17781c6

Please sign in to comment.