Skip to content

Commit

Permalink
Improve save_csv string formatting (#948)
Browse files Browse the repository at this point in the history
* Improve save_csv string formatting

Using .format can take up to 2x as long as using %.

Also add a test covering an additional line of code.

* Changelog message for issue #947

* Improve test coverage
  • Loading branch information
bhagemeier authored Apr 20, 2022
1 parent 5f77902 commit 0510437
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use.
- [#940](https://github.com/helmholtz-analytics/heat/pull/940) Duplicate MPI_COMM_WORLD to make library more independent.
- [#941](https://github.com/helmholtz-analytics/heat/pull/941) Add function to save data as CSV.
- [#948](https://github.com/helmholtz-analytics/heat/pull/948) Improve CSV write performance.

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
Expand Down
12 changes: 6 additions & 6 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,16 +1004,16 @@ def save_csv(
decimals = 0
dec_sep = 0
if sign == 1:
fmt = "{: %dd}" % (pre_point_digits + 1)
fmt = "%%%-dd" % (pre_point_digits + 1)
else:
fmt = "{:%dd}" % (pre_point_digits)
fmt = "%%%dd" % (pre_point_digits)
elif types.issubdtype(data.dtype, types.floating):
if decimals == -1:
decimals = 7 if data.dtype is types.float32 else 15
if sign == 1:
fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals)
fmt = "%%%-d.%df" % (pre_point_digits + decimals + 2, decimals)
else:
fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals)
fmt = "%%%d.%df" % (pre_point_digits + decimals + 1, decimals)

# sign + decimal separator + pre separator digits + decimals (post separator)
item_size = decimals + dec_sep + sign + pre_point_digits
Expand All @@ -1033,11 +1033,11 @@ def save_csv(
for i in range(data.lshape[0]):
# if lshape is of the form (x,), then there will only be a single element per row
if len(data.lshape) == 1:
row = fmt.format(data.larray[i])
row = fmt % (data.larray[i])
else:
if data.lshape[1] == 0:
break
row = sep.join(fmt.format(item) for item in data.larray[i])
row = sep.join(fmt % (item) for item in data.larray[i])

if (
data.split is None
Expand Down
55 changes: 52 additions & 3 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def test_save_csv(self):
]:
for separator in [",", ";", "|"]:
for split in [None, 0, 1]:
for headers in [None, ["# This", "# is a", "# test."], ["an,ordinary,header"]]:
for headers in [None, ["# This", "# is a", "# test."]]:
for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]:
if rnd_type[0] == ht.random.randint:
data = rnd_type[0](
-100, 1000, size=shape, dtype=rnd_type[1], split=split
-1000, 1000, size=shape, dtype=rnd_type[1], split=split
)
else:
data = rnd_type[0](
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_save_csv(self):
# split=split,
header_lines=0 if headers is None else len(headers),
sep=separator,
)
).reshape(shape)
resid = data - comparison
self.assertTrue(
ht.max(resid).item() < 0.00001 and ht.min(resid).item() > -0.00001
Expand All @@ -195,6 +195,55 @@ def test_save_csv(self):
if data.comm.rank == 0:
os.unlink(filename)

# Test vector
data = ht.random.randint(0, 100, size=(150,))
if data.comm.rank == 0:
tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False)
tmpfile.close()
filename = tmpfile.name
else:
filename = None
filename = data.comm.handle.bcast(filename, root=0)
data.save(filename)
comparison = ht.load(filename).reshape((150,))
self.assertTrue((data == comparison).all())
data.comm.handle.Barrier()
if data.comm.rank == 0:
os.unlink(filename)

# Test 0 matrix
data = ht.zeros((10, 10))
if data.comm.rank == 0:
tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False)
tmpfile.close()
filename = tmpfile.name
else:
filename = None
filename = data.comm.handle.bcast(filename, root=0)
data.save(filename)
comparison = ht.load(filename)
self.assertTrue((data == comparison).all())
data.comm.handle.Barrier()
if data.comm.rank == 0:
os.unlink(filename)

# Test negative float values
data = ht.random.rand(100, 100)
data = data - 500
if data.comm.rank == 0:
tmpfile = tempfile.NamedTemporaryFile(prefix="test_io_", suffix=".csv", delete=False)
tmpfile.close()
filename = tmpfile.name
else:
filename = None
filename = data.comm.handle.bcast(filename, root=0)
data.save(filename)
comparison = ht.load(filename)
self.assertTrue((data == comparison).all())
data.comm.handle.Barrier()
if data.comm.rank == 0:
os.unlink(filename)

def test_load_exception(self):
# correct extension, file does not exist
if ht.io.supports_hdf5():
Expand Down

0 comments on commit 0510437

Please sign in to comment.