From fc57d7180ed68eb59d61edad9de8e8552426a0ac Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <m.tarnawa@fz-juelich.de> Date: Fri, 12 Mar 2021 16:27:58 +0100 Subject: [PATCH 1/3] add iscomplex, isreal --- heat/core/types.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/heat/core/types.py b/heat/core/types.py index 1d3b5d305a..a97d755999 100644 --- a/heat/core/types.py +++ b/heat/core/types.py @@ -27,6 +27,8 @@ from . import communication from . import devices from . import factories +from . import _operations +from . import sanitation __all__ = [ @@ -55,6 +57,8 @@ "double", "flexible", "can_cast", + "iscomplex", + "isreal", "issubdtype", "promote_types", "complex64", @@ -607,6 +611,42 @@ def can_cast(from_, to, casting="intuitive"): break +def iscomplex(x): + """ + Test element-wise if input is complex. + + Parameters + ---------- + x : DNDarray + + Examples + -------- + >>> ht.iscomplex(ht.array([1+1j, 1])) + DNDarray([ True, False], dtype=ht.bool, device=cpu:0, split=None) + """ + sanitation.sanitize_in(x) + + if issubclass(x.dtype, _complexfloating): + return x.imag != 0 + else: + return factories.zeros(x.shape, bool, split=x.split, device=x.device, comm=x.comm) + + +def isreal(x): + """ + Test element-wise if input is real-valued. + + Parameters + ---------- + x : DNDarray + + Examples + -------- + ht.iscomplex(ht.array([1+1j, 1])) + """ + return _operations.__local_op(torch.isreal, x, None, no_cast=True) + + def issubdtype(arg1, arg2): """ Returns True if first argument is a typecode lower/equal in type hierarchy. From 8fe424d800b2f6a7296bf34f9d59607834a48877 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <m.tarnawa@fz-juelich.de> Date: Fri, 12 Mar 2021 16:31:34 +0100 Subject: [PATCH 2/3] add tests for iscomplex isreal --- heat/core/tests/test_types.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/heat/core/tests/test_types.py b/heat/core/tests/test_types.py index 5b98defb6c..7108650ad0 100644 --- a/heat/core/tests/test_types.py +++ b/heat/core/tests/test_types.py @@ -113,6 +113,72 @@ def test_complex128(self): self.assertEqual(ht.complex128.char(), "c16") + def test_iscomplex(self): + a = ht.array([1, 1.2, 1 + 1j, 1 + 0j]) + s = ht.array([False, False, True, False]) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.array([1, 1.2, True], split=0) + s = ht.array([False, False, False], split=0) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.ones((6, 6), dtype=ht.bool, split=0) + s = ht.zeros((6, 6), dtype=ht.bool, split=0) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1) + s = ht.ones((5, 5), dtype=ht.bool, split=1) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + def test_isreal(self): + a = ht.array([1, 1.2, 1 + 1j, 1 + 0j]) + s = ht.array([True, True, False, True]) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.array([1, 1.2, True], split=0) + s = ht.array([True, True, True], split=0) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.ones((6, 6), dtype=ht.bool, split=0) + s = ht.ones((6, 6), dtype=ht.bool, split=0) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1) + s = ht.zeros((5, 5), dtype=ht.bool, split=1) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + class TestTypeConversion(TestCase): def test_can_cast(self): From ad507a014e47fd6d285e113f785d76f8ef7fd444 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <m.tarnawa@fz-juelich.de> Date: Fri, 12 Mar 2021 16:37:58 +0100 Subject: [PATCH 3/3] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b85c068518..00c6e0284d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ ### Logical - [#711](https://github.com/helmholtz-analytics/heat/pull/711) `isfinite()`, `isinf()`, `isnan()` +### Types +- [#738](https://github.com/helmholtz-analytics/heat/pull/738) `iscomplex()`, `isreal()` + ## Bug fixes - [#709](https://github.com/helmholtz-analytics/heat/pull/709) Set the encoding for README.md in setup.py explicitly.