diff --git a/CHANGELOG.md b/CHANGELOG.md index 5326ef1f29..fa6cc50365 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ - [#711](https://github.com/helmholtz-analytics/heat/pull/711) `isfinite()`, `isinf()`, `isnan()` - [#743](https://github.com/helmholtz-analytics/heat/pull/743) `isneginf()`, `isposinf()` +### Types +- [#738](https://github.com/helmholtz-analytics/heat/pull/738) `iscomplex()`, `isreal()` + ## Bug fixes - [#709](https://github.com/helmholtz-analytics/heat/pull/709) Set the encoding for README.md in setup.py explicitly. diff --git a/heat/core/tests/test_types.py b/heat/core/tests/test_types.py index 5b98defb6c..7108650ad0 100644 --- a/heat/core/tests/test_types.py +++ b/heat/core/tests/test_types.py @@ -113,6 +113,72 @@ def test_complex128(self): self.assertEqual(ht.complex128.char(), "c16") + def test_iscomplex(self): + a = ht.array([1, 1.2, 1 + 1j, 1 + 0j]) + s = ht.array([False, False, True, False]) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.array([1, 1.2, True], split=0) + s = ht.array([False, False, False], split=0) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.ones((6, 6), dtype=ht.bool, split=0) + s = ht.zeros((6, 6), dtype=ht.bool, split=0) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1) + s = ht.ones((5, 5), dtype=ht.bool, split=1) + r = ht.iscomplex(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + def test_isreal(self): + a = ht.array([1, 1.2, 1 + 1j, 1 + 0j]) + s = ht.array([True, True, False, True]) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.array([1, 1.2, True], split=0) + s = ht.array([True, True, True], split=0) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.ones((6, 6), dtype=ht.bool, split=0) + s = ht.ones((6, 6), dtype=ht.bool, split=0) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + + a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1) + s = ht.zeros((5, 5), dtype=ht.bool, split=1) + r = ht.isreal(a) + self.assertEqual(r.shape, s.shape) + self.assertEqual(r.dtype, s.dtype) + self.assertEqual(r.device, s.device) + self.assertTrue(ht.equal(r, s)) + class TestTypeConversion(TestCase): def test_can_cast(self): diff --git a/heat/core/types.py b/heat/core/types.py index 1d3b5d305a..a97d755999 100644 --- a/heat/core/types.py +++ b/heat/core/types.py @@ -27,6 +27,8 @@ from . import communication from . import devices from . import factories +from . import _operations +from . import sanitation __all__ = [ @@ -55,6 +57,8 @@ "double", "flexible", "can_cast", + "iscomplex", + "isreal", "issubdtype", "promote_types", "complex64", @@ -607,6 +611,42 @@ def can_cast(from_, to, casting="intuitive"): break +def iscomplex(x): + """ + Test element-wise if input is complex. + + Parameters + ---------- + x : DNDarray + + Examples + -------- + >>> ht.iscomplex(ht.array([1+1j, 1])) + DNDarray([ True, False], dtype=ht.bool, device=cpu:0, split=None) + """ + sanitation.sanitize_in(x) + + if issubclass(x.dtype, _complexfloating): + return x.imag != 0 + else: + return factories.zeros(x.shape, bool, split=x.split, device=x.device, comm=x.comm) + + +def isreal(x): + """ + Test element-wise if input is real-valued. + + Parameters + ---------- + x : DNDarray + + Examples + -------- + ht.iscomplex(ht.array([1+1j, 1])) + """ + return _operations.__local_op(torch.isreal, x, None, no_cast=True) + + def issubdtype(arg1, arg2): """ Returns True if first argument is a typecode lower/equal in type hierarchy.