From de7da33c3d077efc3ca7a55db72b114c1735e543 Mon Sep 17 00:00:00 2001 From: "E. G. Patrick Bos" Date: Fri, 26 Jul 2024 15:46:52 +0200 Subject: [PATCH] support ArrayLike data in to_xarray This commit makes dianna swallow more kinds of data items without users having to mangle them manually themselves. As long as your data is numpy ArrayLike it will go. The only blocker for this was the direct use of the .ndim attribute on data, which assumes it is already a numpy array rather than something that can trivially be converted to a numpy array. A PIL.Image is an example of such an ArrayLike type. We use this in explainable_embedding to feed items into the OpenAI CLIP model. --- dianna/utils/misc.py | 8 +++++--- setup.cfg | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dianna/utils/misc.py b/dianna/utils/misc.py index 944517d5..a2933dfe 100644 --- a/dianna/utils/misc.py +++ b/dianna/utils/misc.py @@ -1,6 +1,8 @@ import inspect import warnings from pathlib import Path +import numpy as np +import numpy.typing def get_function(model_or_function, preprocess_function=None): @@ -49,7 +51,7 @@ def get_kwargs_applicable_to_function(function, kwargs): } -def to_xarray(data, axis_labels, required_labels=None): +def to_xarray(data: numpy.typing.ArrayLike, axis_labels, required_labels=None): """Converts numpy data and axes labels to an xarray object.""" if isinstance(axis_labels, dict): # key = axis index, value = label @@ -59,10 +61,10 @@ def to_xarray(data, axis_labels, required_labels=None): indices = list(axis_labels.keys()) for index in indices: if index < 0: - axis_labels[data.ndim + index] = axis_labels.pop(index) + axis_labels[np.ndim(data) + index] = axis_labels.pop(index) labels = [ axis_labels[index] if index in axis_labels else f'dim_{index}' - for index in range(data.ndim) + for index in range(np.ndim(data)) ] else: labels = list(axis_labels) diff --git a/setup.cfg b/setup.cfg index a063f34e..97d15c0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ install_requires = ipython lime matplotlib - numpy + numpy>=1.20 onnx==1.14.1 onnx_tf onnxruntime