Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass write result to on_write_complete #49091

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/ray/data/_internal/datasource/bigquery_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import time
import uuid
from typing import Iterable, Optional
from typing import Any, Iterable, Optional

import pyarrow.parquet as pq

Expand Down Expand Up @@ -70,7 +70,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
from google.api_core import exceptions
from google.cloud import bigquery
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/_internal/datasource/mongo_datasink.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Iterable
from typing import Any, Iterable

from ray.data._internal.datasource.mongo_datasource import (
_validate_database_collection_exist,
Expand All @@ -26,7 +26,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
import pymongo

_validate_database_collection_exist(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
import pyarrow as pa
import pyarrow.parquet as pq

Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/_internal/datasource/sql_datasink.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable
from typing import Any, Callable, Iterable

from ray.data._internal.datasource.sql_datasource import Connection, _connect
from ray.data._internal.execution.interfaces import TaskContext
Expand All @@ -18,7 +18,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
with _connect(self.connection_factory) as cursor:
for block in blocks:
block_accessor = BlockAccessor.for_block(block)
Expand Down
36 changes: 23 additions & 13 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from typing import Callable, Iterator, List, Union
import pickle
from typing import Callable, Iterable, List, Union

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator
Expand All @@ -8,6 +9,7 @@
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
MapTransformer,
MapTransformFn,
)
from ray.data._internal.logical.operators.write_operator import Write
from ray.data.block import Block, BlockAccessor
Expand All @@ -18,30 +20,40 @@

def generate_write_fn(
datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]:
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
) -> Callable[[Iterable[Block], TaskContext], Iterable[Block]]:
stats_fn = generate_collect_write_stats_fn()

def fn(blocks: Iterable[Block], ctx) -> Iterable[Block]:
"""Writes the blocks to the given datasink or legacy datasource.

Outputs the original blocks to be written."""
# Create a copy of the iterator, so we can return the original blocks.
it1, it2 = itertools.tee(blocks, 2)
if isinstance(datasink_or_legacy_datasource, Datasink):
datasink_or_legacy_datasource.write(it1, ctx)
write_result = datasink_or_legacy_datasource.write(it1, ctx)
else:
datasink_or_legacy_datasource.write(it1, ctx, **write_args)
return it2
write_result = datasink_or_legacy_datasource.write(it1, ctx, **write_args)

import pandas as pd

payload = pd.DataFrame({"payload": [pickle.dumps(write_result)]})

stats = list(stats_fn(it2, ctx))
assert len(stats) == 1
block = pd.concat([stats[0], payload], axis=1)
return iter([block])

return fn


def generate_collect_write_stats_fn() -> (
Callable[[Iterator[Block], TaskContext], Iterator[Block]]
):
def generate_collect_write_stats_fn() -> Callable[
[Iterable[Block], TaskContext], Iterable[Block]
]:
# If the write op succeeds, the resulting Dataset is a list of
# one Block which contain stats/metrics about the write.
# Otherwise, an error will be raised. The Datasource can handle
# execution outcomes with `on_write_complete()`` and `on_write_failed()``.
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterable[Block], ctx) -> Iterable[Block]:
"""Handles stats collection for block writes."""
block_accessors = [BlockAccessor.for_block(block) for block in blocks]
total_num_rows = sum(ba.num_rows() for ba in block_accessors)
Expand All @@ -67,11 +79,9 @@ def plan_write_op(
input_physical_dag = physical_children[0]

write_fn = generate_write_fn(op._datasink_or_legacy_datasource, **op._write_args)
collect_stats_fn = generate_collect_write_stats_fn()
# Create a MapTransformer for a write operator
transform_fns = [
transform_fns: List[MapTransformFn] = [
BlockMapTransformFn(write_fn),
BlockMapTransformFn(collect_stats_fn),
]
map_transformer = MapTransformer(transform_fns)
return MapOperator.create(
Expand Down
9 changes: 5 additions & 4 deletions python/ray/data/datasource/datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass, fields
from typing import Iterable, List, Optional
from typing import Any, Iterable, List, Optional

import ray
from ray.data._internal.execution.interfaces import TaskContext
Expand Down Expand Up @@ -67,7 +67,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
"""Write blocks. This is used by a single write task.

Args:
Expand All @@ -82,6 +82,7 @@ def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
This can be used to "commit" a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.
Return value of write function is stored in payload column.

Args:
write_result_blocks: The blocks resulting from executing
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(self):
self.rows_written = 0
self.enabled = True

def write(self, block: Block) -> None:
def write(self, block: Block) -> Any:
block = BlockAccessor.for_block(block)
self.rows_written += block.num_rows()

Expand All @@ -181,7 +182,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
tasks = []
if not self.enabled:
raise ValueError("disabled")
Expand Down
34 changes: 32 additions & 2 deletions python/ray/data/tests/test_datasink.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Iterable
import pickle
from typing import Any, Iterable, List

import pytest

import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block
from ray.data.datasource import Datasink
from ray.data.datasource.datasink import WriteResult


@pytest.mark.parametrize("num_rows_per_write", [5, 10, 50])
Expand All @@ -14,7 +16,7 @@ class MockDatasink(Datasink):
def __init__(self, num_rows_per_write):
self._num_rows_per_write = num_rows_per_write

def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None:
def write(self, blocks: Iterable[Block], ctx: TaskContext) -> Any:
assert sum(len(block) for block in blocks) == self._num_rows_per_write

@property
Expand All @@ -26,6 +28,34 @@ def num_rows_per_write(self):
)


@pytest.mark.parametrize("num_rows_per_write", [5, 10, 50])
def test_on_write_complete(tmp_path, ray_start_regular_shared, num_rows_per_write):
class MockDatasink(Datasink):
def __init__(self, num_rows_per_write):
self._num_rows_per_write = num_rows_per_write
self.payloads = None

def write(self, blocks: Iterable[Block], ctx: TaskContext) -> Any:
assert sum(len(block) for block in blocks) == self._num_rows_per_write
return f"task-{ctx.task_idx}"

def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
self.payloads = [
pickle.loads(result["payload"].iloc[0])
for result in write_result_blocks
]
return super().on_write_complete(write_result_blocks)

@property
def num_rows_per_write(self):
return self._num_rows_per_write

sink = MockDatasink(num_rows_per_write)
ray.data.range(100, override_num_blocks=20).write_datasink(sink)
expect = [f"task-{i}" for i in range(100 // num_rows_per_write)]
assert sink.payloads == expect


if __name__ == "__main__":
import sys

Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
from typing import Iterable, List
from typing import Any, Iterable, List

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(self):
self.rows_written = 0
self.node_ids = set()

def write(self, node_id: str, block: Block) -> str:
def write(self, node_id: str, block: Block) -> Any:
block = BlockAccessor.for_block(block)
self.rows_written += block.num_rows()
self.node_ids.add(node_id)
Expand All @@ -255,7 +255,7 @@ def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
) -> Any:
data_sink = self.data_sink

def write(b):
Expand Down
Loading