Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IO import issues #787

Merged
merged 10 commits into from
Jun 14, 2021
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ Example on 2 processes:
- [#768](https://github.com/helmholtz-analytics/heat/pull/768) Fixed an issue where `deg2rad` and `rad2deg`are not working with the 'out' parameter.
- [#785](https://github.com/helmholtz-analytics/heat/pull/785) Removed `storage_offset` when finding the mpi buffer (`communication. MPICommunication.as_mpi_memory()`).
- [#785](https://github.com/helmholtz-analytics/heat/pull/785) added allowance for 1 dimensional non-contiguous local tensors in `communication. MPICommunication.mpi_type_and_elements_of()`
- [#787](https://github.com/helmholtz-analytics/heat/pull/787) Fixed an issue where Heat cannot be imported when some optional dependencies are not available.
- [#790](https://github.com/helmholtz-analytics/heat/pull/790) catch incorrect device after `bcast` in `DNDarray.__getitem__`
### Linear Algebra
- [#718](https://github.com/helmholtz-analytics/heat/pull/718) New feature: `trace()`
- [#768](https://github.com/helmholtz-analytics/heat/pull/768) New feature: unary positive and negative operations
### Misc.
- [#761](https://github.com/helmholtz-analytics/heat/pull/761) New feature: `result_type`


# v1.0.0

## New features / Highlights
Expand Down
4 changes: 4 additions & 0 deletions heat/classification/tests/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
import heat as ht

from heat.classification.kneighborsclassifier import KNeighborsClassifier
from heat.core.tests.test_suites.basic_test import TestCase


class TestKNN(TestCase):
@unittest.skipUnless(ht.supports_hdf5(), "Requires HDF5")
def test_split_none(self):
x = ht.load_hdf5("heat/datasets/iris.h5", dataset="data")

Expand All @@ -27,6 +29,7 @@ def test_split_none(self):
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.shape, y.shape)

@unittest.skipUnless(ht.supports_hdf5(), "Requires HDF5")
def test_split_zero(self):
x = ht.load_hdf5("heat/datasets/iris.h5", dataset="data", split=0)

Expand Down Expand Up @@ -73,6 +76,7 @@ def test_utility(self,):
one_hot = KNeighborsClassifier.one_hot_encoding(a)
self.assertTrue((one_hot == b).all())

@unittest.skipUnless(ht.supports_hdf5(), "Requires HDF5")
def test_fit_one_hot(self,):
x = ht.load_hdf5("heat/datasets/iris.h5", dataset="data")

Expand Down
30 changes: 21 additions & 9 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
__NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"])
__NETCDF_DIM_TEMPLATE = "{}_dim_{}"

__all__ = ["load", "load_csv", "save", "supports_hdf5", "supports_netcdf", "save_hdf5"]
__all__ = ["load", "load_csv", "save", "supports_hdf5", "supports_netcdf"]

try:
import h5py
Expand Down Expand Up @@ -691,10 +691,16 @@ def load(

if extension in __CSV_EXTENSION:
return load_csv(path, *args, **kwargs)
elif supports_hdf5() and extension in __HDF5_EXTENSIONS:
return load_hdf5(path, *args, **kwargs)
elif supports_netcdf() and extension in __NETCDF_EXTENSIONS:
return load_netcdf(path, *args, **kwargs)
elif extension in __HDF5_EXTENSIONS:
if supports_hdf5():
return load_hdf5(path, *args, **kwargs)
else:
raise ImportError("hdf5 is required for file extension {}".format(extension))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think that these should be either RunTimeErrors or TypeErrors. the issue that this is talking about is that the package is not installed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using ImportError because it's an external package that needs to be imported.
I don't know if TypeError is good here as the function gets the right argument types. ValueError maybe.
A general RuntimeError might be the best here.

elif extension in __NETCDF_EXTENSIONS:
if supports_netcdf():
return load_netcdf(path, *args, **kwargs)
else:
raise ImportError("netcdf is required for file extension {}".format(extension))
else:
raise ValueError("Unsupported file extension {}".format(extension))

Expand Down Expand Up @@ -944,10 +950,16 @@ def save(
raise TypeError("Expected path to be str, but was {}".format(type(path)))
extension = os.path.splitext(path)[-1].strip().lower()

if supports_hdf5() and extension in __HDF5_EXTENSIONS:
save_hdf5(data, path, *args, **kwargs)
elif supports_netcdf() and extension in __NETCDF_EXTENSIONS:
save_netcdf(data, path, *args, **kwargs)
if extension in __HDF5_EXTENSIONS:
if supports_hdf5():
save_hdf5(data, path, *args, **kwargs)
else:
raise ImportError("hdf5 is required for file extension {}".format(extension))
elif extension in __NETCDF_EXTENSIONS:
if supports_netcdf():
save_netcdf(data, path, *args, **kwargs)
else:
raise ImportError("netcdf is required for file extension {}".format(extension))
else:
raise ValueError("Unsupported file extension {}".format(extension))

Expand Down
29 changes: 16 additions & 13 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_load(self):
# content
self.assertTrue((self.IRIS == iris.larray).all())
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
_ = ht.load(self.HDF5_PATH, dataset=self.HDF5_DATASET)

# netCDF
Expand All @@ -80,7 +80,7 @@ def test_load(self):
# content
self.assertTrue((self.IRIS == iris.larray).all())
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
_ = ht.load(self.NETCDF_PATH, variable=self.NETCDF_VARIABLE)

def test_load_csv(self):
Expand Down Expand Up @@ -147,14 +147,14 @@ def test_load_exception(self):
with self.assertRaises(IOError):
ht.load("foo.h5", "data")
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
ht.load("foo.h5", "data")

if ht.io.supports_netcdf():
with self.assertRaises(IOError):
ht.load("foo.nc", "data")
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
ht.load("foo.nc", "data")

# unknown file extension
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_save_exception(self):
with self.assertRaises(TypeError):
ht.save(data, self.HDF5_OUT_PATH, 1)
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
ht.save(data, self.HDF5_OUT_PATH, self.HDF5_DATASET)

if ht.io.supports_netcdf():
Expand Down Expand Up @@ -409,13 +409,16 @@ def test_save_exception(self):
mode="a",
)
else:
with self.assertRaises(ValueError):
with self.assertRaises(ImportError):
ht.save(data, self.NETCDF_OUT_PATH, self.NETCDF_VARIABLE)

with self.assertRaises(ValueError):
ht.save(1, "data.dat")

def test_load_hdf5(self):
# HDF5 support is optional
if not ht.io.supports_hdf5():
return
self.skipTest("Requires HDF5")

# default parameters
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET)
Expand Down Expand Up @@ -453,7 +456,7 @@ def test_load_hdf5(self):
def test_load_hdf5_exception(self):
# HDF5 support is optional
if not ht.io.supports_hdf5():
return
self.skipTest("Requires HDF5")

# improper argument types
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -501,7 +504,7 @@ def test_save_hdf5(self):
def test_save_hdf5_exception(self):
# HDF5 support is optional
if not ht.io.supports_hdf5():
return
self.skipTest("Requires HDF5")

# dummy data
data = ht.arange(1)
Expand All @@ -516,7 +519,7 @@ def test_save_hdf5_exception(self):
def test_load_netcdf(self):
# netcdf support is optional
if not ht.io.supports_netcdf():
return
self.skipTest("Requires NetCDF")

# default parameters
iris = ht.load_netcdf(self.NETCDF_PATH, self.NETCDF_VARIABLE)
Expand Down Expand Up @@ -554,7 +557,7 @@ def test_load_netcdf(self):
def test_load_netcdf_exception(self):
# netcdf support is optional
if not ht.io.supports_netcdf():
return
self.skipTest("Requires NetCDF")

# improper argument types
with self.assertRaises(TypeError):
Expand All @@ -573,7 +576,7 @@ def test_load_netcdf_exception(self):
def test_save_netcdf(self):
# netcdf support is optional
if not ht.io.supports_netcdf():
return
self.skipTest("Requires NetCDF")

# local unsplit data
local_data = ht.arange(100)
Expand Down Expand Up @@ -602,7 +605,7 @@ def test_save_netcdf(self):
def test_save_netcdf_exception(self):
# netcdf support is optional
if not ht.io.supports_netcdf():
return
self.skipTest("Requires NetCDF")

# dummy data
data = ht.arange(1)
Expand Down
1 change: 1 addition & 0 deletions heat/utils/data/tests/test_partial_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class TestPartialDataset(unittest.TestCase):
@unittest.skipUnless(ht.supports_hdf5(), "Requires HDF5")
def test_partial_h5_dataset(self):
# load h5 data and get the total shape
full_data = ht.load("heat/datasets/iris.h5", dataset="data", split=None)
Expand Down