diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 2e3728ef83a..a4bb1111a46 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -30,12 +30,20 @@
title: Process
- local: stream
title: Stream
- - local: use_with_tensorflow
- title: Use with TensorFlow
- local: use_with_pytorch
title: Use with PyTorch
+ - local: use_with_tensorflow
+ title: Use with TensorFlow
+ - local: use_with_numpy
+ title: Use with NumPy
- local: use_with_jax
title: Use with JAX
+ - local: use_with_pandas
+ title: Use with Pandas
+ - local: use_with_polars
+ title: Use with Polars
+ - local: use_with_pyarrow
+ title: Use with PyArrow
- local: use_with_spark
title: Use with Spark
- local: cache
diff --git a/docs/source/process.mdx b/docs/source/process.mdx
index 456a57f2d44..712dac4de4c 100644
--- a/docs/source/process.mdx
+++ b/docs/source/process.mdx
@@ -630,40 +630,94 @@ Note that if no sampling probabilities are specified, the new dataset will have
## Format
-The [`~Dataset.set_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter and the columns you want to format. Formatting is applied on-the-fly.
+The [`~Dataset.with_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter. You can also choose which the columns you want to format using `columns=`. Formatting is applied on-the-fly.
For example, create PyTorch tensors by setting `type="torch"`:
```py
->>> import torch
->>> dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
+>>> dataset = dataset.with_format(type="torch")
```
-The [`~Dataset.with_format`] function also changes the format of a column, except it returns a new [`Dataset`] object:
+The [`~Dataset.set_format`] function also changes the format of a column, except it runs in-place:
```py
->>> dataset = dataset.with_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
+>>> dataset.set_format(type="torch")
+```
+
+If you need to reset the dataset to its original format, set the format to `None` (or use [`~Dataset.reset_format`]):
+
+```py
+>>> dataset.format
+{'type': 'torch', 'format_kwargs': {}, 'columns': [...], 'output_all_columns': False}
+>>> dataset = dataset.with_format(None)
+>>> dataset.format
+{'type': None, 'format_kwargs': {}, 'columns': [...], 'output_all_columns': False}
```
+### Tensors formats
+
+Several tensors or arrays formats are supported. It is generally recommended to use these formats instead of converting outputs of a dataset to tensors or arrays manually to avoid unnecessary data copies and accelerate data loading.
+
+Here is the list of supported tensors or arrays formats:
+
+- NumPy: format name is "numpy", for more information see [Using Datasets with NumPy](use_with_numpy)
+- PyTorch: format name is "torch", for more information see [Using Datasets with PyTorch](use_with_pytorch)
+- TensorFlow: format name is "tensorflow", for more information see [Using Datasets with TensorFlow](use_with_tensorflow)
+- JAX: format name is "jax", for more information see [Using Datasets with JAX](use_with_jax)
+
-🤗 Datasets also provides support for other common data formats such as NumPy, TensorFlow, JAX, Arrow, Pandas and Polars. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.
+Check out the [Using Datasets with TensorFlow](use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.
-If you need to reset the dataset to its original format, use the [`~Dataset.reset_format`] function:
+When a dataset is formatted in a tensor or array format, all the data are formatted as tensors or arrays (except unsupported types like strings for example for PyTorch):
-```py
->>> dataset.format
-{'type': 'torch', 'format_kwargs': {}, 'columns': ['label'], 'output_all_columns': False}
->>> dataset.reset_format()
->>> dataset.format
-{'type': 'python', 'format_kwargs': {}, 'columns': ['idx', 'label', 'sentence1', 'sentence2'], 'output_all_columns': False}
+```python
+>>> ds = Dataset.from_dict({"text": ["foo", "bar"], "tokens": [[0, 1, 2], [3, 4, 5]]})
+>>> ds = ds.with_format("torch")
+>>> ds[0]
+{'text': 'foo', 'tokens': tensor([0, 1, 2])}
+>>> ds[:2]
+{'text': ['foo', 'bar'],
+ 'tokens': tensor([[0, 1, 2],
+ [3, 4, 5]])}
+```
+
+### Tabular formats
+
+You can use a dataframes or tables format to optimize data loading and data processing, since they generally offer zero-copy operations and transforms written in low-level languages.
+
+Here is the list of supported dataframes or tables formats:
+
+- Pandas: format name is "pandas", for more information see [Using Datasets with Pandas](use_with_pandas)
+- Polars: format name is "polars", for more information see [Using Datasets with Polars](use_with_polars)
+- PyArrow: format name is "arrow", for more information see [Using Datasets with PyArrow](use_with_tensorflow)
+
+When a dataset is formatted in a dataframe or table format, every dataset row or batches of rows is formatted as a dataframe or table, and dataset colums are formatted as a series or array:
+
+```python
+>>> ds = Dataset.from_dict({"text": ["foo", "bar"], "label": [0, 1]})
+>>> ds = ds.with_format("pandas")
+>>> ds[:2]
+ text label
+0 foo 0
+1 bar 1
```
-### Format transform
+Those formats make it possible to iterate on the data faster by avoiding data copies, and also enable faster data processing in [`~Dataset.map`] or [`~Dataset.filter`]:
-The [`~Dataset.set_transform`] function applies a custom formatting transform on-the-fly. This function replaces any previously specified format. For example, you can use this function to tokenize and pad tokens on-the-fly. Tokenization is only applied when examples are accessed:
+```python
+>>> ds = ds.map(lambda df: df.assign(upper_text=df.text.str.upper()), batched=True)
+>>> ds[:2]
+ text label upper_text
+0 foo 0 FOO
+1 bar 1 BAR
+```
+
+### Custom format transform
+
+The [`~Dataset.with_transform`] function applies a custom formatting transform on-the-fly. This function replaces any previously specified format. For example, you can use this function to tokenize and pad tokens on-the-fly. Tokenization is only applied when examples are accessed:
```py
>>> from transformers import AutoTokenizer
@@ -671,12 +725,14 @@ The [`~Dataset.set_transform`] function applies a custom formatting transform on
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> def encode(batch):
... return tokenizer(batch["sentence1"], batch["sentence2"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
->>> dataset.set_transform(encode)
+>>> dataset = dataset.with_transform(encode)
>>> dataset.format
{'type': 'custom', 'format_kwargs': {'transform': }, 'columns': ['idx', 'label', 'sentence1', 'sentence2'], 'output_all_columns': False}
```
-You can also use the [`~Dataset.set_transform`] function to decode formats not supported by [`Features`]. For example, the [`Audio`] feature uses [`soundfile`](https://python-soundfile.readthedocs.io/en/0.11.0/) - a fast and simple library to install - but it does not provide support for less common audio formats. Here is where you can use [`~Dataset.set_transform`] to apply a custom decoding transform on the fly. You're free to use any library you like to decode the audio files.
+There is also [`~Dataset.set_transform`] which does the same but runs in-place.
+
+You can also use the [`~Dataset.with_transform`] function to decode formats not supported by [`Features`]. For example, the [`Audio`] feature uses [`soundfile`](https://python-soundfile.readthedocs.io/en/0.11.0/) - a fast and simple library to install - but it does not provide support for less common audio formats. Here is where you can use [`~Dataset.set_transform`] to apply a custom decoding transform on the fly. You're free to use any library you like to decode the audio files.
The example below uses the [`pydub`](http://pydub.com/) package to open an audio format not supported by `soundfile`:
diff --git a/docs/source/use_with_jax.mdx b/docs/source/use_with_jax.mdx
index dc73a8df778..89d1628df06 100644
--- a/docs/source/use_with_jax.mdx
+++ b/docs/source/use_with_jax.mdx
@@ -108,7 +108,7 @@ To avoid this, you must explicitly use the [`Array`] feature type and specify th
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
->>> ds = ds.with_format("torch")
+>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
diff --git a/docs/source/use_with_numpy.mdx b/docs/source/use_with_numpy.mdx
new file mode 100644
index 00000000000..095c283c2f9
--- /dev/null
+++ b/docs/source/use_with_numpy.mdx
@@ -0,0 +1,203 @@
+# Use with NumPy
+
+This document is a quick introduction to using `datasets` with NumPy, with a particular focus on how to get
+`jax.array` objects out of our datasets, and how to use them to train NumPy models.
+
+
+
+`numpy` and `jaxlib` are required to reproduce to code above, so please make sure you
+install them as `pip install datasets[jax]`.
+
+
+
+## Dataset format
+
+By default, datasets return regular Python objects: integers, floats, strings, lists, etc..
+
+To get NumPy arrays instead, you can set the format of the dataset to `numpy`:
+
+```py
+>>> from datasets import Dataset
+>>> data = [[1, 2], [3, 4]]
+>>> ds = Dataset.from_dict({"data": data})
+>>> ds = ds.with_format("numpy")
+>>> ds[0]
+{'data': array([1, 2])}
+>>> ds[:2]
+{'data': array([
+ [1, 2],
+ [3, 4]])}
+```
+
+
+
+A [`Dataset`] object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to NumPy arrays.
+
+
+
+Note that the exact same procedure applies to `DatasetDict` objects, so that
+when setting the format of a `DatasetDict` to `numpy`, all the `Dataset`s there
+will be formatted as `numpy`:
+
+```py
+>>> from datasets import DatasetDict
+>>> data = {"train": {"data": [[1, 2], [3, 4]]}, "test": {"data": [[5, 6], [7, 8]]}}
+>>> dds = DatasetDict.from_dict(data)
+>>> dds = dds.with_format("numpy")
+>>> dds["train"][:2]
+{'data': array([
+ [1, 2],
+ [3, 4]])}
+```
+
+
+### N-dimensional arrays
+
+If your dataset consists of N-dimensional arrays, you will see that by default they are considered as the same array if the shape is fixed:
+
+```py
+>>> from datasets import Dataset
+>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]] # fixed shape
+>>> ds = Dataset.from_dict({"data": data})
+>>> ds = ds.with_format("numpy")
+>>> ds[0]
+{'data': array([[1, 2],
+ [3, 4]])}
+```
+
+```py
+>>> from datasets import Dataset
+>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]] # varying shape
+>>> ds = Dataset.from_dict({"data": data})
+>>> ds = ds.with_format("numpy")
+>>> ds[0]
+{'data': array([array([1, 2]), array([3])], dtype=object)}
+```
+
+However this logic often requires slow shape comparisons and data copies.
+To avoid this, you must explicitly use the [`Array`] feature type and specify the shape of your tensors:
+
+```py
+>>> from datasets import Dataset, Features, Array2D
+>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
+>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
+>>> ds = Dataset.from_dict({"data": data}, features=features)
+>>> ds = ds.with_format("numpy")
+>>> ds[0]
+{'data': array([[1, 2],
+ [3, 4]])}
+>>> ds[:2]
+{'data': array([[[1, 2],
+ [3, 4]],
+
+ [[5, 6],
+ [7, 8]]])}
+```
+
+### Other feature types
+
+[`ClassLabel`] data is properly converted to arrays:
+
+```py
+>>> from datasets import Dataset, Features, ClassLabel
+>>> labels = [0, 0, 1]
+>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
+>>> ds = Dataset.from_dict({"label": labels}, features=features)
+>>> ds = ds.with_format("numpy")
+>>> ds[:3]
+{'label': array([0, 0, 1])}
+```
+
+String and binary objects are unchanged, since NumPy only supports numbers.
+
+The [`Image`] and [`Audio`] feature types are also supported.
+
+
+
+To use the [`Image`] feature type, you'll need to install the `vision` extra as
+`pip install datasets[vision]`.
+
+
+
+```py
+>>> from datasets import Dataset, Features, Image
+>>> images = ["path/to/image.png"] * 10
+>>> features = Features({"image": Image()})
+>>> ds = Dataset.from_dict({"image": images}, features=features)
+>>> ds = ds.with_format("numpy")
+>>> ds[0]["image"].shape
+(512, 512, 3)
+>>> ds[0]
+{'image': array([[[ 255, 255, 255],
+ [ 255, 255, 255],
+ ...,
+ [ 255, 255, 255],
+ [ 255, 255, 255]]], dtype=uint8)}
+>>> ds[:2]["image"].shape
+(2, 512, 512, 3)
+>>> ds[:2]
+{'image': array([[[[ 255, 255, 255],
+ [ 255, 255, 255],
+ ...,
+ [ 255, 255, 255],
+ [ 255, 255, 255]]]], dtype=uint8)}
+```
+
+
+
+To use the [`Audio`] feature type, you'll need to install the `audio` extra as
+`pip install datasets[audio]`.
+
+
+
+```py
+>>> from datasets import Dataset, Features, Audio
+>>> audio = ["path/to/audio.wav"] * 10
+>>> features = Features({"audio": Audio()})
+>>> ds = Dataset.from_dict({"audio": audio}, features=features)
+>>> ds = ds.with_format("numpy")
+>>> ds[0]["audio"]["array"]
+array([-0.059021 , -0.03894043, -0.00735474, ..., 0.0133667 ,
+ 0.01809692, 0.00268555], dtype=float32)
+>>> ds[0]["audio"]["sampling_rate"]
+array(44100, weak_type=True)
+```
+
+## Data loading
+
+NumPy doesn't have any built-in data loading capabilities, so you'll need to use a library such
+as [PyTorch](https://pytorch.org/) to load your data using a `DataLoader` or [TensorFlow](https://www.tensorflow.org/)
+using a `tf.data.Dataset`.
+
+So that's the reason why NumPy-formatting in `datasets` is so useful, because it lets you use
+any model from the HuggingFace Hub with NumPy, without having to worry about the data loading
+part.
+
+### Using `with_format('numpy')`
+
+The easiest way to get NumPy arrays out of a dataset is to use the `with_format('numpy')` method. Lets assume
+that we want to train a neural network on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) available
+at the HuggingFace Hub at https://huggingface.co/datasets/mnist.
+
+```py
+>>> from datasets import load_dataset
+>>> ds = load_dataset("mnist")
+>>> ds = ds.with_format("numpy")
+>>> ds["train"][0]
+{'image': array([[ 0, 0, 0, ...],
+ [ 0, 0, 0, ...],
+ ...,
+ [ 0, 0, 0, ...],
+ [ 0, 0, 0, ...]], dtype=uint8),
+ 'label': array(5)}
+```
+
+Once the format is set we can feed the dataset to the NumPy model in batches using the `Dataset.iter()`
+method:
+
+```py
+>>> for epoch in range(epochs):
+... for batch in ds["train"].iter(batch_size=32):
+... x, y = batch["image"], batch["label"]
+... ...
+```
diff --git a/docs/source/use_with_pandas.mdx b/docs/source/use_with_pandas.mdx
new file mode 100644
index 00000000000..9c2cfb7e878
--- /dev/null
+++ b/docs/source/use_with_pandas.mdx
@@ -0,0 +1,83 @@
+# Use with Pandas
+
+This document is a quick introduction to using `datasets` with Pandas, with a particular focus on how to process
+datasets using Pandas functions, and how to convert a dataset to Pandas or from Pandas.
+
+This is particularly useful as it allows fast operations, since `datasets` uses PyArrow under the hood and PyArrow is well integrated with Pandas.
+
+## Dataset format
+
+By default, datasets return regular Python objects: integers, floats, strings, lists, etc.
+
+To get Pandas DataFrames or Series instead, you can set the format of the dataset to `pandas` using [`Dataset.with_format`]:
+
+```py
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("pandas")
+>>> ds[0] # pd.DataFrame
+ col_0 col_1
+0 a 0.0
+>>> ds[:2] # pd.DataFrame
+ col_0 col_1
+0 a 0.0
+1 b 0.0
+>>> ds["data"] # pd.Series
+0 a
+1 b
+2 c
+3 d
+Name: col_0, dtype: object
+```
+
+This also works for `IterableDataset` objects obtained e.g. using `load_dataset(..., streaming=True)`:
+
+```py
+>>> ds = ds.with_format("pandas")
+>>> for df in ds.iter(batch_size=2):
+... print(df)
+... break
+ col_0 col_1
+0 a 0.0
+1 b 0.0
+```
+
+## Process data
+
+Pandas functions are generally faster than regular hand-written python functions, and therefore they are a good option to optimize data processing. You can use Pandas functions to process a dataset in [`Dataset.map`] or [`Dataset.filter`]:
+
+```python
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("pandas")
+>>> ds = ds.map(lambda df: df.assign(col_2=df.col_1 + 1), batched=True)
+>>> ds[:2]
+ col_0 col_1 col_2
+0 a 0.0 1.0
+1 b 0.0 1.0
+>>> ds = ds.filter(lambda df: df.col_0 == "b", batched=True)
+>>> ds[0]
+ col_0 col_1 col_2
+0 b 0.0 1.0
+```
+
+We use `batched=True` because it is faster to process batches of data in Pandas rather than row by row. It's also possible to use `batch_size=` in `map()` to set the size of each `df`.
+
+This also works for [`IterableDataset.map`] and [`IterableDataset.filter`].
+
+## Import or Export from Pandas
+
+To import data from Pandas, you can use [`Dataset.from_pandas`]:
+
+```python
+ds = Dataset.from_pandas(df)
+```
+
+And you can use [`Dataset.to_pandas`] to export a Dataset to a Pandas DataFrame:
+
+
+```python
+df = Dataset.from_pandas(ds)
+```
diff --git a/docs/source/use_with_polars.mdx b/docs/source/use_with_polars.mdx
new file mode 100644
index 00000000000..82144b05e31
--- /dev/null
+++ b/docs/source/use_with_polars.mdx
@@ -0,0 +1,117 @@
+# Use with Polars
+
+This document is a quick introduction to using `datasets` with Polars, with a particular focus on how to process
+datasets using Polars functions, and how to convert a dataset to Polars or from Polars.
+
+This is particularly useful as it allows fast zero-copy operations, since both `datasets` and Polars use Arrow under the hood.
+
+## Dataset format
+
+By default, datasets return regular Python objects: integers, floats, strings, lists, etc.
+
+To get Polars DataFrames or Series instead, you can set the format of the dataset to `polars` using [`Dataset.with_format`]:
+
+```py
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("polars")
+>>> ds[0] # pl.DataFrame
+shape: (1, 2)
+┌───────┬───────┐
+│ col_0 ┆ col_1 │
+│ --- ┆ --- │
+│ str ┆ f64 │
+╞═══════╪═══════╡
+│ a ┆ 0.0 │
+└───────┴───────┘
+>>> ds[:2] # pl.DataFrame
+shape: (2, 2)
+┌───────┬───────┐
+│ col_0 ┆ col_1 │
+│ --- ┆ --- │
+│ str ┆ f64 │
+╞═══════╪═══════╡
+│ a ┆ 0.0 │
+│ b ┆ 0.0 │
+└───────┴───────┘
+>>> ds["data"] # pl.Series
+shape: (4,)
+Series: 'col_0' [str]
+[
+ "a"
+ "b"
+ "c"
+ "d"
+]
+```
+
+This also works for `IterableDataset` objects obtained e.g. using `load_dataset(..., streaming=True)`:
+
+```py
+>>> ds = ds.with_format("polars")
+>>> for df in ds.iter(batch_size=2):
+... print(df)
+... break
+shape: (2, 2)
+┌───────┬───────┐
+│ col_0 ┆ col_1 │
+│ --- ┆ --- │
+│ str ┆ f64 │
+╞═══════╪═══════╡
+│ a ┆ 0.0 │
+│ b ┆ 0.0 │
+└───────┴───────┘
+```
+
+## Process data
+
+Polars functions are generally faster than regular hand-written python functions, and therefore they are a good option to optimize data processing. You can use Polars functions to process a dataset in [`Dataset.map`] or [`Dataset.filter`]:
+
+```python
+>>> import polars as pl
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("polars")
+>>> ds = ds.map(lambda df: df.with_columns(pl.col("col_1").add(1).alias("col_2")), batched=True)
+>>> ds[:2]
+shape: (2, 3)
+┌───────┬───────┬───────┐
+│ col_0 ┆ col_1 ┆ col_2 │
+│ --- ┆ --- ┆ --- │
+│ str ┆ f64 ┆ f64 │
+╞═══════╪═══════╪═══════╡
+│ a ┆ 0.0 ┆ 1.0 │
+│ b ┆ 0.0 ┆ 1.0 │
+└───────┴───────┴───────┘
+>>> ds = ds.filter(lambda df: df["col_0"] == "b", batched=True)
+>>> ds[0]
+shape: (1, 3)
+┌───────┬───────┬───────┐
+│ col_0 ┆ col_1 ┆ col_2 │
+│ --- ┆ --- ┆ --- │
+│ str ┆ f64 ┆ f64 │
+╞═══════╪═══════╪═══════╡
+│ b ┆ 0.0 ┆ 1.0 │
+└───────┴───────┴───────┘
+```
+
+We use `batched=True` because it is faster to process batches of data in Polars rather than row by row. It's also possible to use `batch_size=` in `map()` to set the size of each `df`.
+
+This also works for [`IterableDataset.map`] and [`IterableDataset.filter`].
+
+## Import or Export from Polars
+
+To import data from Polars, you can use [`Dataset.from_polars`]:
+
+```python
+ds = Dataset.from_polars(df)
+```
+
+And you can use [`Dataset.to_polars`] to export a Dataset to a Polars DataFrame:
+
+
+```python
+df = Dataset.from_polars(ds)
+```
diff --git a/docs/source/use_with_pyarrow.mdx b/docs/source/use_with_pyarrow.mdx
new file mode 100644
index 00000000000..6ab8ba5bf15
--- /dev/null
+++ b/docs/source/use_with_pyarrow.mdx
@@ -0,0 +1,108 @@
+# Use with PyArrow
+
+This document is a quick introduction to using `datasets` with PyArrow, with a particular focus on how to process
+datasets using Arrow compute functions, and how to convert a dataset to PyArrow or from PyArrow.
+
+This is particularly useful as it allows fast zero-copy operations, since `datasets` uses PyArrow under the hood.
+
+## Dataset format
+
+By default, datasets return regular Python objects: integers, floats, strings, lists, etc.
+
+To get PyArrow Tables or Arrays instead, you can set the format of the dataset to `pyarrow` using [`Dataset.with_format`]:
+
+```py
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("arrow")
+>>> ds[0] # pa.Table
+pyarrow.Table
+col_0: string
+col_1: double
+----
+col_0: [["a"]]
+col_1: [[0]]
+>>> ds[:2] # pa.Table
+pyarrow.Table
+col_0: string
+col_1: double
+----
+col_0: [["a","b"]]
+col_1: [[0,0]]
+>>> ds["data"] # pa.array
+
+[
+ [
+ "a",
+ "b",
+ "c",
+ "d"
+ ]
+]
+```
+
+This also works for `IterableDataset` objects obtained e.g. using `load_dataset(..., streaming=True)`:
+
+```py
+>>> ds = ds.with_format("arrow")
+>>> for df in ds.iter(batch_size=2):
+... print(df)
+... break
+pyarrow.Table
+col_0: string
+col_1: double
+----
+col_0: [["a","b"]]
+col_1: [[0,0]]
+```
+
+## Process data
+
+PyArrow functions are generally faster than regular hand-written python functions, and therefore they are a good option to optimize data processing. You can use Arrow compute functions to process a dataset in [`Dataset.map`] or [`Dataset.filter`]:
+
+```python
+>>> import pyarrow.compute as pc
+>>> from datasets import Dataset
+>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
+>>> ds = Dataset.from_dict(data)
+>>> ds = ds.with_format("arrow")
+>>> ds = ds.map(lambda t: t.append_column("col_2", pc.add(t["col_1"], 1)), batched=True)
+>>> ds[:2]
+pyarrow.Table
+col_0: string
+col_1: double
+col_2: double
+----
+col_0: [["a","b"]]
+col_1: [[0,0]]
+col_2: [[1,1]]
+>>> ds = ds.filter(lambda t: pc.equal(t["col_0"], "b"), batched=True)
+>>> ds[0]
+pyarrow.Table
+col_0: string
+col_1: double
+col_2: double
+----
+col_0: [["b"]]
+col_1: [[0]]
+col_2: [[1]]
+```
+
+We use `batched=True` because it is faster to process batches of data in PyArrow rather than row by row. It's also possible to use `batch_size=` in `map()` to set the size of each `df`.
+
+This also works for [`IterableDataset.map`] and [`IterableDataset.filter`].
+
+## Import or Export from PyArrow
+
+A [`Dataset`] is a wrapper of a PyArrow Table, you can instantiate a Dataset directly from the Table:
+
+```python
+ds = Dataset(table)
+```
+
+You can access the PyArrow Table of a dataset using [`Dataset.data`], which returns a [`MemoryMappedTable`] or a [`InMemoryTable`] or a [`ConcatenationTable`], depending on the origin of the Arrow data and the operations that were applied.
+
+Those objects wrap the underlying PyArrow table accessible at `Dataset.data.table`. This table contains all the data of the dataset, but there might also be an indices mapping at `Dataset._indices` which maps the dataset rows indices to the PyArrow Table rows indices. This can happen if the dataset has been shuffled with [`Dataset.shuffle`] or if only a subset of the rows are used (e.g. after a [`Dataset.select`]).
+
+In the general case, you can export a dataset to a PyArrow Table using `table = ds.with_format("arrow")[:]`.
diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index f005b2374d1..927bc01709f 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -5205,9 +5205,12 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
from .iterable_dataset import ArrowExamplesIterable, IterableDataset
if self._format_type is not None:
- raise NotImplementedError(
- "Converting a formatted dataset to a formatted iterable dataset is not implemented yet. Please run `my_dataset = my_dataset.with_format(None)` before calling to_iterable_dataset"
- )
+ if self._format_kwargs or (
+ self._format_columns is not None and set(self._format_columns) != set(self.column_names)
+ ):
+ raise NotImplementedError(
+ "Converting a formatted dataset with kwargs or selected columns to a formatted iterable dataset is not implemented yet. Please run `my_dataset = my_dataset.with_format(None)` before calling to_iterable_dataset"
+ )
if num_shards > len(self):
raise ValueError(
f"Unable to shard a dataset of size {len(self)} into {num_shards} shards (the number of shards exceeds the number of samples)."
@@ -5228,7 +5231,10 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
Dataset._generate_tables_from_shards,
kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE},
)
- return IterableDataset(ex_iterable, info=DatasetInfo(features=self.features))
+ ds = IterableDataset(ex_iterable, info=DatasetInfo(features=self.features))
+ if self._format_type:
+ ds = ds.with_format(self._format_type)
+ return ds
def _push_parquet_shards_to_hub(
self,
@@ -6308,6 +6314,8 @@ def get_indices_from_mask_function(
if with_rank:
additional_args += (rank,)
mask = function(*inputs, *additional_args, **fn_kwargs)
+ if isinstance(mask, (pa.Array, pa.ChunkedArray)):
+ mask = mask.to_pylist()
else:
# we get batched data (to do less look-ups) but `function` only accepts one example
# therefore we need to call `function` on each example of the batch to get the mask