Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify ArrowTensorType tables and Tensor blocks #18867

Merged
merged 8 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 16 additions & 48 deletions doc/source/data/dataset-tensor-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,66 +3,34 @@
Dataset Tensor Support
======================

Tensor-typed values
-------------------
Tables with tensor columns
--------------------------

Datasets supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use Pandas and Ray Datasets to read, write, and manipulate e.g., images. All conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays are taken care of by Ray Datasets.

With our Pandas extension type, :class:`TensorDtype <ray.data.extensions.tensor_extension.TensorDtype>`, and extension array, :class:`TensorArray <ray.data.extensions.tensor_extension.TensorArray>`, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType <ray.data.extensions.tensor_extension.ArrowTensorType>`, and extension array, :class:`ArrowTensorArray <ray.data.extensions.tensor_extension.ArrowTensorArray>`, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format.

Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically.

Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples:
Single-column tensor datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples:

.. code-block:: python

# Create a Dataset of tensor-typed values.
ds = ray.data.range_tensor(10000, shape=(3, 5))
# -> Dataset(num_blocks=200, num_rows=10000,
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)

ds.map_batches(lambda t: t + 2).show(2)
# -> [[2 2 2 2 2]
# [2 2 2 2 2]
# [2 2 2 2 2]]
# [[3 3 3 3 3]
# [3 3 3 3 3]
# [3 3 3 3 3]]
# schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})

# Save to storage.
ds.write_numpy("/tmp/tensor_out")
ds.write_numpy("/tmp/tensor_out", column="value")

# Read from storage.
ray.data.read_numpy("/tmp/tensor_out")
# -> Dataset(num_blocks=200, num_rows=?,
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)

Tensor datasets are also created whenever an array type is returned from a map function:

.. code-block:: python

# Create a dataset of Python integers.
ds = ray.data.range(10)
# -> Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)

# It is now converted into a Tensor dataset.
ds = ds.map_batches(lambda x: np.array(x))
# -> Dataset(num_blocks=10, num_rows=10,
# schema=<Tensor: shape=(None,), dtype=int64>)

Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store:

.. code-block:: python

import numpy as np

# Create a Dataset from a list of NumPy ndarray objects.
arr1 = np.arange(0, 10)
arr2 = np.arange(10, 20)
ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)])

Tables with tensor columns
--------------------------

In addition to tensor datasets, Datasets also supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use both Pandas and Ray Datasets to read, write, and manipulate a table with a column of e.g. images (2D arrays), with all conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays, being taken care of by Ray Datasets.

With our Pandas extension type, :class:`TensorDtype <ray.data.extensions.tensor_extension.TensorDtype>`, and extension array, :class:`TensorArray <ray.data.extensions.tensor_extension.TensorArray>`, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType <ray.data.extensions.tensor_extension.ArrowTensorType>`, and extension array, :class:`ArrowTensorArray <ray.data.extensions.tensor_extension.ArrowTensorArray>`, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format.

Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically.
# schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})

Reading existing serialized tensor columns
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -150,7 +118,7 @@ Now that the tensor column is properly typed and in a ``Dataset``, we can perfor

# Arrow and Pandas is now aware of this tensor column, so we can do the
# typical DataFrame operations on this column.
ds = ds.map_batches(lambda x: 2 * (x + 1), format="pandas")
ds = ds.map_batches(lambda x: 2 * (x + 1), batch_format="pandas")
# -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s]
print(ds)
# -> Dataset(
Expand Down
2 changes: 1 addition & 1 deletion doc/source/data/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Ray Datasets are the standard way to load and exchange data in Ray libraries and

Concepts
--------
Ray Datasets implement `Distributed Arrow <https://arrow.apache.org/>`__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table <https://arrow.apache.org/docs/python/data.html#tables>`__, `Arrow tensor <https://arrow.apache.org/docs/python/generated/pyarrow.Tensor.html>`__, or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data.
Ray Datasets implement `Distributed Arrow <https://arrow.apache.org/>`__. A Dataset consists of a list of Ray object references to *blocks*. Each block holds a set of items in either an `Arrow table <https://arrow.apache.org/docs/python/data.html#tables>`__ or a Python list (for Arrow incompatible objects). Having multiple blocks in a dataset allows for parallel transformation and ingest of the data.

The following figure visualizes a Dataset that has three Arrow table blocks, each block holding 1000 rows each:

Expand Down
24 changes: 12 additions & 12 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# Represents a batch of records to be stored in the Ray object store.
#
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``.
Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes]
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
Block = Union[List[T], "pyarrow.Table", bytes]


