From ca2eba38c1ad2b48bf9614c176aabb28f9478d66 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 4 Nov 2022 17:54:36 +0000 Subject: [PATCH] Simplify copy_range dispatch Precondition is that range indices have been sanitised. --- python/cudf/cudf/_lib/copying.pyx | 41 ++++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index f5c910ca77d..1de91e6a3e9 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -132,36 +132,37 @@ def _copy_range(Column input_column, return Column.from_unique_ptr(move(c_result)) -def copy_range(Column input_column, +def copy_range(Column source_column, Column target_column, - size_type input_begin, - size_type input_end, + size_type source_begin, + size_type source_end, size_type target_begin, size_type target_end, bool inplace): """ - Copy input_column from input_begin to input_end to - target_column from target_begin to target_end - """ - - if abs(target_end - target_begin) < 1: - return target_column + Copy a contiguous range from a source to a target column - if target_begin < 0: - target_begin = target_begin + target_column.size - - if target_end < 0: - target_end = target_end + target_column.size + Notes + ----- + Expects the source and target ranges to have been sanitised to be + in-range for the source and target column respectively. For + example via ``slice.indices``. + """ - if target_begin > target_end: + assert ( + source_end - source_begin == target_end - target_begin, + "Source and target ranges must be same length" + ) + if target_end >= target_begin and inplace: + # FIXME: Are we allowed to do this when inplace=False? return target_column - if inplace is True: - _copy_range_in_place(input_column, target_column, - input_begin, input_end, target_begin) + if inplace: + _copy_range_in_place(source_column, target_column, + source_begin, source_end, target_begin) else: - return _copy_range(input_column, target_column, - input_begin, input_end, target_begin) + return _copy_range(source_column, target_column, + source_begin, source_end, target_begin) def gather(