Skip to content

Commit

Permalink
Add support for categorical/dictionary types (#6892)
Browse files Browse the repository at this point in the history
* Add support for dictionary types

* Add unit tests

* Style fix

* bump pyarrow

* convert in Dataset init

* remove beam from tests

---------

Co-authored-by: Quentin Lhoest <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2024
1 parent a2dc287 commit 686f5df
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: uv pip install --system --upgrade pyarrow huggingface-hub dill
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: uv pip install --system pyarrow==12.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
run: uv pip install --system pyarrow==15.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17",
# Backend and serialization.
# Minimum 12.0.0 to be able to concatenate extension arrays
"pyarrow>=12.0.0",
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
"pyarrow>=15.0.0",
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
"pyarrow-hotfix",
# For smart caching dataset processing
Expand Down Expand Up @@ -166,7 +166,6 @@
"pytest-datadir",
"pytest-xdist",
# optional dependencies
"apache-beam>=2.26.0; sys_platform != 'win32' and python_version<'3.10'", # doesn't support recent dill versions for recent python versions and on windows requires pyarrow<12.0.0
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"jax>=0.3.14; sys_platform != 'win32'",
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,11 @@ def __init__(
f"{e}\nThe 'source' features come from dataset_info.json, and the 'target' ones are those of the dataset arrow file."
)

# In case there are types like pa.dictionary that we need to convert to the underlying type

if self.data.schema != self.info.features.arrow_schema:
self._data = self.data.cast(self.info.features.arrow_schema)

# Infer fingerprint if None

if self._fingerprint is None:
Expand Down
7 changes: 5 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
return "string"
elif pyarrow.types.is_large_string(arrow_type):
return "large_string"
elif pyarrow.types.is_dictionary(arrow_type):
return _arrow_to_datasets_dtype(arrow_type.value_type)
else:
raise ValueError(f"Arrow type {arrow_type} does not have a datasets dtype equivalent.")

Expand Down Expand Up @@ -1434,8 +1436,6 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
elif isinstance(pa_type, _ArrayXDExtensionType):
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
elif isinstance(pa_type, pa.DictionaryType):
raise NotImplementedError # TODO(thom) this will need access to the dictionary as well (for labels). I.e. to the py_table
elif isinstance(pa_type, pa.DataType):
return Value(dtype=_arrow_to_datasets_dtype(pa_type))
else:
Expand Down Expand Up @@ -1705,6 +1705,9 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features":
It also checks the schema metadata for Hugging Face Datasets features.
Non-nullable fields are not supported and set to nullable.
Also, pa.dictionary is not supported and it uses its underlying type instead.
Therefore datasets convert DictionaryArray objects to their actual values.
Args:
pa_schema (`pyarrow.Schema`):
Arrow Schema.
Expand Down
6 changes: 6 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def test_string_to_arrow_bijection_for_primitive_types(self):
with self.assertRaises(ValueError):
string_to_arrow(sdt)

def test_categorical_one_way(self):
# Categorical types (aka dictionary types) need special handling as there isn't a bijection
categorical_type = pa.dictionary(pa.int32(), pa.string())

self.assertEqual("string", _arrow_to_datasets_dtype(categorical_type))

def test_feature_named_type(self):
"""reference: issue #1110"""
features = Features({"_type": Value("string")})
Expand Down
18 changes: 18 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4826,3 +4826,21 @@ def test_dataset_getitem_raises():
ds[False]
with pytest.raises(TypeError):
ds._getitem(True)


def test_categorical_dataset(tmpdir):
n_legs = pa.array([2, 4, 5, 100])
animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]).cast(
pa.dictionary(pa.int32(), pa.string())
)
names = ["n_legs", "animals"]

table = pa.Table.from_arrays([n_legs, animals], names=names)
table_path = str(tmpdir / "data.parquet")
pa.parquet.write_table(table, table_path)

dataset = Dataset.from_parquet(table_path)
entry = dataset[0]

# Categorical types get transparently converted to string
assert entry["animals"] == "Flamingo"

0 comments on commit 686f5df

Please sign in to comment.