@DeveloperAPI
Expand Down Expand Up @@ -52,8 +52,8 @@ class BlockAccessor(Generic[T]):
as a top-level Ray object, without a wrapping class (issue #17186).

There are three types of block accessors: ``SimpleBlockAccessor``, which
operates over a plain Python list, ``ArrowBlockAccessor``, for
``pyarrow.Table`` type blocks, and ``TensorBlockAccessor``, for tensors.
operates over a plain Python list, and ``ArrowBlockAccessor`` for
``pyarrow.Table`` type blocks.
"""

def num_rows(self) -> int:
Expand Down Expand Up @@ -85,12 +85,16 @@ def to_pandas(self) -> "pandas.DataFrame":
"""Convert this block into a Pandas dataframe."""
raise NotImplementedError

def to_numpy(self) -> np.ndarray:
"""Convert this block into a NumPy ndarray."""
def to_numpy(self, column: str = None) -> np.ndarray:
"""Convert this block (or column of block) into a NumPy ndarray.

Args:
column: Name of column to convert, or None.
"""
raise NotImplementedError

def to_arrow(self) -> Union["pyarrow.Table", "pyarrow.Tensor"]:
"""Convert this block into an Arrow table or tensor."""
def to_arrow(self) -> "pyarrow.Table":
"""Convert this block into an Arrow table."""
raise NotImplementedError

def size_bytes(self) -> int:
Expand Down Expand Up @@ -136,10 +140,6 @@ def for_block(block: Block) -> "BlockAccessor[T]":
from ray.data.impl.simple_block import \
SimpleBlockAccessor
return SimpleBlockAccessor(block)
elif isinstance(block, np.ndarray):
from ray.data.impl.tensor_block import \
TensorBlockAccessor
return TensorBlockAccessor(block)
else:
raise TypeError("Not a block type: {}".format(block))

Expand Down
45 changes: 25 additions & 20 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class Dataset(Generic[T]):

Datasets are implemented as a list of ``ObjectRef[Block]``. The block
also determines the unit of parallelism. The default block type is the
``pyarrow.Table``. Tensor objects are held in ``np.ndarray`` blocks,
and other Arrow-incompatible objects are held in ``list`` blocks.
``pyarrow.Table``. Arrow-incompatible objects are held in ``list`` blocks.

Since Datasets are just lists of Ray object refs, they can be passed
between Ray tasks and actors just like any other object. Datasets support
Expand Down Expand Up @@ -169,7 +168,7 @@ def map_batches(self,
tasks, or "actors" to use an autoscaling Ray actor pool.
batch_format: Specify "native" to use the native block format,
"pandas" to select ``pandas.DataFrame`` as the batch format,
or "pyarrow" to select ``pyarrow.Table/Tensor``.
or "pyarrow" to select ``pyarrow.Table``.
ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""
Expand Down Expand Up @@ -205,19 +204,15 @@ def transform(block: Block) -> Block:
"or 'pyarrow', got: {}".format(batch_format))

applied = fn(view)
if (isinstance(applied, list) or isinstance(applied, pa.Table)
or isinstance(applied, np.ndarray)):
if isinstance(applied, list) or isinstance(applied, pa.Table):
applied = applied
elif isinstance(applied, pd.core.frame.DataFrame):
applied = pa.Table.from_pandas(applied)
elif isinstance(applied, pa.Tensor):
applied = applied.to_numpy()
else:
raise ValueError("The map batches UDF returned a type "
f"{type(applied)}, which is not allowed. "
"The return type must be either list, "
"pandas.DataFrame, np.ndarray, "
"pyarrow.Tensor, or pyarrow.Table")
"pandas.DataFrame, or pyarrow.Table")
builder.add_block(applied)

return builder.build()
Expand Down Expand Up @@ -947,11 +942,13 @@ def write_numpy(
self,
path: str,
*,
column: str = "value",
filesystem: Optional["pyarrow.fs.FileSystem"] = None) -> None:
"""Write the dataset to npy files.
"""Write a tensor column of the dataset to npy files.

This is only supported for datasets of Tensor records.
To control the number of files, use ``.repartition()``.
This is only supported for datasets convertible to Arrow records that
contain a TensorArray column. To control the number of files, use
``.repartition()``.

The format of the output files will be {self._uuid}_{block_idx}.npy,
where ``uuid`` is an unique id for the dataset.
Expand All @@ -964,12 +961,15 @@ def write_numpy(
Args:
path: The path to the destination root directory, where npy
files will be written to.
column: The name of the table column that contains the tensor to
be written. This defaults to "value".
filesystem: The filesystem implementation to write to.
"""
self.write_datasource(
NumpyDatasource(),
path=path,
dataset_uuid=self._uuid,
column=column,
filesystem=filesystem)

def write_datasource(self, datasource: Datasource[T],
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def iter_batches(self,
batch_format: The format in which to return each batch.
Specify "native" to use the current block format, "pandas" to
select ``pandas.DataFrame`` or "pyarrow" to select
``pyarrow.Table/Tensor``. Default is "native".
``pyarrow.Table``. Default is "native".
drop_last: Whether to drop the last batch if it's incomplete.

Returns:
Expand Down Expand Up @@ -1364,7 +1364,8 @@ def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]:
block_to_df = cached_remote_fn(_block_to_df)
return [block_to_df.remote(block) for block in self._blocks]

def to_numpy(self) -> List[ObjectRef[np.ndarray]]:
def to_numpy(self, *,
column: Optional[str] = None) -> List[ObjectRef[np.ndarray]]:
"""Convert this dataset into a distributed set of NumPy ndarrays.

This is only supported for datasets convertible to NumPy ndarrays.
Expand All @@ -1373,12 +1374,19 @@ def to_numpy(self) -> List[ObjectRef[np.ndarray]]:

Time complexity: O(dataset size / parallelism)

Args:
column: The name of the column to convert to numpy, or None to
specify the entire row. Required for Arrow tables.

Returns:
A list of remote NumPy ndarrays created from this dataset.
"""

block_to_ndarray = cached_remote_fn(_block_to_ndarray)
return [block_to_ndarray.remote(block) for block in self._blocks]
return [
block_to_ndarray.remote(block, column=column)
for block in self._blocks
]

def to_arrow(self) -> List[ObjectRef["pyarrow.Table"]]:
"""Convert this dataset into a distributed set of Arrow tables.
Expand Down Expand Up @@ -1585,9 +1593,6 @@ def __repr__(self) -> str:
schema = self.schema()
if schema is None:
schema_str = "Unknown schema"
elif isinstance(schema, dict):
schema_str = "<Tensor: shape={}, dtype={}>".format(
schema["shape"], schema["dtype"])
elif isinstance(schema, type):
schema_str = str(schema)
else:
Expand Down Expand Up @@ -1640,9 +1645,9 @@ def _block_to_df(block: Block):
return block.to_pandas()


def _block_to_ndarray(block: Block):
def _block_to_ndarray(block: Block, column: Optional[str]):
block = BlockAccessor.for_block(block)
return block.to_numpy()
return block.to_numpy(column)


def _block_to_arrow(block: Block):
Expand Down
16 changes: 12 additions & 4 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def make_block(start: int, count: int) -> Block:
return pyarrow.Table.from_arrays(
[np.arange(start, start + count)], names=["value"])
elif block_format == "tensor":
return np.ones(
tensor_shape, dtype=np.int64) * np.expand_dims(
tensor = TensorArray(
np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
np.arange(start, start + count),
tuple(range(1, 1 + len(tensor_shape))))
tuple(range(1, 1 + len(tensor_shape)))))
return pyarrow.Table.from_pydict({"value": tensor})
else:
return list(builtins.range(start, start + count))

Expand All @@ -145,7 +146,14 @@ def make_block(start: int, count: int) -> Block:
import pyarrow
schema = pyarrow.Table.from_pydict({"value": [0]}).schema
elif block_format == "tensor":
schema = {"dtype": "int64", "shape": (None, ) + tensor_shape}
_check_pyarrow_version()
from ray.data.extensions import TensorArray
import pyarrow
tensor = TensorArray(
np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
np.arange(0, 10), tuple(
range(1, 1 + len(tensor_shape)))))
schema = pyarrow.Table.from_pydict({"value": tensor}).schema
elif block_format == "list":
schema = int
else:
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def write_block(write_path: str, block: Block):
if _block_udf is not None:
block = _block_udf(block)
with fs.open_output_stream(write_path) as f:
_write_block_to_file(f, BlockAccessor.for_block(block))
_write_block_to_file(f, BlockAccessor.for_block(block),
**write_args)

write_block = cached_remote_fn(write_block)

Expand Down
13 changes: 9 additions & 4 deletions python/ray/data/datasource/numpy_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow

from ray.data.block import BlockAccessor
from ray.data.datasource.file_based_datasource import (FileBasedDatasource)
from ray.data.datasource.file_based_datasource import FileBasedDatasource


class NumpyDatasource(FileBasedDatasource):
Expand All @@ -21,17 +21,22 @@ class NumpyDatasource(FileBasedDatasource):
"""

def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
from ray.data.extensions import TensorArray
import pyarrow as pa
# TODO(ekl) Ideally numpy can read directly from the file, but it
# seems like it requires the file to be seekable.
buf = BytesIO()
data = f.readall()
buf.write(data)
buf.seek(0)
return np.load(buf)
return pa.Table.from_pydict({
"value": TensorArray(np.load(buf, allow_pickle=True))
})

def _write_block(self, f: "pyarrow.NativeFile", block: BlockAccessor,
**writer_args):
np.save(f, block.to_arrow())
column: str, **writer_args):
value = block.to_numpy(column)
np.save(f, value)

def _file_format(self):
return "npy"
4 changes: 4 additions & 0 deletions python/ray/data/extensions/tensor_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,10 @@ def __arrow_ext_class__(self):
"""
return ArrowTensorArray

def __str__(self):
return "<ArrowTensorType: shape={}, dtype={}>".format(
self.shape, self.storage_type.value_type)


@PublicAPI(stability="beta")
class ArrowTensorArray(pa.ExtensionArray):
Expand Down
Loading