Skip to content

Commit

Permalink
FIX-#5164: Fix unwrap_partitions for virtual partitions when `axis=No…
Browse files Browse the repository at this point in the history
…ne` (#6560)

Co-authored-by: Rehan Durrani <[email protected]>
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev and Rehan Durrani authored Sep 15, 2023
1 parent b95d9b3 commit aa75256
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
25 changes: 23 additions & 2 deletions modin/distributed/dataframe/pandas/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,30 @@
if TYPE_CHECKING:
from modin.core.execution.ray.implementations.pandas_on_ray.partitioning import (
PandasOnRayDataframePartition,
PandasOnRayDataframeColumnPartition,
PandasOnRayDataframeRowPartition,
)
from modin.core.execution.dask.implementations.pandas_on_dask.partitioning import (
PandasOnDaskDataframePartition,
PandasOnDaskDataframeColumnPartition,
PandasOnDaskDataframeRowPartition,
)
from modin.core.execution.unidist.implementations.pandas_on_unidist.partitioning.partition import (
from modin.core.execution.unidist.implementations.pandas_on_unidist.partitioning import (
PandasOnUnidistDataframePartition,
PandasOnUnidistDataframeColumnPartition,
PandasOnUnidistDataframeRowPartition,
)

PartitionUnionType = Union[
PandasOnRayDataframePartition,
PandasOnDaskDataframePartition,
PandasOnUnidistDataframePartition,
PandasOnRayDataframeColumnPartition,
PandasOnRayDataframeRowPartition,
PandasOnDaskDataframeColumnPartition,
PandasOnDaskDataframeRowPartition,
PandasOnUnidistDataframeColumnPartition,
PandasOnUnidistDataframeRowPartition,
]
else:
from typing import Any
Expand Down Expand Up @@ -85,7 +97,10 @@ def _unwrap_partitions() -> list:
[p.drain_call_queue() for p in modin_frame._partitions.flatten()]

def get_block(partition: PartitionUnionType) -> np.ndarray:
blocks = partition.list_of_blocks
if hasattr(partition, "force_materialization"):
blocks = partition.force_materialization().list_of_blocks
else:
blocks = partition.list_of_blocks
assert (
len(blocks) == 1
), f"Implementation assumes that partition contains a single block, but {len(blocks)} recieved."
Expand All @@ -109,6 +124,12 @@ def get_block(partition: PartitionUnionType) -> np.ndarray:
"PandasOnRayDataframePartition",
"PandasOnDaskDataframePartition",
"PandasOnUnidistDataframePartition",
"PandasOnRayDataframeColumnPartition",
"PandasOnRayDataframeRowPartition",
"PandasOnDaskDataframeColumnPartition",
"PandasOnDaskDataframeRowPartition",
"PandasOnUnidistDataframeColumnPartition",
"PandasOnUnidistDataframeRowPartition",
):
return _unwrap_partitions()
raise ValueError(
Expand Down
22 changes: 22 additions & 0 deletions modin/test/test_partition_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ def get_df(lib, data):
)


def test_unwrap_virtual_partitions():
# see #5164 for details
data = test_data["int_data"]
df = pd.DataFrame(data)
virtual_partitioned_df = pd.concat([df] * 10)
actual_partitions = np.array(unwrap_partitions(virtual_partitioned_df, axis=None))
expected_df = pd.concat([pd.DataFrame(data)] * 10)
expected_partitions = expected_df._query_compiler._modin_frame._partitions
assert expected_partitions.shape == actual_partitions.shape

for row_idx in range(expected_partitions.shape[0]):
for col_idx in range(expected_partitions.shape[1]):
df_equals(
get_func(
expected_partitions[row_idx][col_idx]
.force_materialization()
.list_of_blocks[0]
),
get_func(actual_partitions[row_idx][col_idx]),
)


@pytest.mark.parametrize("column_widths", [None, "column_widths"])
@pytest.mark.parametrize("row_lengths", [None, "row_lengths"])
@pytest.mark.parametrize("columns", [None, "columns"])
Expand Down

0 comments on commit aa75256

Please sign in to comment.