Skip to content

Commit

Permalink
support ArrayLike data in to_xarray
Browse files Browse the repository at this point in the history
This commit makes dianna swallow more kinds of data items without users having to mangle them manually themselves. As long as your data is numpy ArrayLike it will go. The only blocker for this was the direct use of the .ndim attribute on data, which assumes it is already a numpy array rather than something that can trivially be converted to a numpy array. A PIL.Image is an example of such an ArrayLike type. We use this in explainable_embedding to feed items into the OpenAI CLIP model.
  • Loading branch information
egpbos authored and loostrum committed Jul 31, 2024
1 parent 430be09 commit de7da33
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions dianna/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
import warnings
from pathlib import Path
import numpy as np
import numpy.typing


def get_function(model_or_function, preprocess_function=None):
Expand Down Expand Up @@ -49,7 +51,7 @@ def get_kwargs_applicable_to_function(function, kwargs):
}


def to_xarray(data, axis_labels, required_labels=None):
def to_xarray(data: numpy.typing.ArrayLike, axis_labels, required_labels=None):
"""Converts numpy data and axes labels to an xarray object."""
if isinstance(axis_labels, dict):
# key = axis index, value = label
Expand All @@ -59,10 +61,10 @@ def to_xarray(data, axis_labels, required_labels=None):
indices = list(axis_labels.keys())
for index in indices:
if index < 0:
axis_labels[data.ndim + index] = axis_labels.pop(index)
axis_labels[np.ndim(data) + index] = axis_labels.pop(index)
labels = [
axis_labels[index] if index in axis_labels else f'dim_{index}'
for index in range(data.ndim)
for index in range(np.ndim(data))
]
else:
labels = list(axis_labels)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ install_requires =
ipython
lime
matplotlib
numpy
numpy>=1.20
onnx==1.14.1
onnx_tf
onnxruntime
Expand Down

0 comments on commit de7da33

Please sign in to comment.