diff --git a/dianna/utils/misc.py b/dianna/utils/misc.py index 944517d5..a2933dfe 100644 --- a/dianna/utils/misc.py +++ b/dianna/utils/misc.py @@ -1,6 +1,8 @@ import inspect import warnings from pathlib import Path +import numpy as np +import numpy.typing def get_function(model_or_function, preprocess_function=None): @@ -49,7 +51,7 @@ def get_kwargs_applicable_to_function(function, kwargs): } -def to_xarray(data, axis_labels, required_labels=None): +def to_xarray(data: numpy.typing.ArrayLike, axis_labels, required_labels=None): """Converts numpy data and axes labels to an xarray object.""" if isinstance(axis_labels, dict): # key = axis index, value = label @@ -59,10 +61,10 @@ def to_xarray(data, axis_labels, required_labels=None): indices = list(axis_labels.keys()) for index in indices: if index < 0: - axis_labels[data.ndim + index] = axis_labels.pop(index) + axis_labels[np.ndim(data) + index] = axis_labels.pop(index) labels = [ axis_labels[index] if index in axis_labels else f'dim_{index}' - for index in range(data.ndim) + for index in range(np.ndim(data)) ] else: labels = list(axis_labels) diff --git a/setup.cfg b/setup.cfg index a063f34e..97d15c0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ install_requires = ipython lime matplotlib - numpy + numpy>=1.20 onnx==1.14.1 onnx_tf onnxruntime