Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT-#5816: Implement '.split' method for axis partitions #5856

Merged
merged 7 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions modin/core/dataframe/base/partitioning/axis_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class BaseDataframeAxisPartition(ABC): # pragma: no cover
An abstract class that represents the parent class for any axis partition class.

This class is intended to simplify the way that operations are performed.

Attributes
----------
_PARTITIONS_METADATA_LEN : int
The number of metadata values that the object of `partition_type` consumes.
"""

@property
Expand Down Expand Up @@ -87,15 +92,21 @@ def apply(
# Child classes must have these in order to correctly subclass.
instance_type = None
partition_type = None
_PARTITIONS_METADATA_LEN = 0

def _wrap_partitions(self, partitions: list) -> list:
def _wrap_partitions(
self, partitions: list, extract_metadata: Optional[bool] = None
) -> list:
"""
Wrap remote partition objects with `BaseDataframePartition` class.

Parameters
----------
partitions : list
List of remotes partition objects to be wrapped with `BaseDataframePartition` class.
extract_metadata : bool, optional
Whether the partitions list contains information about partition's metadata.
If `None` was passed will take the argument's value from the value of `cls._PARTITIONS_METADATA_LEN`.

Returns
-------
Expand All @@ -105,7 +116,23 @@ def _wrap_partitions(self, partitions: list) -> list:
assert self.partition_type is not None
assert self.instance_type is not None # type: ignore

return [self.partition_type(obj) for obj in partitions]
if extract_metadata is None:
# If `_PARTITIONS_METADATA_LEN == 0` then the execution doesn't support metadata
# and thus we should never try extracting it, otherwise assuming that the common
# approach of always passing the metadata is used.
extract_metadata = bool(self._PARTITIONS_METADATA_LEN)

if extract_metadata:
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
return [
self.partition_type(*init_args)
for init_args in zip(
# `partition_type` consumes `(object_id, *metadata)`, thus adding `+1`
*[iter(partitions)]
* (self._PARTITIONS_METADATA_LEN + 1)
)
]
else:
return [self.partition_type(object_id) for object_id in partitions]

def force_materialization(
self, get_ip: bool = False
Expand Down
82 changes: 82 additions & 0 deletions modin/core/dataframe/pandas/partitioning/axis_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,88 @@ def apply(
)
)

def split(
self, split_func, num_splits, f_args=None, f_kwargs=None, extract_metadata=False
):
"""
Split axis partition into multiple partitions using the `split_func`.

Parameters
----------
split_func : callable(pandas.DataFrame) -> list[pandas.DataFrame]
A function that takes partition's content and split it into multiple chunks.
num_splits : int
The number of splits the `split_func` return.
f_args : iterable, optional
Positional arguments to pass to the `split_func`.
f_kwargs : dict, optional
Keyword arguments to pass to the `split_func`.
extract_metadata : bool, default: False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original .split method is already implemented in the manner of not extracting the metadata:

outputs = self.execution_wrapper.deploy(
split_func, [self._data] + list(args), num_returns=num_splits
)
self._is_debug(log) and log.debug(f"EXIT::Partition.split::{self._identity}")
return [self.__constructor__(output) for output in outputs]

So the full-axis implementation just follows the initial approach.

We don't want to extract metadata because:

  1. Partitions generated by this function are temporary, at the reshuffling flow the split_row_partitions are immediately replaced by new_partitions holding new metadata, meaning that the metadata of split_row_partitions is never accessed:
    # We need to convert every partition that came from the splits into a full-axis column partition.
    new_partitions = [
    [
    cls._column_partitions_class(row_partition, full_axis=False).apply(
    final_shuffle_func
    )
    ]
    for row_partition in split_row_partitions
    ]
  2. The splitting stage generates a lot of partitions (up to ncores ^ 2), it's already not an easy task for ray to put into storage that big amount of futures at once, the situation becomes even worse when we ask to store the metadata futures as well (4 * (ncores ^ 2) amount of futures at once). I've measured the case from [PERF] Slow sort_values in value_counts #5533 with and without the partition's metadata, and received a stable 9% speed-up (~ 0.12s) for the case without metadata.

Whether to return metadata (length, width, ip) of the result. Passing `False` may relax
the load on object storage as the remote function would return X times fewer futures
(where X is the number of metadata values). Passing `False` makes sense for temporary
results where you know for sure that the metadata will never be requested.

Returns
-------
list
List of wrapped remote partition objects.
"""
f_args = tuple() if f_args is None else f_args
f_kwargs = {} if f_kwargs is None else f_kwargs
return self._wrap_partitions(
self.deploy_splitting_func(
self.axis,
split_func,
f_args,
f_kwargs,
num_splits,
*self.list_of_blocks,
extract_metadata=extract_metadata,
),
extract_metadata=extract_metadata,
)

