Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type handling and resolve #339 #348

Merged
merged 9 commits into from
Aug 28, 2024
24 changes: 10 additions & 14 deletions examples/wave_example.ipynb

Large diffs are not rendered by default.

65 changes: 41 additions & 24 deletions mhkit/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,27 @@ def test_convert_to_dataarray(self):
# test data formats
test_n = d1
test_s = pd.Series(d1, t)
test_df = pd.DataFrame({"d1": d1}, index=t)
test_df2 = pd.DataFrame({"d1": d1, "d1_duplicate": d1}, index=t)
test_df_1d = pd.DataFrame({"d1": d1}, index=t)
test_df_2d = pd.DataFrame({"d1": d1, "d1_duplicate": d1}, index=t)
test_da = xr.DataArray(
data=d1,
dims="time",
coords=dict(time=t),
)
test_ds = xr.Dataset(
test_ds_1d_1v = xr.Dataset(
data_vars={"d1": (["time"], d1)}, coords={"time": t, "index": i}
)
test_ds2 = xr.Dataset(
test_ds_2d_1v = xr.Dataset(
data_vars={
"d1": (["time"], d1),
"d2": (["ind"], d2),
"d1_duplicate": (["time"], d1),
},
coords={"time": t},
)
test_ds_2d_2v = xr.Dataset(
data_vars={
"d1": (["time"], d1),
"d2": (["index"], d2),
},
coords={"time": t, "index": i},
)
Expand All @@ -205,15 +212,33 @@ def test_convert_to_dataarray(self):
self.assertIsInstance(da, xr.DataArray)
self.assertTrue(all(da.data == d1))

# Dataframe
df = utils.convert_to_dataarray(test_df)
self.assertIsInstance(df, xr.DataArray)
self.assertTrue(all(df.data == d1))

# Dataset
ds = utils.convert_to_dataarray(test_ds)
self.assertIsInstance(ds, xr.DataArray)
self.assertTrue(all(ds.data == d1))
# 1D Dataframe
df_1d = utils.convert_to_dataarray(test_df_1d)
self.assertIsInstance(df_1d, xr.DataArray)
self.assertTrue(all(df_1d.data == d1))
self.assertTrue("variable" not in df_1d.dims)

# Multivariate Dataframe
df_2d = utils.convert_to_dataarray(test_df_2d)
self.assertIsInstance(df_2d, xr.DataArray)
self.assertTrue(all(df_2d.sel(variable="d1").data == d1))
self.assertTrue(all(df_2d.sel(variable="d1_duplicate").data == d1))

# 1D Dataset
ds_1d_1v = utils.convert_to_dataarray(test_ds_1d_1v)
self.assertIsInstance(ds_1d_1v, xr.DataArray)
self.assertTrue(all(ds_1d_1v.data == d1))
self.assertTrue("variable" not in ds_1d_1v.dims)

# Multivariate 1D Dataset
ds_2d_1v = utils.convert_to_dataarray(test_ds_2d_1v)
self.assertIsInstance(ds_2d_1v, xr.DataArray)
self.assertTrue(all(ds_2d_1v.sel(variable="d1").data == d1))
self.assertTrue(all(ds_2d_1v.sel(variable="d1_duplicate").data == d1))

# Multivariate 2D Dataset (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_ds_2d_2v)

# int (error)
with self.assertRaises(TypeError):
Expand All @@ -223,14 +248,6 @@ def test_convert_to_dataarray(self):
with self.assertRaises(TypeError):
utils.convert_to_dataarray(test_n, 5)

# Multivariate Dataframe (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_df2)

# Multivariate Dataset (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_ds2)

def test_convert_to_dataset(self):
# test data
a = 5
Expand All @@ -242,7 +259,7 @@ def test_convert_to_dataset(self):
# test data formats
test_n = d1
test_s = pd.Series(d1, t)
test_df2 = pd.DataFrame({"d1": d1, "d2": d2}, index=t)
test_df_2d = pd.DataFrame({"d1": d1, "d2": d2}, index=t)
test_da = xr.DataArray(
data=d1,
dims="time",
Expand All @@ -267,7 +284,7 @@ def test_convert_to_dataset(self):
self.assertTrue(all(da["test_name"].data == d1))

# Dataframe
df = utils.convert_to_dataset(test_df2)
df = utils.convert_to_dataset(test_df_2d)
self.assertIsInstance(df, xr.Dataset)
self.assertTrue(all(df["d1"].data == d1))
self.assertTrue(all(df["d2"].data == d2))
Expand Down
62 changes: 39 additions & 23 deletions mhkit/utils/type_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def convert_to_dataarray(data, name="data"):
"""
Converts the given data to an xarray.DataArray.

This function is designed to handle inputs that can be either a numpy ndarray, pandas Series,
or an xarray DataArray. For convenience, pandas DataFrame and xarray Dataset can also be input
but may only contain a single variable. The function ensures that the output is consistently
an xarray.DataArray.
This function takes in a numpy ndarray, pandas Series, pandas Dataframe, or xarray Dataset
and outputs an equivalent xarray DataArray. DataArrays can be passed through with no changes.

Xarray datasets can only be input when all variable have the same dimensions.

Multivariate pandas Dataframes become 2D DataArrays, which is especially useful when IO
functions return Dataframes with an extremely large number of variable. Use the function
convert_to_dataset to change a multivariate Dataframe into a multivariate Dataset.

Parameters
----------
Expand Down Expand Up @@ -138,7 +142,7 @@ def convert_to_dataarray(data, name="data"):
data, (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray, xr.Dataset)
):
raise TypeError(
"Input data must be of type np.ndarray, pandas.DataFrame, pandas.Series, "
"Input data must be of type np.ndarray, pandas.Series, pandas.DataFrame, "
f"xarray.DataArray, or xarray.Dataset. Got {type(data)}"
)

Expand All @@ -147,40 +151,52 @@ def convert_to_dataarray(data, name="data"):

# Checks pd.DataFrame input and converts to pd.Series if possible
if isinstance(data, pd.DataFrame):
if data.shape[1] > 1:
raise ValueError(
"If the input data is a pd.DataFrame or xr.Dataset, it must contain one variable. Got {data.shape[1]}"
)
else:
# use iloc instead of squeeze. For DataFrames/Series with only a
# single value, squeeze returns a scalar, which is unexpected.
# iloc will return a Series as expected
if data.shape[1] == 1:
# Convert the 1D, univariate case to a Series, which will be caught by the Series conversion below.
# This eliminates an unnecessary variable dimension and names the DataArray with the DataFrame variable name.
#
# Use iloc instead of squeeze. For DataFrames/Series with only a
# single value, squeeze returns a scalar which is unexpected.
# iloc returns a Series with one value as expected.
data = data.iloc[:, 0]
else:
index = data.index.values
columns = data.columns.values
data = xr.DataArray(
data=data.T,
dims=("variable", "index"),
coords={"variable": columns, "index": index},
)

# Checks xr.Dataset input and converts to xr.DataArray if possible
if isinstance(data, xr.Dataset):
keys = list(data.keys())
if len(keys) > 1:
raise ValueError(
"If the input data is a pd.DataFrame or xr.Dataset, it must contain one variable. Got {len(data.keys())}"
)
else:
if len(keys) == 1:
# if only one variable, remove the "variable" dimension and rename the DataArray to simplify
data = data.to_array()
data = data.sel(
variable=keys[0]
) # removes the variable dimension, further simplifying the dataarray
data = data.sel(variable=keys[0])
data.name = keys[0]
data.drop_vars("variable")
else:
# Allow multiple variables if they have the same dimensions
if all([data[keys[0]].dims == data[key].dims for key in keys]):
data = data.to_array()
else:
raise ValueError(
"Multivariate Datasets can only be input if all variables have the same dimensions."
)

# Converts pd.Series to xr.DataArray
if isinstance(data, pd.Series):
data = data.to_xarray()

# Converts np.ndarray to xr.DataArray. Assigns a simple 0-based dimension named index
# Converts np.ndarray to xr.DataArray. Assigns a simple 0-based dimension named index to match how pandas converts to xarray
if isinstance(data, np.ndarray):
data = xr.DataArray(
data=data, dims="index", coords={"index": np.arange(len(data))}
)

# If there's no data name, add one to prevent issues calling or converting the dataArray later one
# If there's no data name, add one to prevent issues calling or converting to a Dataset later on
if data.name == None:
data.name = name

Expand Down
Loading