From 0510437144defe739740b48e0aa370badc3b30dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Wed, 20 Apr 2022 14:25:43 +0200 Subject: [PATCH] Improve save_csv string formatting (#948) * Improve save_csv string formatting Using .format can take up to 2x as long as using %. Also add a test covering an additional line of code. * Changelog message for issue #947 * Improve test coverage --- CHANGELOG.md | 1 + heat/core/io.py | 12 ++++----- heat/core/tests/test_io.py | 55 +++++++++++++++++++++++++++++++++++--- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07aecda8a9..45ff3ea0d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use. - [#940](https://github.com/helmholtz-analytics/heat/pull/940) Duplicate MPI_COMM_WORLD to make library more independent. - [#941](https://github.com/helmholtz-analytics/heat/pull/941) Add function to save data as CSV. +- [#948](https://github.com/helmholtz-analytics/heat/pull/948) Improve CSV write performance. ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension diff --git a/heat/core/io.py b/heat/core/io.py index c58e1d5fd4..0dcf6fc04f 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1004,16 +1004,16 @@ def save_csv( decimals = 0 dec_sep = 0 if sign == 1: - fmt = "{: %dd}" % (pre_point_digits + 1) + fmt = "%%%-dd" % (pre_point_digits + 1) else: - fmt = "{:%dd}" % (pre_point_digits) + fmt = "%%%dd" % (pre_point_digits) elif types.issubdtype(data.dtype, types.floating): if decimals == -1: decimals = 7 if data.dtype is types.float32 else 15 if sign == 1: - fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals) + fmt = "%%%-d.%df" % (pre_point_digits + decimals + 2, decimals) else: - fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals) + fmt = "%%%d.%df" % (pre_point_digits + decimals + 1, decimals) # sign + decimal separator + pre separator digits + decimals (post separator) item_size = decimals + dec_sep + sign + pre_point_digits @@ -1033,11 +1033,11 @@ def save_csv( for i in range(data.lshape[0]): # if lshape is of the form (x,), then there will only be a single element per row if len(data.lshape) == 1: - row = fmt.format(data.larray[i]) + row = fmt % (data.larray[i]) else: if data.lshape[1] == 0: break - row = sep.join(fmt.format(item) for item in data.larray[i]) + row = sep.join(fmt % (item) for item in data.larray[i]) if ( data.split is None diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 47ad519cad..ccfb131f1b 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -152,11 +152,11 @@ def test_save_csv(self): ]: for separator in [",", ";", "|"]: for split in [None, 0, 1]: - for headers in [None, ["# This", "# is a", "# test."], ["an,ordinary,header"]]: + for headers in [None, ["# This", "# is a", "# test."]]: for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]: if rnd_type[0] == ht.random.randint: data = rnd_type[0]( - -100, 1000, size=shape, dtype=rnd_type[1], split=split + -1000, 1000, size=shape, dtype=rnd_type[1], split=split ) else: data = rnd_type[0]( @@ -186,7 +186,7 @@ def test_save_csv(self): # split=split, header_lines=0 if headers is None else len(headers), sep=separator, - ) + ).reshape(shape) resid = data - comparison self.assertTrue( ht.max(resid).item() < 0.00001 and ht.min(resid).item() > -0.00001 @@ -195,6 +195,55 @@ def test_save_csv(self): if data.comm.rank == 0: os.unlink(filename) + # Test vector + data = ht.random.randint(0, 100, size=(150,)) + if data.comm.rank == 0: + tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False) + tmpfile.close() + filename = tmpfile.name + else: + filename = None + filename = data.comm.handle.bcast(filename, root=0) + data.save(filename) + comparison = ht.load(filename).reshape((150,)) + self.assertTrue((data == comparison).all()) + data.comm.handle.Barrier() + if data.comm.rank == 0: + os.unlink(filename) + + # Test 0 matrix + data = ht.zeros((10, 10)) + if data.comm.rank == 0: + tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False) + tmpfile.close() + filename = tmpfile.name + else: + filename = None + filename = data.comm.handle.bcast(filename, root=0) + data.save(filename) + comparison = ht.load(filename) + self.assertTrue((data == comparison).all()) + data.comm.handle.Barrier() + if data.comm.rank == 0: + os.unlink(filename) + + # Test negative float values + data = ht.random.rand(100, 100) + data = data - 500 + if data.comm.rank == 0: + tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False) + tmpfile.close() + filename = tmpfile.name + else: + filename = None + filename = data.comm.handle.bcast(filename, root=0) + data.save(filename) + comparison = ht.load(filename) + self.assertTrue((data == comparison).all()) + data.comm.handle.Barrier() + if data.comm.rank == 0: + os.unlink(filename) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5():