Skip to content

Commit

Permalink
Fix ArrayXD cast (huggingface#6297)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko authored Oct 13, 2023
1 parent 292d627 commit 3e8d420
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
if isinstance(array, pa.ExtensionArray):
array = array.storage
if isinstance(pa_type, pa.ExtensionType):
return pa_type.wrap_array(array)
return pa_type.wrap_array(_c(array, pa_type.storage_type))
elif array.type == pa_type:
return array
elif pa.types.is_struct(array.type):
Expand Down
12 changes: 11 additions & 1 deletion tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import datasets
from datasets import Sequence, Value
from datasets.features.features import Array2DExtensionType, ClassLabel, Features, Image
from datasets.features.features import Array2D, Array2DExtensionType, ClassLabel, Features, Image
from datasets.table import (
ConcatenationTable,
InMemoryTable,
Expand Down Expand Up @@ -1165,6 +1165,16 @@ def test_cast_array_to_features_to_null_type():
cast_array_to_feature(arr, Sequence(Value("null")))


def test_cast_array_to_features_array_xd():
# same storage type
arr = pa.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], pa.list_(pa.list_(pa.int32(), 2), 2))
casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="int32"))
assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="int32")
# different storage type
casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="float32"))
assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="float32")


def test_cast_array_to_features_sequence_classlabel():
arr = pa.array([[], [1], [0, 1]], pa.list_(pa.int64()))
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
Expand Down

0 comments on commit 3e8d420

Please sign in to comment.