From 25c94dc0d66c18c25332c540e01ba15bbbccad4f Mon Sep 17 00:00:00 2001 From: jukejian Date: Wed, 11 Dec 2024 17:35:26 +0800 Subject: [PATCH] [data] fix write lance failed from high version Signed-off-by: jukejian --- .../data/_internal/planner/plan_write_op.py | 20 +++++++++--- python/ray/data/datasource/datasink.py | 31 +++++++++++++++++-- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py index 690aab5b46fd2..ca0298f7a01c7 100644 --- a/python/ray/data/_internal/planner/plan_write_op.py +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -26,10 +26,14 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: # Create a copy of the iterator, so we can return the original blocks. it1, it2 = itertools.tee(blocks, 2) if isinstance(datasink_or_legacy_datasource, Datasink): - datasink_or_legacy_datasource.write(it1, ctx) + task_result = datasink_or_legacy_datasource.write(it1, ctx) else: - datasink_or_legacy_datasource.write(it1, ctx, **write_args) - return it2 + task_result = datasink_or_legacy_datasource.write(it1, ctx, **write_args) + + import pandas as pd + + block = pd.DataFrame({"task_result": [task_result], "origin_block": [it2]}) + return iter([block]) return fn @@ -43,7 +47,11 @@ def generate_collect_write_stats_fn() -> ( # execution outcomes with `on_write_complete()`` and `on_write_failed()``. def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: """Handles stats collection for block writes.""" - block_accessors = [BlockAccessor.for_block(block) for block in blocks] + # only have one element in the iterator + first_element = dict(next(blocks).iloc[0]) + origin_block = first_element["origin_block"] + task_result = first_element["task_result"] + block_accessors = [BlockAccessor.for_block(block) for block in origin_block] total_num_rows = sum(ba.num_rows() for ba in block_accessors) total_size_bytes = sum(ba.size_bytes() for ba in block_accessors) @@ -51,7 +59,9 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: # type. import pandas as pd - write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes) + write_result = WriteResult( + num_rows=total_num_rows, size_bytes=total_size_bytes, result=task_result + ) block = pd.DataFrame({"write_result": [write_result]}) return iter([block]) diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index fe4d4cf4ef9a3..3503ad7d1e1f3 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass, fields -from typing import Iterable, List, Optional +from typing import Any, Iterable, List, Optional import ray from ray.data._internal.execution.interfaces import TaskContext @@ -19,10 +19,32 @@ class WriteResult: Attributes: total_num_rows: The total number of rows written. total_size_bytes: The total size of the written data in bytes. + result: every task can return a result. """ num_rows: int = 0 size_bytes: int = 0 + result: Any = None + + def __init__( + self, num_rows: int = None, size_bytes: int = None, result: Any = None + ) -> None: + if result is not None: + self.result = result + if num_rows is not None: + self.num_rows = num_rows + if size_bytes is not None: + self.size_bytes = size_bytes + + def __getitem__(self, key): + if key == "total_num_rows": + return self.num_rows + elif key == "total_size_bytes": + return self.size_bytes + elif key == "result": + return self.result + else: + raise KeyError(f"Key {key} not found in WriteResult") @staticmethod def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult": @@ -36,14 +58,19 @@ def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult" """ total_num_rows = 0 total_size_bytes = 0 + total_result = [] for write_result in write_results: total_num_rows += write_result.num_rows total_size_bytes += write_result.size_bytes + total_result += ( + write_result.result if write_result.result is not None else [] + ) return WriteResult( num_rows=total_num_rows, size_bytes=total_size_bytes, + result=total_result, ) @@ -67,7 +94,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> None: + ) -> Any: """Write blocks. This is used by a single write task. Args: