From e409b30591eee1053fb1732c59f7e06708c66fd5 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 29 Jul 2021 15:24:33 +0200 Subject: [PATCH] substitute factories.array with DNDarray --- heat/core/logical.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/heat/core/logical.py b/heat/core/logical.py index 3942047906..5e299bd6e5 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -237,13 +237,29 @@ def isclose( output_gshape = stride_tricks.broadcast_shape(t1.gshape, t2.gshape) res = torch.empty(output_gshape, device=t1.device.torch_device).bool() t1.comm.Allgather(_local_isclose, res) - result = factories.array(res, dtype=types.bool, device=t1.device, split=t1.split) + result = DNDarray( + res, + gshape=output_gshape, + dtype=types.bool, + split=t1.split, + device=t1.device, + comm=t1.comm, + balanced=t1.is_balanced, + ) else: if _local_isclose.dim() == 0: # both x and y are scalars, return a single boolean value - result = bool(factories.array(_local_isclose).item()) + result = bool(_local_isclose.item()) else: - result = factories.array(_local_isclose, dtype=types.bool, device=t1.device) + result = DNDarray( + _local_isclose, + gshape=tuple(_local_isclose.shape), + dtype=types.bool, + split=None, + device=t1.device, + comm=t1.comm, + balanced=t1.is_balanced, + ) return result