diff --git a/modin/engines/base/frame/data.py b/modin/engines/base/frame/data.py index 6c31bf1d00d..d4fba1c686d 100644 --- a/modin/engines/base/frame/data.py +++ b/modin/engines/base/frame/data.py @@ -1483,9 +1483,10 @@ def broadcast_apply( BasePandasFrame """ # Only sort the indices if they do not match - left_parts, right_parts, joined_index = self._copartition( - axis, other, join_type, sort=not self.axes[axis].equals(other.axes[axis]) - ) + other_index = other.axes[axis] + sort = not self.axes[axis].equals(other_index) + joined_index = self._join_index_objects(axis, other_index, join_type, sort) + left_parts, right_parts = self._copartition(axis, other, joined_index) # unwrap list returned by `copartition`. right_parts = right_parts[0] new_frame = self._frame_mgr_cls.broadcast_apply( @@ -1690,7 +1691,12 @@ def broadcast_apply_full_axis( ) def _copartition( - self, axis, other, how, sort, force_repartition=False, reindexer=None + self, + axis, + other, + joined_index, + force_repartition=False, + make_map_reindexer=None, ): """ Copartition two dataframes. @@ -1703,18 +1709,15 @@ def _copartition( The axis to copartition along (0 - rows, 1 - columns). other : BasePandasFrame The other dataframes(s) to copartition against. - how : str - How to manage joining the index object ("left", "right", etc.) - sort : bool - Whether or not to sort the joined index. + joined_index : Index, default None force_repartition : bool, default False Whether or not to force the repartitioning. By default, this method will skip repartitioning if it is possible. This is because reindexing is extremely inefficient. Because this method is used to `join` or `append`, it is vital that the internal indices match. - reindexer : str, default None - Defines the operation for which `_copartition` is executed. - Allows us to add some specifics (for example, how to make reindex). + make_map_reindexer : func, default None + Defines indexer for specific case. + (_copartition works in concat, binary_op, broadcast_apply functions) Returns ------- @@ -1728,34 +1731,23 @@ def _copartition( return ( self._partitions, [self._simple_shuffle(axis, o) for o in other], - self.axes[axis].copy(), ) - index_other_obj = [o.axes[axis] for o in other] - joined_index = self._join_index_objects(axis, index_other_obj, how, sort) - # We have to set these because otherwise when we perform the functions it may # end up serializing this entire object. left_old_idx = self.axes[axis] - right_old_idxes = index_other_obj + right_old_idxes = [o.axes[axis] for o in other] - def make_map_func(index, left=True): - # left - specific argument for case of binary operation; - # it choose indexer for left or right index - if index.equals(joined_index): - return lambda df: df - if reindexer == "binary": - # case for binary operation with duplicate values; way from pandas - _join_index, ilidx, iridx = self.axes[axis].join( - other[0].axes[axis], how=how, sort=sort, return_indexers=True - ) + if make_map_reindexer is None: - return lambda df: df._reindex_with_indexers( - {axis: [_join_index, ilidx if left else iridx]}, - copy=True, - allow_dups=True, - ) - return lambda df: df.reindex(joined_index, axis=axis) + def make_map_func(index, left=True): + # left - specific argument for case of binary operation; + # it choose indexer for left or right index + if index.equals(joined_index): + return lambda df: df + return lambda df: df.reindex(joined_index, axis=axis) + + make_map_reindexer = make_map_func # Start with this and we'll repartition the first time, and then not again. if not force_repartition and left_old_idx.equals(joined_index): @@ -1764,7 +1756,7 @@ def make_map_func(index, left=True): reindexed_self = self._frame_mgr_cls.map_axis_partitions( axis, self._partitions, - make_map_func(left_old_idx), + make_map_reindexer(left_old_idx), ) def get_column_widths(partitions): @@ -1783,13 +1775,13 @@ def get_row_lengths(partitions): reindexed_other = other[i]._frame_mgr_cls.map_axis_partitions( axis, other[i]._partitions, - make_map_func(right_old_idxes[i], left=False), + make_map_reindexer(right_old_idxes[i], left=False), lengths=get_row_lengths(reindexed_self) if axis == 0 else get_column_widths(reindexed_self), ) reindexed_other_list.append(reindexed_other) - return reindexed_self, reindexed_other_list, joined_index + return reindexed_self, reindexed_other_list def _simple_shuffle(self, axis, other): """ @@ -1841,12 +1833,28 @@ def _binary_op(self, op, right_frame, join_type="outer"): BasePandasFrame A new dataframe. """ - left_parts, right_parts, joined_index = self._copartition( + joined_index, ilidx, iridx = self.axes[0].join( + right_frame.axes[0], how=join_type, sort=True, return_indexers=True + ) + + def make_map_reindexer(index, left=True): + # left - specific argument for case of binary operation; + # it choose indexer for left or right index + if index.equals(joined_index): + return lambda df: df + + # case with duplicate values; way from pandas + return lambda df: df._reindex_with_indexers( + {0: [joined_index, ilidx if left else iridx]}, + copy=True, + allow_dups=True, + ) + + left_parts, right_parts = self._copartition( 0, right_frame, - join_type, - sort=True, - reindexer="binary", + joined_index, + make_map_reindexer=make_map_reindexer, ) # unwrap list returned by `copartition`. right_parts = right_parts[0] @@ -1900,8 +1908,16 @@ def _concat(self, axis, others, how, sort): length for o in others for length in o._column_widths ] else: - left_parts, right_parts, joined_index = self._copartition( - axis ^ 1, others, how, sort, force_repartition=True + copartition_axis = axis ^ 1 + others_index = [o.axes[copartition_axis] for o in others] + joined_index = self._join_index_objects( + copartition_axis, others_index, how, sort + ) + left_parts, right_parts = self._copartition( + copartition_axis, + others, + joined_index, + force_repartition=True, ) new_lengths = None new_widths = None