Skip to content

Commit

Permalink
PERF-modin-project#4494: Get all partition widths/lengths in parallel
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
  • Loading branch information
noloerino committed Aug 9, 2022
1 parent 8e1190c commit cb4f35c
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,35 @@ def _get_partition_size_along_axis(self, partition, axis=0):
Returns
-------
list
A list of lengths along the specified axis that sum to the overall length of the partition
along the specified axis.
A list of Dask futures representing lengths along the specified axis that sum to
the overall length of the partition along the specified axis.
Notes
-----
This utility function is used to ensure that computation occurs asynchronously across all partitions
whether the partitions are virtual or physical partitions.
"""

def len_fn(df):
return len(df) if not axis else len(df.columns)

if isinstance(partition, self._partition_mgr_cls._partition_class):
return [
partition.apply(
lambda df: len(df) if not axis else len(df.columns)
)._data
]
return [partition.apply(len_fn)._data]
elif partition.axis == axis:
return [
ptn.apply(lambda df: len(df) if not axis else len(df.columns))._data
ptn.apply(len_fn)._data
for ptn in partition.list_of_block_partitions
]
return [
partition.list_of_block_partitions[0]
.apply(lambda df: len(df) if not axis else (len(df.columns)))
.apply(len_fn)
._data
]

@property
def _row_lengths(self):
"""
Compute ther row partitions lengths if they are not cached.
Compute the row partitions lengths if they are not cached.
Returns
-------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,80 @@ class PandasOnRayDataframe(PandasDataframe):
"""

_partition_mgr_cls = PandasOnRayDataframePartitionManager

def _get_partition_size_along_axis(self, partition, axis=0):
"""
Compute the length along the specified axis of the specified partition.
Parameters
----------
partition : ``PandasOnRayDataframeVirtualPartition`` or ``PandasOnRayDataframePartition``
The partition whose size to compute.
axis : int, default: 0
The axis along which to compute size.
Returns
-------
list
A list of ray object IDs representing lengths along the specified axis that sum to the overall length of the partition
along the specified axis.
Notes
-----
This utility function is used to ensure that computation occurs asynchronously across all partitions
whether the partitions are virtual or physical partitions.
"""

def len_fn(df):
return len(df) if not axis else len(df.columns)

if isinstance(partition, self._partition_mgr_cls._partition_class):
return [partition.apply(len_fn)._data]
elif partition.axis == axis:
return [
ptn.apply(len_fn)._data
for ptn in partition.list_of_partitions_to_combine
]
return [partition.list_of_partitions_to_combine[0].apply(len_fn)._data]

@property
def _row_lengths(self):
"""
Compute the row partitions lengths if they are not cached.
Returns
-------
list
A list of row partitions lengths.
"""
if self._row_lengths_cache is None:
row_lengths_list = ray.get(
[
self._get_partition_size_along_axis(obj, axis=0)
for obj in self._partitions.T[0]
]
)
self._row_lengths_cache = [sum(len_list) for len_list in row_lengths_list]
return self._row_lengths_cache

@property
def _column_widths(self):
"""
Compute the column partitions widths if they are not cached.
Returns
-------
list
A list of column partitions widths.
"""
if self._column_widths_cache is None:
col_widths_list = ray.get(
[
self._get_partition_size_along_axis(obj, axis=1)
for obj in self._partitions[0]
]
)
self._column_widths_cache = [
sum(width_list) for width_list in col_widths_list
]
return self._column_widths_cache

0 comments on commit cb4f35c

Please sign in to comment.