@classmethod
def deploy_splitting_func(
cls,
axis,
split_func,
f_args,
f_kwargs,
num_splits,
*partitions,
extract_metadata=False,
):
"""
Deploy a splitting function along a full axis.

Parameters
----------
axis : {0, 1}
The axis to perform the function along.
split_func : callable(pandas.DataFrame) -> list[pandas.DataFrame]
The function to perform.
f_args : list or tuple
Positional arguments to pass to `split_func`.
f_kwargs : dict
Keyword arguments to pass to `split_func`.
num_splits : int
The number of splits the `split_func` return.
*partitions : iterable
All partitions that make up the full axis (row or column).
extract_metadata : bool, default: False
Whether to return metadata (length, width, ip) of the result. Note that `True` value
is not supported in `PandasDataframeAxisPartition` class.

Returns
-------
list
A list of pandas DataFrames.
"""
dataframe = pandas.concat(list(partitions), axis=axis, copy=False)
return split_func(dataframe, *f_args, **f_kwargs)

@classmethod
def deploy_axis_func(
cls,
Expand Down
14 changes: 8 additions & 6 deletions modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,16 +1581,18 @@ def shuffle_partitions(
# Convert our list of block partitions to row partitions. We need to create full-axis
# row partitions since we need to send the whole partition to the split step as otherwise
# we wouldn't know how to split the block partitions that don't contain the shuffling key.
row_partitions = [
partition.force_materialization().list_of_block_partitions[0]
for partition in cls.row_partitions(partitions)
]
row_partitions = [partition for partition in cls.row_partitions(partitions)]
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
if len(pivots):
# Gather together all of the sub-partitions
split_row_partitions = np.array(
[
partition.split(
shuffle_functions.split_function, len(pivots) + 1, pivots
shuffle_functions.split_function,
num_splits=len(pivots) + 1,
f_args=(pivots,),
# The partition's metadata will never be accessed for the split partitions,
# thus no need to compute it.
extract_metadata=False,
)
for partition in row_partitions
]
Expand All @@ -1608,5 +1610,5 @@ def shuffle_partitions(
else:
# If there are not pivots we can simply apply the function row-wise
return np.array(
[[row_part.apply(final_shuffle_func)] for row_part in row_partitions]
[row_part.apply(final_shuffle_func) for row_part in row_partitions]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

row_part is now actually a row partition returning a list, meaning there's no need to wrap this into a list no more

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class PandasOnDaskDataframeVirtualPartition(PandasDataframeAxisPartition):
"""

axis = None
_PARTITIONS_METADATA_LEN = 3 # (length, width, ip)
partition_type = PandasOnDaskDataframePartition
instance_type = Future

Expand Down Expand Up @@ -145,6 +146,34 @@ def list_of_ips(self):
result[idx] = partition._ip_cache
return result

@classmethod
@_inherit_docstrings(PandasDataframeAxisPartition.deploy_splitting_func)
def deploy_splitting_func(
cls,
axis,
func,
f_args,
f_kwargs,
num_splits,
*partitions,
extract_metadata=False,
):
return DaskWrapper.deploy(
func=_deploy_dask_func,
f_args=(
PandasDataframeAxisPartition.deploy_splitting_func,
axis,
func,
f_args,
f_kwargs,
num_splits,
*partitions,
),
f_kwargs={"extract_metadata": extract_metadata},
num_returns=num_splits * 4 if extract_metadata else num_splits,
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
pure=False,
)

@classmethod
def deploy_axis_func(
cls,
Expand Down Expand Up @@ -266,25 +295,6 @@ def deploy_func_between_two_axis_partitions(
pure=False,
)

def _wrap_partitions(self, partitions):
"""
Wrap partitions passed as a list of distributed.Future with ``PandasOnDaskDataframePartition`` class.

Parameters
----------
partitions : list
List of distributed.Future.

Returns
-------
list
List of ``PandasOnDaskDataframePartition`` objects.
"""
return [
self.partition_type(future, length, width, ip)
for (future, length, width, ip) in zip(*[iter(partitions)] * 4)
]

def apply(
self,
func,
Expand Down Expand Up @@ -505,7 +515,16 @@ class PandasOnDaskDataframeRowPartition(PandasOnDaskDataframeVirtualPartition):
axis = 1


def _deploy_dask_func(deployer, axis, f_to_deploy, f_args, f_kwargs, *args, **kwargs):
def _deploy_dask_func(
deployer,
axis,
f_to_deploy,
f_args,
f_kwargs,
*args,
extract_metadata=True,
**kwargs,
):
"""
Execute a function on an axis partition in a worker process.

Expand All @@ -527,6 +546,11 @@ def _deploy_dask_func(deployer, axis, f_to_deploy, f_args, f_kwargs, *args, **kw
Keyword arguments to pass to ``f_to_deploy``.
*args : list
Positional arguments to pass to ``func``.
extract_metadata : bool, default: True
Whether to return metadata (length, width, ip) of the result. Passing `False` may relax
the load on object storage as the remote function would return 4 times fewer futures.
Passing `False` makes sense for temporary results where you know for sure that the
metadata will never be requested.
**kwargs : dict
Keyword arguments to pass to ``func``.

Expand All @@ -536,6 +560,8 @@ def _deploy_dask_func(deployer, axis, f_to_deploy, f_args, f_kwargs, *args, **kw
The result of the function ``func`` and metadata for it.
"""
result = deployer(axis, f_to_deploy, f_args, f_kwargs, *args, **kwargs)
if not extract_metadata:
return result
ip = get_ip()
if isinstance(result, pandas.DataFrame):
return result, len(result), len(result.columns), ip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# If Ray has not been initialized yet by Modin,
# it will be initialized when calling `RayWrapper.put`.
_DEPLOY_AXIS_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_axis_func)
_DEPLOY_SPLIT_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_splitting_func)
_DRAIN = RayWrapper.put(PandasDataframeAxisPartition.drain)


Expand All @@ -54,6 +55,7 @@ class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition):
Width, or reference to width, of wrapped ``pandas.DataFrame``.
"""

_PARTITIONS_METADATA_LEN = 3 # (length, width, ip)
partition_type = PandasOnRayDataframePartition
instance_type = ray.ObjectRef
axis = None
Expand Down Expand Up @@ -150,6 +152,31 @@ def list_of_ips(self):
result[idx] = partition._ip_cache
return result

@classmethod
@_inherit_docstrings(PandasDataframeAxisPartition.deploy_splitting_func)
def deploy_splitting_func(
cls,
axis,
func,
f_args,
f_kwargs,
num_splits,
*partitions,
extract_metadata=False,
):
return _deploy_ray_func.options(
num_returns=num_splits * 4 if extract_metadata else num_splits,
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
).remote(
_DEPLOY_SPLIT_FUNC,
axis,
func,
f_args,
f_kwargs,
num_splits,
*partitions,
extract_metadata=extract_metadata,
)

@classmethod
def deploy_axis_func(
cls,
Expand Down Expand Up @@ -264,25 +291,6 @@ def deploy_func_between_two_axis_partitions(
*partitions,
)

def _wrap_partitions(self, partitions):
"""
Wrap partitions passed as a list of ``ray.ObjectRef`` with ``PandasOnRayDataframePartition`` class.

Parameters
----------
partitions : list
List of ``ray.ObjectRef``.

Returns
-------
list
List of ``PandasOnRayDataframePartition`` objects.
"""
return [
self.partition_type(object_id, length, width, ip)
for (object_id, length, width, ip) in zip(*[iter(partitions)] * 4)
]

def apply(
self,
func,
Expand Down Expand Up @@ -522,7 +530,14 @@ class PandasOnRayDataframeRowPartition(PandasOnRayDataframeVirtualPartition):

@ray.remote
def _deploy_ray_func(
deployer, axis, f_to_deploy, f_args, f_kwargs, *args, **kwargs
deployer,
axis,
f_to_deploy,
f_args,
f_kwargs,
*args,
extract_metadata=True,
**kwargs,
): # pragma: no cover
"""
Execute a function on an axis partition in a worker process.
Expand All @@ -547,6 +562,11 @@ def _deploy_ray_func(
Keyword arguments to pass to ``f_to_deploy``.
*args : list
Positional arguments to pass to ``deployer``.
extract_metadata : bool, default: True
Whether to return metadata (length, width, ip) of the result. Passing `False` may relax
the load on object storage as the remote function would return 4 times fewer futures.
Passing `False` makes sense for temporary results where you know for sure that the
metadata will never be requested.
**kwargs : dict
Keyword arguments to pass to ``deployer``.

Expand All @@ -561,6 +581,8 @@ def _deploy_ray_func(
"""
f_args = deserialize(f_args)
result = deployer(axis, f_to_deploy, f_args, f_kwargs, *args, **kwargs)
if not extract_metadata:
return result
ip = get_node_ip_address()
if isinstance(result, pandas.DataFrame):
return result, len(result), len(result.columns), ip
Expand Down
Loading