Skip to content

Commit

Permalink
add pandas pyarrow and polars docs
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jan 31, 2025
1 parent c2b7303 commit 581087e
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 24 deletions.
12 changes: 10 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@
title: Process
- local: stream
title: Stream
- local: use_with_tensorflow
title: Use with TensorFlow
- local: use_with_pytorch
title: Use with PyTorch
- local: use_with_tensorflow
title: Use with TensorFlow
- local: use_with_numpy
title: Use with NumPy
- local: use_with_jax
title: Use with JAX
- local: use_with_pandas
title: Use with Pandas
- local: use_with_polars
title: Use with Polars
- local: use_with_pyarrow
title: Use with PyArrow
- local: use_with_spark
title: Use with Spark
- local: cache
Expand Down
90 changes: 73 additions & 17 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -630,53 +630,109 @@ Note that if no sampling probabilities are specified, the new dataset will have

## Format

The [`~Dataset.set_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter and the columns you want to format. Formatting is applied on-the-fly.
The [`~Dataset.with_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter. You can also choose which the columns you want to format using `columns=`. Formatting is applied on-the-fly.

For example, create PyTorch tensors by setting `type="torch"`:

```py
>>> import torch
>>> dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
>>> dataset = dataset.with_format(type="torch")
```

The [`~Dataset.with_format`] function also changes the format of a column, except it returns a new [`Dataset`] object:
The [`~Dataset.set_format`] function also changes the format of a column, except it runs in-place:

```py
>>> dataset = dataset.with_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
>>> dataset.set_format(type="torch")
```

If you need to reset the dataset to its original format, set the format to `None` (or use [`~Dataset.reset_format`]):

```py
>>> dataset.format
{'type': 'torch', 'format_kwargs': {}, 'columns': [...], 'output_all_columns': False}
>>> dataset = dataset.with_format(None)
>>> dataset.format
{'type': None, 'format_kwargs': {}, 'columns': [...], 'output_all_columns': False}
```

### Tensors formats

Several tensors or arrays formats are supported. It is generally recommended to use these formats instead of converting outputs of a dataset to tensors or arrays manually to avoid unnecessary data copies and accelerate data loading.

Here is the list of supported tensors or arrays formats:

- NumPy: format name is "numpy", for more information see [Using Datasets with NumPy](use_with_numpy)
- PyTorch: format name is "torch", for more information see [Using Datasets with PyTorch](use_with_pytorch)
- TensorFlow: format name is "tensorflow", for more information see [Using Datasets with TensorFlow](use_with_tensorflow)
- JAX: format name is "jax", for more information see [Using Datasets with JAX](use_with_jax)

<Tip>

🤗 Datasets also provides support for other common data formats such as NumPy, TensorFlow, JAX, Arrow, Pandas and Polars. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.
Check out the [Using Datasets with TensorFlow](use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.

</Tip>

If you need to reset the dataset to its original format, use the [`~Dataset.reset_format`] function:
When a dataset is formatted in a tensor or array format, all the data are formatted as tensors or arrays (except unsupported types like strings for example for PyTorch):

```py
>>> dataset.format
{'type': 'torch', 'format_kwargs': {}, 'columns': ['label'], 'output_all_columns': False}
>>> dataset.reset_format()
>>> dataset.format
{'type': 'python', 'format_kwargs': {}, 'columns': ['idx', 'label', 'sentence1', 'sentence2'], 'output_all_columns': False}
```python
>>> ds = Dataset.from_dict({"text": ["foo", "bar"], "tokens": [[0, 1, 2], [3, 4, 5]]})
>>> ds = ds.with_format("torch")
>>> ds[0]
{'text': 'foo', 'tokens': tensor([0, 1, 2])}
>>> ds[:2]
{'text': ['foo', 'bar'],
'tokens': tensor([[0, 1, 2],
[3, 4, 5]])}
```

### Tabular formats

You can use a dataframes or tables format to optimize data loading and data processing, since they generally offer zero-copy operations and transforms written in low-level languages.

Here is the list of supported dataframes or tables formats:

- Pandas: format name is "pandas", for more information see [Using Datasets with Pandas](use_with_pandas)
- Polars: format name is "polars", for more information see [Using Datasets with Polars](use_with_polars)
- PyArrow: format name is "arrow", for more information see [Using Datasets with PyArrow](use_with_tensorflow)

When a dataset is formatted in a dataframe or table format, every dataset row or batches of rows is formatted as a dataframe or table, and dataset colums are formatted as a series or array:

```python
>>> ds = Dataset.from_dict({"text": ["foo", "bar"], "label": [0, 1]})
>>> ds = ds.with_format("pandas")
>>> ds[:2]
text label
0 foo 0
1 bar 1
```

### Format transform
Those formats make it possible to iterate on the data faster by avoiding data copies, and also enable faster data processing in [`~Dataset.map`] or [`~Dataset.filter`]:

The [`~Dataset.set_transform`] function applies a custom formatting transform on-the-fly. This function replaces any previously specified format. For example, you can use this function to tokenize and pad tokens on-the-fly. Tokenization is only applied when examples are accessed:
```python
>>> ds = ds.map(lambda df: df.assign(upper_text=df.text.str.upper()), batched=True)
>>> ds[:2]
text label upper_text
0 foo 0 FOO
1 bar 1 BAR
```

### Custom format transform

The [`~Dataset.with_transform`] function applies a custom formatting transform on-the-fly. This function replaces any previously specified format. For example, you can use this function to tokenize and pad tokens on-the-fly. Tokenization is only applied when examples are accessed:

```py
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> def encode(batch):
... return tokenizer(batch["sentence1"], batch["sentence2"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
>>> dataset.set_transform(encode)
>>> dataset = dataset.with_transform(encode)
>>> dataset.format
{'type': 'custom', 'format_kwargs': {'transform': <function __main__.encode(batch)>}, 'columns': ['idx', 'label', 'sentence1', 'sentence2'], 'output_all_columns': False}
```

You can also use the [`~Dataset.set_transform`] function to decode formats not supported by [`Features`]. For example, the [`Audio`] feature uses [`soundfile`](https://python-soundfile.readthedocs.io/en/0.11.0/) - a fast and simple library to install - but it does not provide support for less common audio formats. Here is where you can use [`~Dataset.set_transform`] to apply a custom decoding transform on the fly. You're free to use any library you like to decode the audio files.
There is also [`~Dataset.set_transform`] which does the same but runs in-place.

You can also use the [`~Dataset.with_transform`] function to decode formats not supported by [`Features`]. For example, the [`Audio`] feature uses [`soundfile`](https://python-soundfile.readthedocs.io/en/0.11.0/) - a fast and simple library to install - but it does not provide support for less common audio formats. Here is where you can use [`~Dataset.set_transform`] to apply a custom decoding transform on the fly. You're free to use any library you like to decode the audio files.

The example below uses the [`pydub`](http://pydub.com/) package to open an audio format not supported by `soundfile`:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/use_with_jax.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ To avoid this, you must explicitly use the [`Array`] feature type and specify th
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> ds = ds.with_format("torch")
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
Expand Down
203 changes: 203 additions & 0 deletions docs/source/use_with_numpy.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Use with NumPy

This document is a quick introduction to using `datasets` with NumPy, with a particular focus on how to get
`jax.array` objects out of our datasets, and how to use them to train NumPy models.

<Tip>

`numpy` and `jaxlib` are required to reproduce to code above, so please make sure you
install them as `pip install datasets[jax]`.

</Tip>

## Dataset format

By default, datasets return regular Python objects: integers, floats, strings, lists, etc..

To get NumPy arrays instead, you can set the format of the dataset to `numpy`:

```py
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("numpy")
>>> ds[0]
{'data': array([1, 2])}
>>> ds[:2]
{'data': array([
[1, 2],
[3, 4]])}
```

<Tip>

A [`Dataset`] object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to NumPy arrays.

</Tip>

Note that the exact same procedure applies to `DatasetDict` objects, so that
when setting the format of a `DatasetDict` to `numpy`, all the `Dataset`s there
will be formatted as `numpy`:

```py
>>> from datasets import DatasetDict
>>> data = {"train": {"data": [[1, 2], [3, 4]]}, "test": {"data": [[5, 6], [7, 8]]}}
>>> dds = DatasetDict.from_dict(data)
>>> dds = dds.with_format("numpy")
>>> dds["train"][:2]
{'data': array([
[1, 2],
[3, 4]])}
```


### N-dimensional arrays

If your dataset consists of N-dimensional arrays, you will see that by default they are considered as the same array if the shape is fixed:

```py
>>> from datasets import Dataset
>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]] # fixed shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("numpy")
>>> ds[0]
{'data': array([[1, 2],
[3, 4]])}
```

```py
>>> from datasets import Dataset
>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]] # varying shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("numpy")
>>> ds[0]
{'data': array([array([1, 2]), array([3])], dtype=object)}
```

However this logic often requires slow shape comparisons and data copies.
To avoid this, you must explicitly use the [`Array`] feature type and specify the shape of your tensors:

```py
>>> from datasets import Dataset, Features, Array2D
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> ds = ds.with_format("numpy")
>>> ds[0]
{'data': array([[1, 2],
[3, 4]])}
>>> ds[:2]
{'data': array([[[1, 2],
[3, 4]],

[[5, 6],
[7, 8]]])}
```

### Other feature types

[`ClassLabel`] data is properly converted to arrays:

```py
>>> from datasets import Dataset, Features, ClassLabel
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features)
>>> ds = ds.with_format("numpy")
>>> ds[:3]
{'label': array([0, 0, 1])}
```

String and binary objects are unchanged, since NumPy only supports numbers.

The [`Image`] and [`Audio`] feature types are also supported.

<Tip>

To use the [`Image`] feature type, you'll need to install the `vision` extra as
`pip install datasets[vision]`.

</Tip>

```py
>>> from datasets import Dataset, Features, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features)
>>> ds = ds.with_format("numpy")
>>> ds[0]["image"].shape
(512, 512, 3)
>>> ds[0]
{'image': array([[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]], dtype=uint8)}
>>> ds[:2]["image"].shape
(2, 512, 512, 3)
>>> ds[:2]
{'image': array([[[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]]], dtype=uint8)}
```

<Tip>

To use the [`Audio`] feature type, you'll need to install the `audio` extra as
`pip install datasets[audio]`.

</Tip>

```py
>>> from datasets import Dataset, Features, Audio
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
>>> ds = ds.with_format("numpy")
>>> ds[0]["audio"]["array"]
array([-0.059021 , -0.03894043, -0.00735474, ..., 0.0133667 ,
0.01809692, 0.00268555], dtype=float32)
>>> ds[0]["audio"]["sampling_rate"]
array(44100, weak_type=True)
```

## Data loading

NumPy doesn't have any built-in data loading capabilities, so you'll need to use a library such
as [PyTorch](https://pytorch.org/) to load your data using a `DataLoader` or [TensorFlow](https://www.tensorflow.org/)
using a `tf.data.Dataset`.

So that's the reason why NumPy-formatting in `datasets` is so useful, because it lets you use
any model from the HuggingFace Hub with NumPy, without having to worry about the data loading
part.

### Using `with_format('numpy')`

The easiest way to get NumPy arrays out of a dataset is to use the `with_format('numpy')` method. Lets assume
that we want to train a neural network on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) available
at the HuggingFace Hub at https://huggingface.co/datasets/mnist.

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("mnist")
>>> ds = ds.with_format("numpy")
>>> ds["train"][0]
{'image': array([[ 0, 0, 0, ...],
[ 0, 0, 0, ...],
...,
[ 0, 0, 0, ...],
[ 0, 0, 0, ...]], dtype=uint8),
'label': array(5)}
```

Once the format is set we can feed the dataset to the NumPy model in batches using the `Dataset.iter()`
method:

```py
>>> for epoch in range(epochs):
... for batch in ds["train"].iter(batch_size=32):
... x, y = batch["image"], batch["label"]
... ...
```
Loading

0 comments on commit 581087e

Please sign in to comment.