From 2bcdb54ab477b949e63e039b17ab2bbb9b7e0433 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 24 Jan 2023 07:31:29 -0600 Subject: [PATCH 1/7] Fix SUM/MEAN aggregation type support. (#12503) This PR closes #8399. We cleaned up the logic by fixing SUM/MEAN aggregation type support, which also eliminated `TODO` comments in the target type definitions. We kept the restriction that rolling min/max requires fixed width types because min/max aggregations do support non-fixed width in other aggregation implementations (groupby does a argmin-and-gather approach on strings, for instance). This PR is collaborative work with @karthikeyann. Authors: - Bradley Dice (https://github.com/bdice) - Karthikeyan (https://github.com/karthikeyann) Approvers: - Mark Harris (https://github.com/harrism) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/12503 --- .../cudf/detail/aggregation/aggregation.cuh | 6 +++--- .../cudf/detail/aggregation/aggregation.hpp | 21 +++++++++---------- cpp/src/rolling/detail/rolling.cuh | 11 ++-------- cpp/tests/rolling/empty_input_test.cpp | 20 +++++++++++++++--- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/cpp/include/cudf/detail/aggregation/aggregation.cuh b/cpp/include/cudf/detail/aggregation/aggregation.cuh index 818e8cd7cc6..f13166d5321 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.cuh +++ b/cpp/include/cudf/detail/aggregation/aggregation.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -232,8 +232,8 @@ struct update_target_element< aggregation::SUM, target_has_nulls, source_has_nulls, - std::enable_if_t() && cudf::has_atomic_support() && - !is_fixed_point()>> { + std::enable_if_t() && cudf::has_atomic_support() && + !cudf::is_fixed_point() && !cudf::is_timestamp()>> { __device__ void operator()(mutable_column_device_view target, size_type target_index, column_device_view source, diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 75027c78a68..360c314f2db 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1154,9 +1154,7 @@ struct target_type_impl { using type = bool; }; -// Always use `double` for MEAN -// Except for chrono types where result is chrono. (Use FloorDiv) -// TODO: MEAN should be only be enabled for duration types - not for timestamps +// Always use `double` for MEAN except for durations and fixed point types. template struct target_type_impl< Source, @@ -1167,10 +1165,10 @@ struct target_type_impl< }; template -struct target_type_impl< - Source, - k, - std::enable_if_t<(is_chrono() or is_fixed_point()) && (k == aggregation::MEAN)>> { +struct target_type_impl() or is_fixed_point()) && + (k == aggregation::MEAN)>> { using type = Source; }; @@ -1206,10 +1204,11 @@ struct target_type_impl< using type = Source; }; -// Summing/Multiplying chrono types, use same type accumulator -// TODO: Sum/Product should only be enabled for duration types - not for timestamps +// Summing duration types, use same type accumulator template -struct target_type_impl() && is_sum_product_agg(k)>> { +struct target_type_impl() && (k == aggregation::SUM)>> { using type = Source; }; diff --git a/cpp/src/rolling/detail/rolling.cuh b/cpp/src/rolling/detail/rolling.cuh index fcc85b4f913..d996f88ca49 100644 --- a/cpp/src/rolling/detail/rolling.cuh +++ b/cpp/src/rolling/detail/rolling.cuh @@ -84,16 +84,9 @@ struct DeviceRolling { static constexpr bool is_supported() { return cudf::detail::is_valid_aggregation() && has_corresponding_operator() && - // TODO: Delete all this extra logic once is_valid_aggregation<> cleans up some edge - // cases it isn't handling. - // MIN/MAX supports all fixed width types + // MIN/MAX only supports fixed width types (((O == aggregation::MIN || O == aggregation::MAX) && cudf::is_fixed_width()) || - - // SUM supports all fixed width types except timestamps - ((O == aggregation::SUM) && (cudf::is_fixed_width() && !cudf::is_timestamp())) || - - // MEAN supports numeric and duration - ((O == aggregation::MEAN) && (cudf::is_numeric() || cudf::is_duration()))); + (O == aggregation::SUM) || (O == aggregation::MEAN)); } // operations we do support diff --git a/cpp/tests/rolling/empty_input_test.cpp b/cpp/tests/rolling/empty_input_test.cpp index 626563d5eba..aca1cbf40b7 100644 --- a/cpp/tests/rolling/empty_input_test.cpp +++ b/cpp/tests/rolling/empty_input_test.cpp @@ -183,22 +183,36 @@ TYPED_TEST(TypedRollingEmptyInputTest, EmptyFixedWidthInputs) /// `SUM` returns 64-bit promoted types for integral/decimal input. /// For other fixed-width input types, the same type is returned. + /// Timestamp types are not supported. { auto aggs = agg_vector_t{}; aggs.emplace_back(sum()); using expected_type = cudf::detail::target_type_t; - rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + if constexpr (cudf::is_timestamp()) { + EXPECT_THROW( + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()), + cudf::logic_error); + } else { + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + } } /// `MEAN` returns float64 for all numeric types, - /// except for chrono-types, which yield the same chrono-type. + /// except for duration-types, which yield the same duration-type. + /// Timestamp types are not supported. { auto aggs = agg_vector_t{}; aggs.emplace_back(mean()); using expected_type = cudf::detail::target_type_t; - rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + if constexpr (cudf::is_timestamp()) { + EXPECT_THROW( + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()), + cudf::logic_error); + } else { + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + } } /// For an input type `T`, `COLLECT_LIST` returns a column of type `list`. From ed6daad3a7edca6b8873b1105a585e12757dcf1c Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 24 Jan 2023 07:31:48 -0600 Subject: [PATCH 2/7] Fix SUM/MEAN aggregation type support. (#12503) This PR closes #8399. We cleaned up the logic by fixing SUM/MEAN aggregation type support, which also eliminated `TODO` comments in the target type definitions. We kept the restriction that rolling min/max requires fixed width types because min/max aggregations do support non-fixed width in other aggregation implementations (groupby does a argmin-and-gather approach on strings, for instance). This PR is collaborative work with @karthikeyann. Authors: - Bradley Dice (https://github.com/bdice) - Karthikeyan (https://github.com/karthikeyann) Approvers: - Mark Harris (https://github.com/harrism) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/12503 From 2784f5890d107cbace7c6059ca25504fb8569e43 Mon Sep 17 00:00:00 2001 From: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com> Date: Wed, 25 Jan 2023 14:17:03 -0600 Subject: [PATCH 3/7] Parquet reader optimization to address V100 regression. (#12577) Addresses https://github.com/rapidsai/cudf/issues/12316 Some recent changes caused a performance regression in the parquet reader benchmarks for lists. The culprit ended up being slightly different code generation happening for arch 70. In several memory hotspots, the code was reading values from global, modifying them and then storing them. Previously it had done a better job of loading and keeping them in registers and the L2 cache was helping keep things fast. But the extra store was causing twice as many L2 access in these places and causing many long scoreboard stalls. Ultimately the issue is that these values shouldn't be kept in global memory. The initial implementation did it this way because the data was variable in size (based on depth of column nesting). But in practice, we never see more than 2 or 3 levels of nesting. So the solution is: - Keep these values (in a struct called `PageNestingDecodeInfo`) that is kept in shared memory for up to N nesting levels. N is currently 10. - If the nesting information for the incoming column fits in the cache, use it. Otherwise fall back to the arrays in global memory. In practice, it is exceedingly rare to see columns nested >= 10 deep. This addresses the performance regression and actually gives some performance increases. Some comparisons for LIST benchmarks. ``` cudf 22.10 (prior to regression) | data_type | cardinality | run_length | bytes_per_second | |-----------|-------------|------------|------------------| | LIST | 0 | 1 | 892901208 | | LIST | 1000 | 1 | 952863876 | | LIST | 0 | 32 | 1246033395 | | LIST | 1000 | 32 | 1232884866 | ``` ``` cudf 22.12 (where the regression occurred) | data_type | cardinality | run_length | bytes_per_second | |-----------|-------------|------------|------------------| | LIST | 0 | 1 | 747758436 | | LIST | 1000 | 1 | 827763260 | | LIST | 0 | 32 | 1026048576 | | LIST | 1000 | 32 | 1022928119 | ``` ``` This PR | data_type | cardinality | run_length | bytes_per_second | |-----------|-------------|------------|------------------| | LIST | 0 | 1 | 927347737 | | LIST | 1000 | 1 | 1024566150 | | LIST | 0 | 32 | 1315972881 | | LIST | 1000 | 32 | 1303995168 | ``` Authors: - https://github.com/nvdbaranec Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/12577 --- cpp/src/io/parquet/page_data.cu | 204 ++++++++++++------- cpp/src/io/parquet/parquet_gpu.hpp | 69 ++++--- cpp/src/io/parquet/reader_impl.cpp | 16 +- cpp/src/io/parquet/reader_impl_preprocess.cu | 27 ++- cpp/tests/io/parquet_test.cpp | 44 ++++ 5 files changed, 245 insertions(+), 115 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 70176392ee9..23d130e1585 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -90,6 +90,13 @@ struct page_state_s { const uint8_t* lvl_start[NUM_LEVEL_TYPES]; // [def,rep] int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded int32_t row_index_lower_bound; // lower bound of row indices we should process + + // a shared-memory cache of frequently used data when decoding. The source of this data is + // normally stored in global memory which can yield poor performance. So, when possible + // we copy that info here prior to decoding + PageNestingDecodeInfo nesting_decode_cache[max_cacheable_nesting_decode_info]; + // points to either nesting_decode_cache above when possible, or to the global source otherwise + PageNestingDecodeInfo* nesting_info; }; /** @@ -927,23 +934,49 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, int chunk_idx; // Fetch page info - if (t == 0) s->page = *p; + if (!t) s->page = *p; __syncthreads(); if (s->page.flags & PAGEINFO_FLAGS_DICTIONARY) { return false; } // Fetch column chunk info chunk_idx = s->page.chunk_idx; - if (t == 0) { s->col = chunks[chunk_idx]; } - - // zero nested value and valid counts - int d = 0; - while (d < s->page.num_output_nesting_levels) { - if (d + t < s->page.num_output_nesting_levels) { - s->page.nesting[d + t].valid_count = 0; - s->page.nesting[d + t].value_count = 0; - s->page.nesting[d + t].null_count = 0; + if (!t) { s->col = chunks[chunk_idx]; } + + // if we can use the decode cache, set it up now + auto const can_use_decode_cache = s->page.nesting_info_size <= max_cacheable_nesting_decode_info; + if (can_use_decode_cache) { + int depth = 0; + while (depth < s->page.nesting_info_size) { + int const thread_depth = depth + t; + if (thread_depth < s->page.nesting_info_size) { + // these values need to be copied over from global + s->nesting_decode_cache[thread_depth].max_def_level = + s->page.nesting_decode[thread_depth].max_def_level; + s->nesting_decode_cache[thread_depth].page_start_value = + s->page.nesting_decode[thread_depth].page_start_value; + s->nesting_decode_cache[thread_depth].start_depth = + s->page.nesting_decode[thread_depth].start_depth; + s->nesting_decode_cache[thread_depth].end_depth = + s->page.nesting_decode[thread_depth].end_depth; + } + depth += blockDim.x; + } + } + if (!t) { + s->nesting_info = can_use_decode_cache ? s->nesting_decode_cache : s->page.nesting_decode; + } + __syncthreads(); + + // zero counts + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + int const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->nesting_info[thread_depth].valid_count = 0; + s->nesting_info[thread_depth].value_count = 0; + s->nesting_info[thread_depth].null_count = 0; } - d += blockDim.x; + depth += blockDim.x; } __syncthreads(); @@ -1076,7 +1109,7 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, if (is_decode_step) { int max_depth = s->col.max_nesting_depth; for (int idx = 0; idx < max_depth; idx++) { - PageNestingInfo* pni = &s->page.nesting[idx]; + PageNestingDecodeInfo* nesting_info = &s->nesting_info[idx]; size_t output_offset; // schemas without lists @@ -1085,21 +1118,21 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, } // for schemas with lists, we've already got the exact value precomputed else { - output_offset = pni->page_start_value; + output_offset = nesting_info->page_start_value; } - pni->data_out = static_cast(s->col.column_data_base[idx]); + nesting_info->data_out = static_cast(s->col.column_data_base[idx]); - if (pni->data_out != nullptr) { + if (nesting_info->data_out != nullptr) { // anything below max depth with a valid data pointer must be a list, so the // element size is the size of the offset type. uint32_t len = idx < max_depth - 1 ? sizeof(cudf::size_type) : s->dtype_len; - pni->data_out += (output_offset * len); + nesting_info->data_out += (output_offset * len); } - pni->valid_map = s->col.valid_map_base[idx]; - if (pni->valid_map != nullptr) { - pni->valid_map += output_offset >> 5; - pni->valid_map_offset = (int32_t)(output_offset & 0x1f); + nesting_info->valid_map = s->col.valid_map_base[idx]; + if (nesting_info->valid_map != nullptr) { + nesting_info->valid_map += output_offset >> 5; + nesting_info->valid_map_offset = (int32_t)(output_offset & 0x1f); } } } @@ -1217,26 +1250,26 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, * @brief Store a validity mask containing value_count bits into the output validity buffer of the * page. * - * @param[in,out] pni The page/nesting information to store the mask in. The validity map offset is - * also updated + * @param[in,out] nesting_info The page/nesting information to store the mask in. The validity map + * offset is also updated * @param[in] valid_mask The validity mask to be stored * @param[in] value_count # of bits in the validity mask */ -static __device__ void store_validity(PageNestingInfo* pni, +static __device__ void store_validity(PageNestingDecodeInfo* nesting_info, uint32_t valid_mask, int32_t value_count) { - int word_offset = pni->valid_map_offset / 32; - int bit_offset = pni->valid_map_offset % 32; + int word_offset = nesting_info->valid_map_offset / 32; + int bit_offset = nesting_info->valid_map_offset % 32; // if we fit entirely in the output word if (bit_offset + value_count <= 32) { auto relevant_mask = static_cast((static_cast(1) << value_count) - 1); if (relevant_mask == ~0) { - pni->valid_map[word_offset] = valid_mask; + nesting_info->valid_map[word_offset] = valid_mask; } else { - atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(pni->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); + atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(nesting_info->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); } } // we're going to spill over into the next word. @@ -1250,17 +1283,17 @@ static __device__ void store_validity(PageNestingInfo* pni, // first word. strip bits_left bits off the beginning and store that uint32_t relevant_mask = ((1 << bits_left) - 1); uint32_t mask_word0 = valid_mask & relevant_mask; - atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(pni->valid_map + word_offset, mask_word0 << bit_offset); + atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(nesting_info->valid_map + word_offset, mask_word0 << bit_offset); // second word. strip the remainder of the bits off the end and store that relevant_mask = ((1 << (value_count - bits_left)) - 1); uint32_t mask_word1 = valid_mask & (relevant_mask << bits_left); - atomicAnd(pni->valid_map + word_offset + 1, ~(relevant_mask)); - atomicOr(pni->valid_map + word_offset + 1, mask_word1 >> bits_left); + atomicAnd(nesting_info->valid_map + word_offset + 1, ~(relevant_mask)); + atomicOr(nesting_info->valid_map + word_offset + 1, mask_word1 >> bits_left); } - pni->valid_map_offset += value_count; + nesting_info->valid_map_offset += value_count; } /** @@ -1294,8 +1327,8 @@ inline __device__ void get_nesting_bounds(int& start_depth, // bound what nesting levels we apply values to if (s->col.max_level[level_type::REPETITION] > 0) { int r = s->rep[index]; - start_depth = s->page.nesting[r].start_depth; - end_depth = s->page.nesting[d].end_depth; + start_depth = s->nesting_info[r].start_depth; + end_depth = s->nesting_info[d].end_depth; } // for columns without repetition (even ones involving structs) we always // traverse the entire hierarchy. @@ -1326,6 +1359,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // how many rows we've processed in the page so far int input_row_count = s->input_row_count; + PageNestingDecodeInfo* nesting_info_base = s->nesting_info; + // process until we've reached the target while (input_value_count < target_input_value_count) { // determine the nesting bounds for this thread (the range of nesting depths we @@ -1367,14 +1402,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // walk from 0 to max_depth uint32_t next_thread_value_count, next_warp_value_count; for (int s_idx = 0; s_idx < max_depth; s_idx++) { - PageNestingInfo* pni = &s->page.nesting[s_idx]; + PageNestingDecodeInfo* nesting_info = &nesting_info_base[s_idx]; // if we are within the range of nesting levels we should be adding value indices for int const in_nesting_bounds = ((s_idx >= start_depth && s_idx <= end_depth) && in_row_bounds) ? 1 : 0; // everything up to the max_def_level is a non-null value - uint32_t const is_valid = d >= pni->max_def_level && in_nesting_bounds ? 1 : 0; + uint32_t const is_valid = d >= nesting_info->max_def_level && in_nesting_bounds ? 1 : 0; // compute warp and thread valid counts uint32_t const warp_valid_mask = @@ -1395,8 +1430,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // if this is the value column emit an index for value decoding if (is_valid && s_idx == max_depth - 1) { - int const src_pos = pni->valid_count + thread_valid_count; - int const dst_pos = pni->value_count + thread_value_count; + int const src_pos = nesting_info->valid_count + thread_valid_count; + int const dst_pos = nesting_info->value_count + thread_value_count; // nz_idx is a mapping of src buffer indices to destination buffer indices s->nz_idx[rolling_index(src_pos)] = dst_pos; } @@ -1414,12 +1449,12 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // if we're -not- at a leaf column and we're within nesting/row bounds // and we have a valid data_out pointer, it implies this is a list column, so // emit an offset. - if (in_nesting_bounds && pni->data_out != nullptr) { - int const idx = pni->value_count + thread_value_count; - cudf::size_type const ofs = s->page.nesting[s_idx + 1].value_count + + if (in_nesting_bounds && nesting_info->data_out != nullptr) { + int const idx = nesting_info->value_count + thread_value_count; + cudf::size_type const ofs = nesting_info_base[s_idx + 1].value_count + next_thread_value_count + - s->page.nesting[s_idx + 1].page_start_value; - (reinterpret_cast(pni->data_out))[idx] = ofs; + nesting_info_base[s_idx + 1].page_start_value; + (reinterpret_cast(nesting_info->data_out))[idx] = ofs; } } @@ -1441,14 +1476,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // increment count of valid values, count of total values, and update validity mask if (!t) { - if (pni->valid_map != nullptr && warp_valid_mask_bit_count > 0) { + if (nesting_info->valid_map != nullptr && warp_valid_mask_bit_count > 0) { uint32_t const warp_output_valid_mask = warp_valid_mask >> first_thread_in_write_range; - store_validity(pni, warp_output_valid_mask, warp_valid_mask_bit_count); + store_validity(nesting_info, warp_output_valid_mask, warp_valid_mask_bit_count); - pni->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask); + nesting_info->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask); } - pni->valid_count += warp_valid_count; - pni->value_count += warp_value_count; + nesting_info->valid_count += warp_valid_count; + nesting_info->value_count += warp_value_count; } // propagate value counts for the next level @@ -1463,7 +1498,7 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // update if (!t) { // update valid value count for decoding and total # of values we've processed - s->nz_count = s->page.nesting[max_depth - 1].valid_count; + s->nz_count = nesting_info_base[max_depth - 1].valid_count; s->input_value_count = input_value_count; s->input_row_count = input_row_count; } @@ -1545,7 +1580,7 @@ static __device__ void gpuUpdatePageSizes(page_state_s* s, // count rows and leaf values int const is_new_row = start_depth == 0 ? 1 : 0; uint32_t const warp_row_count_mask = ballot(is_new_row); - int const is_new_leaf = (d >= s->page.nesting[max_depth - 1].max_def_level) ? 1 : 0; + int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level) ? 1 : 0; uint32_t const warp_leaf_count_mask = ballot(is_new_leaf); // is this thread within row bounds? on the first pass we don't know the bounds, so we will be // computing the full size of the column. on the second pass, we will know our actual row @@ -1673,14 +1708,14 @@ __global__ void __launch_bounds__(block_size) // to do the expensive work of traversing the level data to determine sizes. we can just compute // it directly. if (!has_repetition && !compute_string_sizes) { - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { - if (is_base_pass) { pp->nesting[i].size = pp->num_input_values; } - pp->nesting[i].batch_size = pp->num_input_values; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + if (is_base_pass) { pp->nesting[thread_depth].size = pp->num_input_values; } + pp->nesting[thread_depth].batch_size = pp->num_input_values; } - d += blockDim.x; + depth += blockDim.x; } return; } @@ -1688,25 +1723,29 @@ __global__ void __launch_bounds__(block_size) // in the trim pass, for anything with lists, we only need to fully process bounding pages (those // at the beginning or the end of the row bounds) if (!is_base_pass && !is_bounds_page(s, min_row, num_rows)) { - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { // if we are not a bounding page (as checked above) then we are either // returning 0 rows from the page (completely outside the bounds) or all // rows in the page (completely within the bounds) - pp->nesting[i].batch_size = s->num_rows == 0 ? 0 : pp->nesting[i].size; + pp->nesting[thread_depth].batch_size = + s->num_rows == 0 ? 0 : pp->nesting[thread_depth].size; } - d += blockDim.x; + depth += blockDim.x; } return; } // zero sizes - int d = 0; - while (d < s->page.num_output_nesting_levels) { - if (d + t < s->page.num_output_nesting_levels) { s->page.nesting[d + t].batch_size = 0; } - d += blockDim.x; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->page.nesting[thread_depth].batch_size = 0; + } + depth += blockDim.x; } __syncthreads(); @@ -1754,13 +1793,13 @@ __global__ void __launch_bounds__(block_size) if (!t) { pp->num_rows = s->page.nesting[0].batch_size; } // store off this batch size as the "full" size - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { - pp->nesting[i].size = pp->nesting[i].batch_size; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + pp->nesting[thread_depth].size = pp->nesting[thread_depth].batch_size; } - d += blockDim.x; + depth += blockDim.x; } } @@ -1808,6 +1847,8 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( ((s->col.data_type & 7) == BOOLEAN || (s->col.data_type & 7) == BYTE_ARRAY) ? 64 : 32; } + PageNestingDecodeInfo* nesting_info_base = s->nesting_info; + // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t skipped_leaf_values = s->page.skipped_leaf_values; while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) { @@ -1876,7 +1917,7 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( uint32_t dtype_len = s->dtype_len; void* dst = - s->page.nesting[leaf_level_index].data_out + static_cast(dst_pos) * dtype_len; + nesting_info_base[leaf_level_index].data_out + static_cast(dst_pos) * dtype_len; if (dtype == BYTE_ARRAY) { if (s->col.converted_type == DECIMAL) { auto const [ptr, len] = gpuGetStringData(s, val_src_pos); @@ -1931,6 +1972,19 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( } __syncthreads(); } + + // if we are using the nesting decode cache, copy null count back + if (s->nesting_info == s->nesting_decode_cache) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + int const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->page.nesting_decode[thread_depth].null_count = + s->nesting_decode_cache[thread_depth].null_count; + } + depth += blockDim.x; + } + } } } // anonymous namespace diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 33a189cdf87..9b156745e41 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -83,6 +83,40 @@ enum level_type { NUM_LEVEL_TYPES }; +/** + * @brief Nesting information specifically needed by the decode and preprocessing + * kernels. + * + * This data is kept separate from PageNestingInfo to keep it as small as possible. + * It is used in a cached form in shared memory when possible. + */ +struct PageNestingDecodeInfo { + // set up prior to decoding + int32_t max_def_level; + // input repetition/definition levels are remapped with these values + // into the corresponding real output nesting depths. + int32_t start_depth; + int32_t end_depth; + + // computed during preprocessing + int32_t page_start_value; + + // computed during decoding + int32_t null_count; + + // used internally during decoding + int32_t valid_map_offset; + int32_t valid_count; + int32_t value_count; + uint8_t* data_out; + bitmask_type* valid_map; +}; + +// Use up to 512 bytes of shared memory as a cache for nesting information. +// As of 1/20/23, this gives us a max nesting depth of 10 (after which it falls back to +// global memory). This handles all but the most extreme cases. +constexpr int max_cacheable_nesting_decode_info = (512) / sizeof(PageNestingDecodeInfo); + /** * @brief Nesting information * @@ -94,30 +128,15 @@ enum level_type { * */ struct PageNestingInfo { - // input repetition/definition levels are remapped with these values - // into the corresponding real output nesting depths. - int32_t start_depth; - int32_t end_depth; - - // set at initialization - int32_t max_def_level; - int32_t max_rep_level; + // set at initialization (see start_offset_output_iterator in reader_impl_preprocess.cu) cudf::type_id type; // type of the corresponding cudf output column bool nullable; - // set during preprocessing + // TODO: these fields might make sense to move into PageNestingDecodeInfo for memory performance + // reasons. int32_t size; // this page/nesting-level's row count contribution to the output column, if fully // decoded - int32_t batch_size; // the size of the page for this batch - int32_t page_start_value; // absolute output start index in output column data - - // set during data decoding - int32_t valid_count; // # of valid values decoded in this page/nesting-level - int32_t value_count; // total # of values decoded in this page/nesting-level - int32_t null_count; // null count - int32_t valid_map_offset; // current offset in bits relative to valid_map - uint8_t* data_out; // pointer into output buffer - uint32_t* valid_map; // pointer into output validity buffer + int32_t batch_size; // the size of the page for this batch }; /** @@ -159,9 +178,9 @@ struct PageInfo { // skipped_leaf_values will always be 0. // // # of values skipped in the repetition/definition level stream - int skipped_values; + int32_t skipped_values; // # of values skipped in the actual data stream. - int skipped_leaf_values; + int32_t skipped_leaf_values; // for string columns only, the size of all the chars in the string for // this page. only valid/computed during the base preprocess pass int32_t str_bytes; @@ -170,9 +189,10 @@ struct PageInfo { // input column nesting information, output column nesting information and // mappings between the two. the length of the array, nesting_info_size is // max(num_output_nesting_levels, max_definition_levels + 1) - int num_output_nesting_levels; - int nesting_info_size; + int32_t num_output_nesting_levels; + int32_t nesting_info_size; PageNestingInfo* nesting; + PageNestingDecodeInfo* nesting_decode; }; /** @@ -242,7 +262,7 @@ struct ColumnChunkDesc { PageInfo* page_info; // output page info for up to num_dict_pages + // num_data_pages (dictionary pages first) string_index_pair* str_dict_index; // index for string dictionary - uint32_t** valid_map_base; // base pointers of valid bit map for this column + bitmask_type** valid_map_base; // base pointers of valid bit map for this column void** column_data_base; // base pointers of column data int8_t codec; // compressed codec enum int8_t converted_type; // converted type enum @@ -263,6 +283,7 @@ struct file_intermediate_data { hostdevice_vector chunks{}; hostdevice_vector pages_info{}; hostdevice_vector page_nesting_info{}; + hostdevice_vector page_nesting_decode_info{}; }; /** diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index d5dac10b8f6..b1c4dd22c0d 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -24,9 +24,10 @@ namespace cudf::io::detail::parquet { void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) { - auto& chunks = _file_itm_data.chunks; - auto& pages = _file_itm_data.pages_info; - auto& page_nesting = _file_itm_data.page_nesting_info; + auto& chunks = _file_itm_data.chunks; + auto& pages = _file_itm_data.pages_info; + auto& page_nesting = _file_itm_data.page_nesting_info; + auto& page_nesting_decode = _file_itm_data.page_nesting_decode_info; // Should not reach here if there is no page data. CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); @@ -39,7 +40,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // In order to reduce the number of allocations of hostdevice_vector, we allocate a single vector // to store all per-chunk pointers to nested data/nullmask. `chunk_offsets[i]` will store the // offset into `chunk_nested_data`/`chunk_nested_valids` for the array of pointers for chunk `i` - auto chunk_nested_valids = hostdevice_vector(sum_max_depths, _stream); + auto chunk_nested_valids = hostdevice_vector(sum_max_depths, _stream); auto chunk_nested_data = hostdevice_vector(sum_max_depths, _stream); auto chunk_offsets = std::vector(); @@ -124,6 +125,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) pages.device_to_host(_stream); page_nesting.device_to_host(_stream); + page_nesting_decode.device_to_host(_stream); _stream.synchronize(); // for list columns, add the final offset to every offset buffer. @@ -166,8 +168,8 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) gpu::ColumnChunkDesc* col = &chunks[pi->chunk_idx]; input_column_info const& input_col = _input_columns[col->src_col_index]; - int index = pi->nesting - page_nesting.device_ptr(); - gpu::PageNestingInfo* pni = &page_nesting[index]; + int index = pi->nesting_decode - page_nesting_decode.device_ptr(); + gpu::PageNestingDecodeInfo* pndi = &page_nesting_decode[index]; auto* cols = &_output_buffers; for (size_t l_idx = 0; l_idx < input_col.nesting_depth(); l_idx++) { @@ -178,7 +180,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) if (chunk_nested_valids.host_ptr(chunk_offsets[pi->chunk_idx])[l_idx] == nullptr) { continue; } - out_buf.null_count() += pni[l_idx].null_count; + out_buf.null_count() += pndi[l_idx].null_count; } } diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 435fdb1a411..6577a1a3f0f 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -523,9 +523,10 @@ void decode_page_headers(hostdevice_vector& chunks, void reader::impl::allocate_nesting_info() { - auto const& chunks = _file_itm_data.chunks; - auto& pages = _file_itm_data.pages_info; - auto& page_nesting_info = _file_itm_data.page_nesting_info; + auto const& chunks = _file_itm_data.chunks; + auto& pages = _file_itm_data.pages_info; + auto& page_nesting_info = _file_itm_data.page_nesting_info; + auto& page_nesting_decode_info = _file_itm_data.page_nesting_decode_info; // compute total # of page_nesting infos needed and allocate space. doing this in one // buffer to keep it to a single gpu allocation @@ -539,6 +540,8 @@ void reader::impl::allocate_nesting_info() }); page_nesting_info = hostdevice_vector{total_page_nesting_infos, _stream}; + page_nesting_decode_info = + hostdevice_vector{total_page_nesting_infos, _stream}; // retrieve from the gpu so we can update pages.device_to_host(_stream, true); @@ -556,6 +559,9 @@ void reader::impl::allocate_nesting_info() target_page_index += chunks[idx].num_dict_pages; for (int p_idx = 0; p_idx < chunks[idx].num_data_pages; p_idx++) { pages[target_page_index + p_idx].nesting = page_nesting_info.device_ptr() + src_info_index; + pages[target_page_index + p_idx].nesting_decode = + page_nesting_decode_info.device_ptr() + src_info_index; + pages[target_page_index + p_idx].nesting_info_size = per_page_nesting_info_size; pages[target_page_index + p_idx].num_output_nesting_levels = _metadata->get_output_nesting_depth(src_col_schema); @@ -601,6 +607,9 @@ void reader::impl::allocate_nesting_info() gpu::PageNestingInfo* pni = &page_nesting_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + gpu::PageNestingDecodeInfo* nesting_info = + &page_nesting_decode_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + // if we have lists, set our start and end depth remappings if (schema.max_repetition_level > 0) { auto remap = depth_remapping.find(src_col_schema); @@ -610,17 +619,16 @@ void reader::impl::allocate_nesting_info() std::vector const& def_depth_remap = (remap->second.second); for (size_t m = 0; m < rep_depth_remap.size(); m++) { - pni[m].start_depth = rep_depth_remap[m]; + nesting_info[m].start_depth = rep_depth_remap[m]; } for (size_t m = 0; m < def_depth_remap.size(); m++) { - pni[m].end_depth = def_depth_remap[m]; + nesting_info[m].end_depth = def_depth_remap[m]; } } // values indexed by output column index - pni[cur_depth].max_def_level = cur_schema.max_definition_level; - pni[cur_depth].max_rep_level = cur_schema.max_repetition_level; - pni[cur_depth].size = 0; + nesting_info[cur_depth].max_def_level = cur_schema.max_definition_level; + pni[cur_depth].size = 0; pni[cur_depth].type = to_type_id(cur_schema, _strings_to_categorical, _timestamp_type.id()); pni[cur_depth].nullable = cur_schema.repetition_type == OPTIONAL; @@ -640,6 +648,7 @@ void reader::impl::allocate_nesting_info() // copy nesting info to the device page_nesting_info.host_to_device(_stream); + page_nesting_decode_info.host_to_device(_stream); } void reader::impl::load_and_decompress_data(std::vector const& row_groups_info, @@ -1256,7 +1265,7 @@ struct start_offset_output_iterator { if (p.src_col_schema != src_col_schema || p.flags & gpu::PAGEINFO_FLAGS_DICTIONARY) { return empty; } - return p.nesting[nesting_depth].page_start_value; + return p.nesting_decode[nesting_depth].page_start_value; } }; diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 2cd6e49d7bb..21752196430 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -38,6 +38,7 @@ #include #include +#include #include @@ -4826,6 +4827,49 @@ TEST_F(ParquetReaderTest, StructByteArray) CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); } +TEST_F(ParquetReaderTest, NestingOptimizationTest) +{ + // test nesting levels > cudf::io::parquet::gpu::max_cacheable_nesting_decode_info deep. + constexpr cudf::size_type num_nesting_levels = 16; + static_assert(num_nesting_levels > cudf::io::parquet::gpu::max_cacheable_nesting_decode_info); + constexpr cudf::size_type rows_per_level = 2; + + constexpr cudf::size_type num_values = (1 << num_nesting_levels) * rows_per_level; + auto value_iter = thrust::make_counting_iterator(0); + auto validity = + cudf::detail::make_counting_transform_iterator(0, [](cudf::size_type i) { return i % 2; }); + cudf::test::fixed_width_column_wrapper values(value_iter, value_iter + num_values, validity); + + // ~256k values with num_nesting_levels = 16 + int total_values_produced = num_values; + auto prev_col = values.release(); + for (int idx = 0; idx < num_nesting_levels; idx++) { + auto const depth = num_nesting_levels - idx; + auto const num_rows = (1 << (num_nesting_levels - idx)); + + auto offsets_iter = cudf::detail::make_counting_transform_iterator( + 0, [depth, rows_per_level](cudf::size_type i) { return i * rows_per_level; }); + total_values_produced += (num_rows + 1); + + cudf::test::fixed_width_column_wrapper offsets(offsets_iter, + offsets_iter + num_rows + 1); + auto c = cudf::make_lists_column(num_rows, offsets.release(), std::move(prev_col), 0, {}); + prev_col = std::move(c); + } + auto const& expect = prev_col; + + auto filepath = temp_env->get_temp_filepath("NestingDecodeCache.parquet"); + cudf::io::parquet_writer_options opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, table_view{{*expect}}); + cudf::io::write_parquet(opts); + + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expect, result.tbl->get_column(0)); +} + TEST_F(ParquetWriterTest, SingleValueDictionaryTest) { constexpr unsigned int expected_bits = 1; From ee937e5ebb22877a9b07b52def1d025542787022 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 25 Jan 2023 23:32:01 +0100 Subject: [PATCH 4/7] Handle when spillable buffers own each other (#12607) #12587 exposed a bug triggered when `as_buffer()` is given a pointer already owned by another spillable buffer. In this PR, we make this illegal and return a `SpillableBufferSlice` instead. cc. @galipremsagar ## Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/12607 --- .../cudf/cudf/core/buffer/spillable_buffer.py | 92 ++++++++++++++++++- python/cudf/cudf/core/buffer/utils.py | 10 +- python/cudf/cudf/tests/test_spilling.py | 34 +++++++ 3 files changed, 128 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 4f625e3b7c8..7ca85a307bf 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -7,7 +7,16 @@ import time import weakref from threading import RLock -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, +) import numpy @@ -16,6 +25,7 @@ from cudf.core.buffer.buffer import ( Buffer, cuda_array_interface_wrapper, + get_ptr_and_size, host_memory_allocation, ) from cudf.utils.string import format_bytes @@ -27,6 +37,86 @@ T = TypeVar("T", bound="SpillableBuffer") +def get_spillable_owner(data) -> Optional[SpillableBuffer]: + """Get the spillable owner of `data`, if any exist + + Search through the stack of data owners in order to find an + owner of type `SpillableBuffer` (not subclasses). + + Parameters + ---------- + data : buffer-like or array-like + A buffer-like or array-like object that represent C-contiguous memory. + + Return + ------ + SpillableBuffer or None + The owner of `data` if spillable or None. + """ + + if type(data) is SpillableBuffer: + return data + if hasattr(data, "owner"): + return get_spillable_owner(data.owner) + return None + + +def as_spillable_buffer(data, exposed: bool) -> SpillableBuffer: + """Factory function to wrap `data` in a SpillableBuffer object. + + If `data` isn't a buffer already, a new buffer that points to the memory of + `data` is created. If `data` represents host memory, it is copied to a new + `rmm.DeviceBuffer` device allocation. Otherwise, the memory of `data` is + **not** copied, instead the new buffer keeps a reference to `data` in order + to retain its lifetime. + + If `data` is owned by a spillable buffer, a "slice" of the buffer is + returned. In this case, the spillable buffer must either be "exposed" or + spilled locked (called within an acquire_spill_lock context). This is to + guarantee that the memory of `data` isn't spilled before this function gets + to calculate the offset of the new slice. + + It is illegal for a spillable buffer to own another spillable buffer. + + Parameters + ---------- + data : buffer-like or array-like + A buffer-like or array-like object that represent C-contiguous memory. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). + + Return + ------ + SpillableBuffer + A spillabe buffer instance that represents the device memory of `data`. + """ + + from cudf.core.buffer.utils import get_spill_lock + + if not hasattr(data, "__cuda_array_interface__"): + if exposed: + raise ValueError("cannot created exposed host memory") + return SpillableBuffer._from_host_memory(data) + + spillable_owner = get_spillable_owner(data) + if spillable_owner is None: + return SpillableBuffer._from_device_memory(data, exposed=exposed) + + if not spillable_owner.exposed and get_spill_lock() is None: + raise ValueError( + "A owning spillable buffer must " + "either be exposed or spilled locked." + ) + + # At this point, we know that `data` is owned by a spillable buffer, + # which is exposed or spilled locked. + ptr, size = get_ptr_and_size(data.__cuda_array_interface__) + base_ptr = spillable_owner.memory_info()[0] + return SpillableBufferSlice( + spillable_owner, offset=ptr - base_ptr, size=size + ) + + class SpillLock: pass diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index 062e86d0cb1..fac28f52b64 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -8,7 +8,7 @@ from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper from cudf.core.buffer.spill_manager import get_global_manager -from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.core.buffer.spillable_buffer import SpillLock, as_spillable_buffer def as_buffer( @@ -72,11 +72,7 @@ def as_buffer( ) if get_global_manager() is not None: - if hasattr(data, "__cuda_array_interface__"): - return SpillableBuffer._from_device_memory(data, exposed=exposed) - if exposed: - raise ValueError("cannot created exposed host memory") - return SpillableBuffer._from_host_memory(data) + return as_spillable_buffer(data, exposed=exposed) if hasattr(data, "__cuda_array_interface__"): return Buffer._from_device_memory(data) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 4788736966a..bafe51b62ec 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -540,6 +540,40 @@ def test_df_transpose(manager: SpillManager): assert df2._data._data[1].data.exposed +def test_as_buffer_of_spillable_buffer(manager: SpillManager): + data = cupy.arange(10, dtype="u1") + b1 = as_buffer(data, exposed=False) + assert isinstance(b1, SpillableBuffer) + assert b1.owner is data + b2 = as_buffer(b1) + assert b1 is b2 + + with pytest.raises( + ValueError, + match="buffer must either be exposed or spilled locked", + ): + # Use `memory_info` to access device point _without_ making + # the buffer unspillable. + b3 = as_buffer(b1.memory_info()[0], size=b1.size, owner=b1) + + with acquire_spill_lock(): + b3 = as_buffer(b1.get_ptr(), size=b1.size, owner=b1) + assert isinstance(b3, SpillableBufferSlice) + assert b3.owner is b1 + + b4 = as_buffer( + b1.ptr + data.itemsize, size=b1.size - data.itemsize, owner=b3 + ) + assert isinstance(b4, SpillableBufferSlice) + assert b4.owner is b1 + assert all(cupy.array(b4.memoryview()) == data[1:]) + + b5 = as_buffer(b4.ptr, size=b4.size - 1, owner=b4) + assert isinstance(b5, SpillableBufferSlice) + assert b5.owner is b1 + assert all(cupy.array(b5.memoryview()) == data[1:-1]) + + @pytest.mark.parametrize("dtype", ["uint8", "uint64"]) def test_memoryview_slice(manager: SpillManager, dtype): """Check .memoryview() of a sliced spillable buffer""" From 35e90ff80d1e9b5b835b4a2195ec11c7e159fcfc Mon Sep 17 00:00:00 2001 From: Sevag H Date: Wed, 25 Jan 2023 18:51:46 -0500 Subject: [PATCH 5/7] Use CTK 118/cp310 branch of wheel workflows (#12602) This PR builds wheels using the cuda-118 branch of the workflows, which bumps CTK 11.5.1 to CTK 11.8.0 and cp39 to cp310. Authors: - Sevag H (https://github.com/sevagh) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/cudf/pull/12602 --- .github/workflows/build.yaml | 8 ++++---- .github/workflows/pr.yaml | 12 +++++++----- .github/workflows/test.yaml | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 093063c21e7..5c4328cb7a2 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -54,7 +54,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -66,7 +66,7 @@ jobs: wheel-publish-cudf: needs: wheel-build-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -76,7 +76,7 @@ jobs: wheel-build-dask-cudf: needs: wheel-publish-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -87,7 +87,7 @@ jobs: wheel-publish-dask-cudf: needs: wheel-build-dask-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-publish.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-publish.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index cb83aab31cd..89e14d3e421 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -85,7 +85,7 @@ jobs: wheel-build-cudf: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 with: build_type: pull-request package-name: cudf @@ -94,17 +94,19 @@ jobs: wheel-tests-cudf: needs: wheel-build-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 with: build_type: pull-request package-name: cudf + # Install cupy-cuda11x for arm from a special index url + # Install tokenizers last binary wheel to avoid a Rust compile from the latest sdist test-before-arm64: "pip install tokenizers==0.10.2 cupy-cuda11x -f https://pip.cupy.dev/aarch64" test-unittest: "pytest -v -n 8 ./python/cudf/cudf/tests" test-smoketest: "python ./ci/wheel_smoke_test_cudf.py" wheel-build-dask-cudf: - needs: wheel-build-cudf + needs: wheel-tests-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@cuda-118 with: build_type: pull-request package-name: dask_cudf @@ -113,7 +115,7 @@ jobs: wheel-tests-dask-cudf: needs: wheel-build-dask-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@cuda-118 with: build_type: pull-request package-name: dask_cudf diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2b583773e05..b383d185564 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -67,7 +67,7 @@ jobs: run_script: "ci/test_notebooks.sh" wheel-tests-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} @@ -78,7 +78,7 @@ jobs: test-unittest: "pytest -v -n 8 ./python/cudf/cudf/tests" wheel-tests-dask-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} From f7d434d6b633b3d5c68cb8f11063d193c504eee9 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 25 Jan 2023 21:51:12 -0600 Subject: [PATCH 6/7] Change ways to access `ptr` in `Buffer` (#12587) This PR: With the introduction of `copy-on-write`(https://github.com/rapidsai/cudf/pull/11718), and `spilling`, there are different ways each `Buffer` implementation expects the code to behave when a `.ptr` is accessed, i.e., in case of COW, a `ptr` can be accessed for read or write purposes, in case of spilling the ptr can be accessed with a spill lock and the buffer will be spillable after the spill lock goes out of scope during execution. For these reasons we introduced `ptr`, `mutable_ptr` and `data_array_view`, `_data_array_view`, `mask_array_view`, and `_mask_array_view`. With so many ways to access buffer ptr and array views, this has become quite difficult for devs to know when to use what unless you are fully familiar with the implementation details. It will also lead us to a lot of special case handling for `Buffer`, `SpillableBuffer`, and `CopyonWriteBuffer`. For this reason, we have decided to simplify fetching the pointer with a `get_ptr(mode="read"/"write")` API, fetching data & mask array views will also become methods that accept `mode` like `data_array_view(mode="read"/"write")` & `mask_array_view(mode="read"/"write")`. It is the expectation that when the caller passed "read", they don't tamper with the buffer or the memory it is pointing to. In the case of "write", they are good to mutate the memory it is pointing to. Note that even with `mode="read"/"write"` the caller should still, if appropriate, acquire a spill lock for the duration of the access. If this is not done, and the buffer is a `SpillableBuffer`, it will permanently be marked as unspillable. - [x] Introduces `get_ptr()` to replace `ptr` property. - [x] Replaces `data_array_view` & `mask_array_view` methods with `data_array_view(mode=r/w)` & `mask_array_view(mode=r/w)` Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - https://github.com/brandon-b-miller - Lawrence Mitchell (https://github.com/wence-) - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/12587 --- .../source/developer_guide/library_design.md | 6 +- python/cudf/cudf/_lib/column.pyi | 2 - python/cudf/cudf/_lib/column.pyx | 34 ++++----- python/cudf/cudf/_lib/copying.pyx | 4 +- python/cudf/cudf/_lib/transform.pyx | 4 +- python/cudf/cudf/core/buffer/buffer.py | 69 ++++++++++++++++-- .../cudf/cudf/core/buffer/spillable_buffer.py | 32 +++----- python/cudf/cudf/core/column/categorical.py | 7 +- python/cudf/cudf/core/column/column.py | 73 ++++++++++++++++--- python/cudf/cudf/core/column/numerical.py | 17 +++-- python/cudf/cudf/core/column/string.py | 5 +- python/cudf/cudf/core/column/timedelta.py | 11 ++- python/cudf/cudf/core/df_protocol.py | 4 +- python/cudf/cudf/core/frame.py | 4 +- python/cudf/cudf/core/indexed_frame.py | 2 + python/cudf/cudf/core/series.py | 6 +- python/cudf/cudf/core/window/rolling.py | 16 ++-- python/cudf/cudf/testing/_utils.py | 2 +- python/cudf/cudf/tests/test_buffer.py | 6 +- python/cudf/cudf/tests/test_column.py | 4 +- python/cudf/cudf/tests/test_dataframe.py | 8 +- python/cudf/cudf/tests/test_dataframe_copy.py | 4 +- python/cudf/cudf/tests/test_df_protocol.py | 4 +- python/cudf/cudf/tests/test_multiindex.py | 32 ++++++-- python/cudf/cudf/tests/test_pack.py | 6 +- python/cudf/cudf/tests/test_repr.py | 4 +- python/cudf/cudf/tests/test_spilling.py | 25 ++++--- python/cudf/cudf/utils/applyutils.py | 23 ++++-- python/cudf/cudf/utils/queryutils.py | 6 +- python/strings_udf/strings_udf/_typing.py | 8 +- 30 files changed, 284 insertions(+), 144 deletions(-) diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 54a28db1b58..bac5eae6b34 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -236,13 +236,11 @@ Spilling consists of two components: - A spill manager that tracks all instances of `SpillableBuffer` and spills them on demand. A global spill manager is used throughout cudf when spilling is enabled, which makes `as_buffer()` return `SpillableBuffer` instead of the default `Buffer` instances. -Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. +Accessing `Buffer.get_ptr(...)`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.get_ptr(...)`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. To address this, we mark the `SpillableBuffer` as unspillable, we say that the buffer has been _exposed_. This can either be permanent if the device pointer is exposed to external projects or temporary while `libcudf` accesses the device memory. -The `SpillableBuffer.get_ptr()` returns the device pointer of the buffer memory just like `.ptr` but if given an instance of `SpillLock`, the buffer is only unspillable as long as the instance of `SpillLock` is alive. - -For convenience, one can use the decorator/context `acquire_spill_lock` to associate a `SpillLock` with a lifetime bound to the context automatically. +The `SpillableBuffer.get_ptr(...)` returns the device pointer of the buffer memory but if called within an `acquire_spill_lock` decorator/context, the buffer is only marked unspillable while running within the decorator/context. #### Statistics cuDF supports spilling statistics, which can be very useful for performance profiling and to identify code that renders buffers unspillable. diff --git a/python/cudf/cudf/_lib/column.pyi b/python/cudf/cudf/_lib/column.pyi index 612f3cdf95a..013cba3ae03 100644 --- a/python/cudf/cudf/_lib/column.pyi +++ b/python/cudf/cudf/_lib/column.pyi @@ -52,8 +52,6 @@ class Column: @property def base_mask(self) -> Optional[Buffer]: ... @property - def base_mask_ptr(self) -> int: ... - @property def mask(self) -> Optional[Buffer]: ... @property def mask_ptr(self) -> int: ... diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index a5d72193049..11b4a900896 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -101,7 +101,7 @@ cdef class Column: if self.data is None: return 0 else: - return self.data.ptr + return self.data.get_ptr(mode="write") def set_base_data(self, value): if value is not None and not isinstance(value, Buffer): @@ -124,13 +124,6 @@ cdef class Column: def base_mask(self): return self._base_mask - @property - def base_mask_ptr(self): - if self.base_mask is None: - return 0 - else: - return self.base_mask.ptr - @property def mask(self): if self._mask is None: @@ -145,7 +138,7 @@ cdef class Column: if self.mask is None: return 0 else: - return self.mask.ptr + return self.mask.get_ptr(mode="write") def set_base_mask(self, value): """ @@ -206,7 +199,7 @@ cdef class Column: elif hasattr(value, "__cuda_array_interface__"): if value.__cuda_array_interface__["typestr"] not in ("|i1", "|u1"): if isinstance(value, Column): - value = value.data_array_view + value = value.data_array_view(mode="write") value = cp.asarray(value).view('|u1') mask = as_buffer(value) if mask.size < required_num_bytes: @@ -329,10 +322,10 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr( + mode="write") + ) cdef Column child_column if col.base_children: @@ -341,7 +334,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="write") + ) else: mask = NULL @@ -387,10 +382,8 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr(mode="read")) cdef Column child_column if col.base_children: @@ -399,7 +392,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="read") + ) else: mask = NULL @@ -549,7 +544,8 @@ cdef class Column: f"{data_owner} is spilled, which invalidates " f"the exposed data_ptr ({hex(data_ptr)})" ) - data_owner.ptr # accessing the pointer marks it exposed. + # accessing the pointer marks it exposed permanently. + data_owner.mark_exposed() else: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, size=0) diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index c01709322ed..6a53586396f 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import pickle @@ -765,7 +765,7 @@ cdef class _CPackedColumns: gpu_data = Buffer.deserialize(header["data"], frames) dbuf = DeviceBuffer( - ptr=gpu_data.ptr, + ptr=gpu_data.get_ptr(mode="write"), size=gpu_data.nbytes ) diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 6f17dbab86c..a0a8279b213 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -62,7 +62,9 @@ def mask_to_bools(object mask_buffer, size_type begin_bit, size_type end_bit): if not isinstance(mask_buffer, cudf.core.buffer.Buffer): raise TypeError("mask_buffer is not an instance of " "cudf.core.buffer.Buffer") - cdef bitmask_type* bit_mask = (mask_buffer.ptr) + cdef bitmask_type* bit_mask = ( + mask_buffer.get_ptr(mode="read") + ) cdef unique_ptr[column] result with nogil: diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index ebc4d76b6a0..71f48e0ab0c 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -176,7 +176,9 @@ def _getitem(self, offset: int, size: int) -> Buffer: """ return self._from_device_memory( cuda_array_interface_wrapper( - ptr=self.ptr + offset, size=size, owner=self.owner + ptr=self.get_ptr(mode="read") + offset, + size=size, + owner=self.owner, ) ) @@ -202,11 +204,6 @@ def nbytes(self) -> int: """Size of the buffer in bytes.""" return self._size - @property - def ptr(self) -> int: - """Device pointer to the start of the buffer.""" - return self._ptr - @property def owner(self) -> Any: """Object owning the memory of the buffer.""" @@ -215,18 +212,74 @@ def owner(self) -> Any: @property def __cuda_array_interface__(self) -> Mapping: """Implementation of the CUDA Array Interface.""" + return self._get_cuda_array_interface(readonly=False) + + def _get_cuda_array_interface(self, readonly=False): + """Helper function to create a CUDA Array Interface. + + Parameters + ---------- + readonly : bool, default False + If True, returns a CUDA Array Interface with + readonly flag set to True. + If False, returns a CUDA Array Interface with + readonly flag set to False. + + Returns + ------- + dict + """ return { - "data": (self.ptr, False), + "data": ( + self.get_ptr(mode="read" if readonly else "write"), + readonly, + ), "shape": (self.size,), "strides": None, "typestr": "|u1", "version": 0, } + @property + def _readonly_proxy_cai_obj(self): + """ + Returns a proxy object with a read-only CUDA Array Interface. + """ + return cuda_array_interface_wrapper( + ptr=self.get_ptr(mode="read"), + size=self.size, + owner=self, + readonly=True, + typestr="|u1", + version=0, + ) + + def get_ptr(self, *, mode) -> int: + """Device pointer to the start of the buffer. + + Parameters + ---------- + mode : str + Supported values are {"read", "write"} + If "write", the data pointed to may be modified + by the caller. If "read", the data pointed to + must not be modified by the caller. + Failure to fulfill this contract will cause + incorrect behavior. + + + See Also + -------- + SpillableBuffer.get_ptr + """ + return self._ptr + def memoryview(self) -> memoryview: """Read-only access to the buffer through host memory.""" host_buf = host_memory_allocation(self.size) - rmm._lib.device_buffer.copy_ptr_to_host(self.ptr, host_buf) + rmm._lib.device_buffer.copy_ptr_to_host( + self.get_ptr(mode="read"), host_buf + ) return memoryview(host_buf).toreadonly() def serialize(self) -> Tuple[dict, list]: diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 7ca85a307bf..2064c1fd133 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -145,7 +145,7 @@ def __len__(self): def __getitem__(self, i): if i == 0: - return self._buf.ptr + return self._buf.get_ptr(mode="write") elif i == 1: return False raise IndexError("tuple index out of range") @@ -359,7 +359,7 @@ def spill_lock(self, spill_lock: SpillLock) -> None: self.spill(target="gpu") self._spill_locks.add(spill_lock) - def get_ptr(self) -> int: + def get_ptr(self, *, mode) -> int: """Get a device pointer to the memory of the buffer. If this is called within an `acquire_spill_lock` context, @@ -369,8 +369,8 @@ def get_ptr(self) -> int: If this is *not* called within a `acquire_spill_lock` context, this buffer is marked as unspillable permanently. - Return - ------ + Returns + ------- int The device pointer as an integer """ @@ -409,18 +409,6 @@ def memory_info(self) -> Tuple[int, int, str]: ).__array_interface__["data"][0] return (ptr, self.nbytes, self._ptr_desc["type"]) - @property - def ptr(self) -> int: - """Access the memory directly - - Notice, this will mark the buffer as "exposed" and make - it unspillable permanently. - - Consider using `.get_ptr()` instead. - """ - self.mark_exposed() - return self._ptr - @property def owner(self) -> Any: return self._owner @@ -559,12 +547,12 @@ def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: self._owner = base self.lock = base.lock - @property - def ptr(self) -> int: - return self._base.ptr + self._offset - - def get_ptr(self) -> int: - return self._base.get_ptr() + self._offset + def get_ptr(self, *, mode) -> int: + """ + A passthrough method to `SpillableBuffer.get_ptr` + with factoring in the `offset`. + """ + return self._base.get_ptr(mode=mode) + self._offset def _getitem(self, offset: int, size: int) -> Buffer: return SpillableBufferSlice( diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index ef9f515fff7..af21d7545ee 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -956,9 +956,10 @@ def clip(self, lo: ScalarLike, hi: ScalarLike) -> "column.ColumnBase": self.astype(self.categories.dtype).clip(lo, hi).astype(self.dtype) ) - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: - return self.codes.data_array_view + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: + return self.codes.data_array_view(mode=mode) def unique(self, preserve_order=False) -> CategoricalColumn: codes = self.as_numerical.unique(preserve_order=preserve_order) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 69319e2f775..2f4d9e28314 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -64,7 +64,7 @@ ) from cudf.core._compat import PANDAS_GE_150 from cudf.core.abc import Serializable -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import Buffer, acquire_spill_lock, as_buffer from cudf.core.dtypes import ( CategoricalDtype, IntervalDtype, @@ -113,19 +113,71 @@ def as_frame(self) -> "cudf.core.frame.Frame": {None: self.copy(deep=False)} ) - @property - def data_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def data_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the data as a device array object + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.data).view(self.dtype) + if self.data is not None: + if mode == "read": + obj = self.data._readonly_proxy_cai_obj + elif mode == "write": + obj = self.data + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(self.dtype) - @property - def mask_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def mask_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the mask as a device array + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.mask).view(mask_dtype) + if self.mask is not None: + if mode == "read": + obj = self.mask._readonly_proxy_cai_obj + elif mode == "write": + obj = self.mask + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(mask_dtype) def __len__(self) -> int: return self.size @@ -163,7 +215,8 @@ def values_host(self) -> "np.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return self.data_array_view.copy_to_host() + with acquire_spill_lock(): + return self.data_array_view(mode="read").copy_to_host() @property def values(self) -> "cupy.ndarray": @@ -176,7 +229,7 @@ def values(self) -> "cupy.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return cupy.asarray(self.data_array_view) + return cupy.asarray(self.data_array_view(mode="write")) def find_and_replace( self: T, @@ -363,7 +416,7 @@ def nullmask(self) -> Buffer: """The gpu buffer for the null-mask""" if not self.nullable: raise ValueError("Column has no null mask") - return self.mask_array_view + return self.mask_array_view(mode="read") def copy(self: T, deep: bool = True) -> T: """Columns are immutable, so a deep copy produces a copy of the diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 7943135afe1..8ee3b6e15b6 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -35,7 +35,11 @@ is_number, is_scalar, ) -from cudf.core.buffer import Buffer, cuda_array_interface_wrapper +from cudf.core.buffer import ( + Buffer, + acquire_spill_lock, + cuda_array_interface_wrapper, +) from cudf.core.column import ( ColumnBase, as_column, @@ -110,8 +114,8 @@ def __contains__(self, item: ScalarLike) -> bool: # Handles improper item types # Fails if item is of type None, so the handler. try: - if np.can_cast(item, self.data_array_view.dtype): - item = self.data_array_view.dtype.type(item) + if np.can_cast(item, self.dtype): + item = self.dtype.type(item) else: return False except (TypeError, ValueError): @@ -564,6 +568,7 @@ def fillna( return super(NumericalColumn, col).fillna(fill_value, method) + @acquire_spill_lock() def _find_value( self, value: ScalarLike, closest: bool, find: Callable, compare: str ) -> int: @@ -573,14 +578,14 @@ def _find_value( found = 0 if len(self): found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, ) if found == -1: if self.is_monotonic_increasing and closest: found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, compare=compare, diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 4ca3a9ff04d..9c30585a541 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5395,8 +5395,9 @@ def base_size(self) -> int: else: return self.base_children[0].size - 1 - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: raise ValueError("Cannot get an array view of a StringColumn") def to_arrow(self) -> pa.Array: diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index b7d1724a342..e7979fa4d27 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -13,7 +13,7 @@ from cudf import _lib as libcudf from cudf._typing import ColumnBinaryOperand, DatetimeLikeScalar, Dtype from cudf.api.types import is_scalar, is_timedelta64_dtype -from cudf.core.buffer import Buffer +from cudf.core.buffer import Buffer, acquire_spill_lock from cudf.core.column import ColumnBase, column, string from cudf.utils.dtypes import np_to_pa_dtype from cudf.utils.utils import _fillna_natwise @@ -125,11 +125,16 @@ def values(self): "TimeDelta Arrays is not yet implemented in cudf" ) + @acquire_spill_lock() def to_arrow(self) -> pa.Array: mask = None if self.nullable: - mask = pa.py_buffer(self.mask_array_view.copy_to_host()) - data = pa.py_buffer(self.as_numerical.data_array_view.copy_to_host()) + mask = pa.py_buffer( + self.mask_array_view(mode="read").copy_to_host() + ) + data = pa.py_buffer( + self.as_numerical.data_array_view(mode="read").copy_to_host() + ) pa_dtype = np_to_pa_dtype(self.dtype) return pa.Array.from_buffers( type=pa_dtype, diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index b38d3048ed7..2090906380e 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. import enum from collections import abc @@ -89,7 +89,7 @@ def ptr(self) -> int: """ Pointer to start of the buffer as an integer. """ - return self._buf.ptr + return self._buf.get_ptr(mode="write") def __dlpack__(self): # DLPack not implemented in NumPy yet, so leave it out here. diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 32764c6c2f0..8b508eac324 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -1448,7 +1448,7 @@ def searchsorted( # Return result as cupy array if the values is non-scalar # If values is scalar, result is expected to be scalar. - result = cupy.asarray(outcol.data_array_view) + result = cupy.asarray(outcol.data_array_view(mode="read")) if scalar_flag: return result[0].item() else: diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 6526ba1e7c3..c8016786be9 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -49,6 +49,7 @@ is_scalar, ) from cudf.core._base_index import BaseIndex +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ColumnBase, as_column, full from cudf.core.column_accessor import ColumnAccessor from cudf.core.dtypes import ListDtype @@ -2105,6 +2106,7 @@ def add_suffix(self, suffix): Use `Series.add_suffix` or `DataFrame.add_suffix`" ) + @acquire_spill_lock() @_cudf_nvtx_annotate def _apply(self, func, kernel_getter, *args, **kwargs): """Apply `func` across the rows of the frame.""" diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index faad5275abd..1c697a2d824 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -39,6 +39,7 @@ is_struct_dtype, ) from cudf.core.abc import Serializable +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ( ColumnBase, DatetimeColumn, @@ -4855,6 +4856,7 @@ def _align_indices(series_list, how="outer", allow_non_unique=False): return result +@acquire_spill_lock() @_cudf_nvtx_annotate def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): r"""Returns a boolean array where two arrays are equal within a tolerance. @@ -4959,10 +4961,10 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): index = as_index(a.index) a_col = column.as_column(a) - a_array = cupy.asarray(a_col.data_array_view) + a_array = cupy.asarray(a_col.data_array_view(mode="read")) b_col = column.as_column(b) - b_array = cupy.asarray(b_col.data_array_view) + b_array = cupy.asarray(b_col.data_array_view(mode="read")) result = cupy.isclose( a=a_array, b=b_array, rtol=rtol, atol=atol, equal_nan=equal_nan diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index fb1cafa5625..cac4774400a 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION +# Copyright (c) 2020-2023, NVIDIA CORPORATION import itertools @@ -11,6 +11,7 @@ from cudf.api.types import is_integer, is_number from cudf.core import column from cudf.core._compat import PANDAS_GE_150 +from cudf.core.buffer import acquire_spill_lock from cudf.core.column.column import as_column from cudf.core.mixins import Reducible from cudf.utils import cudautils @@ -487,9 +488,11 @@ def _window_to_window_sizes(self, window): if is_integer(window): return window else: - return cudautils.window_sizes_from_offset( - self.obj.index._values.data_array_view, window - ) + with acquire_spill_lock(): + return cudautils.window_sizes_from_offset( + self.obj.index._values.data_array_view(mode="write"), + window, + ) def __repr__(self): return "{} [window={},min_periods={},center={}]".format( @@ -524,16 +527,17 @@ def __init__(self, groupby, window, min_periods=None, center=False): super().__init__(obj, window, min_periods=min_periods, center=center) + @acquire_spill_lock() def _window_to_window_sizes(self, window): if is_integer(window): return cudautils.grouped_window_sizes_from_offset( - column.arange(len(self.obj)).data_array_view, + column.arange(len(self.obj)).data_array_view(mode="read"), self._group_starts, window, ) else: return cudautils.grouped_window_sizes_from_offset( - self.obj.index._values.data_array_view, + self.obj.index._values.data_array_view(mode="read"), self._group_starts, window, ) diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index cbaf47a4c68..fb4daba1209 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -336,7 +336,7 @@ def assert_column_memory_eq( """ def get_ptr(x) -> int: - return x.ptr if x else 0 + return x.get_ptr(mode="read") if x else 0 assert get_ptr(lhs.base_data) == get_ptr(rhs.base_data) assert get_ptr(lhs.base_mask) == get_ptr(rhs.base_mask) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index df7152d53a6..1c9e7475080 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import cupy as cp import pytest @@ -52,7 +52,7 @@ def test_buffer_creation_from_any(): ary = cp.arange(arr_len) b = as_buffer(ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size with pytest.raises( @@ -62,7 +62,7 @@ def test_buffer_creation_from_any(): b = as_buffer(ary.data.ptr, size=ary.nbytes, owner=ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size assert b.owner.owner is ary diff --git a/python/cudf/cudf/tests/test_column.py b/python/cudf/cudf/tests/test_column.py index 75b82baf2e8..7d113bbb9e2 100644 --- a/python/cudf/cudf/tests/test_column.py +++ b/python/cudf/cudf/tests/test_column.py @@ -285,8 +285,8 @@ def test_column_view_valid_numeric_to_numeric(data, from_dtype, to_dtype): expect = pd.Series(cpu_data_view, dtype=cpu_data_view.dtype) got = cudf.Series(gpu_data_view, dtype=gpu_data_view.dtype) - gpu_ptr = gpu_data.data.ptr - assert gpu_ptr == got._column.data.ptr + gpu_ptr = gpu_data.data.get_ptr(mode="read") + assert gpu_ptr == got._column.data.get_ptr(mode="read") assert_eq(expect, got) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 0e0b0a37255..65e24c7c704 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1868,7 +1868,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s1.buffers()[0]).view("u1")[0], - gs1._column.mask_array_view.copy_to_host().view("u1")[0], + gs1._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s1, gs1.to_arrow()) @@ -1879,7 +1879,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s2.buffers()[0]).view("u1")[0], - gs2._column.mask_array_view.copy_to_host().view("u1")[0], + gs2._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s2, gs2.to_arrow()) @@ -2659,11 +2659,11 @@ def query_GPU_memory(note=""): cudaDF = cudaDF[boolmask] assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col0"].index._values.data_array_view.device_ctypes_pointer ) assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col1"].index._values.data_array_view.device_ctypes_pointer ) diff --git a/python/cudf/cudf/tests/test_dataframe_copy.py b/python/cudf/cudf/tests/test_dataframe_copy.py index 1a9098c70db..85e994bd733 100644 --- a/python/cudf/cudf/tests/test_dataframe_copy.py +++ b/python/cudf/cudf/tests/test_dataframe_copy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from copy import copy, deepcopy import numpy as np @@ -160,7 +160,7 @@ def test_kernel_deep_copy(): cdf = gdf.copy(deep=True) sr = gdf["b"] - add_one[1, len(sr)](sr._column.data_array_view) + add_one[1, len(sr)](sr._column.data_array_view(mode="write")) assert not gdf.to_string().split() == cdf.to_string().split() diff --git a/python/cudf/cudf/tests/test_df_protocol.py b/python/cudf/cudf/tests/test_df_protocol.py index 0981e850c10..7dbca90ab03 100644 --- a/python/cudf/cudf/tests/test_df_protocol.py +++ b/python/cudf/cudf/tests/test_df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. from typing import Any, Tuple @@ -41,7 +41,7 @@ def assert_buffer_equal(buffer_and_dtype: Tuple[_CuDFBuffer, Any], cudfcol): if dtype[0] != _DtypeKind.BOOL: array_from_dlpack = cp.from_dlpack(buf.__dlpack__()).get() - col_array = cp.asarray(cudfcol.data_array_view).get() + col_array = cp.asarray(cudfcol.data_array_view(mode="read")).get() assert_eq( array_from_dlpack[non_null_idxs.to_numpy()].flatten(), col_array[non_null_idxs.to_numpy()].flatten(), diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index d27d6732226..3e1f001e7d1 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -804,8 +804,8 @@ def test_multiindex_copy_deep(data, deep): lchildren = reduce(operator.add, lchildren) rchildren = reduce(operator.add, rchildren) - lptrs = [child.base_data.ptr for child in lchildren] - rptrs = [child.base_data.ptr for child in rchildren] + lptrs = [child.base_data.get_ptr(mode="read") for child in lchildren] + rptrs = [child.base_data.get_ptr(mode="read") for child in rchildren] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) @@ -814,20 +814,36 @@ def test_multiindex_copy_deep(data, deep): mi2 = mi1.copy(deep=deep) # Assert ._levels identity - lptrs = [lv._data._data[None].base_data.ptr for lv in mi1._levels] - rptrs = [lv._data._data[None].base_data.ptr for lv in mi2._levels] + lptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi1._levels + ] + rptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi2._levels + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._codes identity - lptrs = [c.base_data.ptr for _, c in mi1._codes._data.items()] - rptrs = [c.base_data.ptr for _, c in mi2._codes._data.items()] + lptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi1._codes._data.items() + ] + rptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi2._codes._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._data identity - lptrs = [d.base_data.ptr for _, d in mi1._data.items()] - rptrs = [d.base_data.ptr for _, d in mi2._data.items()] + lptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi1._data.items() + ] + rptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi2._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) diff --git a/python/cudf/cudf/tests/test_pack.py b/python/cudf/cudf/tests/test_pack.py index b6bda7ef5fa..9972071122e 100644 --- a/python/cudf/cudf/tests/test_pack.py +++ b/python/cudf/cudf/tests/test_pack.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -129,7 +129,9 @@ def assert_packed_frame_unique_pointers(df): for col in df: if df._data[col].data: - assert df._data[col].data.ptr != unpacked._data[col].data.ptr + assert df._data[col].data.get_ptr(mode="read") != unpacked._data[ + col + ].data.get_ptr(mode="read") def test_packed_dataframe_unique_pointers_numeric(): diff --git a/python/cudf/cudf/tests/test_repr.py b/python/cudf/cudf/tests/test_repr.py index 5ba0bec3dc4..bae0fde6463 100644 --- a/python/cudf/cudf/tests/test_repr.py +++ b/python/cudf/cudf/tests/test_repr.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. import textwrap @@ -31,7 +31,7 @@ def test_null_series(nrows, dtype): sr[np.random.choice([False, True], size=size)] = None if dtype != "category" and cudf.dtype(dtype).kind in {"u", "i"}: ps = pd.Series( - sr._column.data_array_view.copy_to_host(), + sr._column.data_array_view(mode="read").copy_to_host(), dtype=np_dtypes_to_pandas_dtypes.get( cudf.dtype(dtype), cudf.dtype(dtype) ), diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index bafe51b62ec..88ce908aa5f 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -119,7 +119,7 @@ def test_spillable_buffer(manager: SpillManager): buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) assert isinstance(buf, SpillableBuffer) assert buf.spillable - buf.ptr # Expose pointer + buf.mark_exposed() assert buf.exposed assert not buf.spillable buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) @@ -137,7 +137,6 @@ def test_spillable_buffer(manager: SpillManager): @pytest.mark.parametrize( "attribute", [ - "ptr", "get_ptr", "memoryview", "is_spilled", @@ -210,7 +209,7 @@ def test_spilling_buffer(manager: SpillManager): buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) buf.spill(target="cpu") assert buf.is_spilled - buf.ptr # Expose pointer and trigger unspill + buf.mark_exposed() # Expose pointer and trigger unspill assert not buf.is_spilled with pytest.raises(ValueError, match="unspillable buffer"): buf.spill(target="cpu") @@ -378,10 +377,10 @@ def test_get_ptr(manager: SpillManager, target): assert buf.spillable assert len(buf._spill_locks) == 0 with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable assert not buf.spillable assert buf.spillable @@ -501,7 +500,7 @@ def test_serialize_cuda_dataframe(manager: SpillManager): assert len(buf._base._spill_locks) == 1 assert len(frames) == 1 assert isinstance(frames[0], Buffer) - assert frames[0].ptr == buf.ptr + assert frames[0].get_ptr(mode="read") == buf.get_ptr(mode="read") frames[0] = cupy.array(frames[0], copy=True) df2 = protocol.deserialize(header, frames) @@ -557,18 +556,20 @@ def test_as_buffer_of_spillable_buffer(manager: SpillManager): b3 = as_buffer(b1.memory_info()[0], size=b1.size, owner=b1) with acquire_spill_lock(): - b3 = as_buffer(b1.get_ptr(), size=b1.size, owner=b1) + b3 = as_buffer(b1.get_ptr(mode="read"), size=b1.size, owner=b1) assert isinstance(b3, SpillableBufferSlice) assert b3.owner is b1 b4 = as_buffer( - b1.ptr + data.itemsize, size=b1.size - data.itemsize, owner=b3 + b1.get_ptr(mode="write") + data.itemsize, + size=b1.size - data.itemsize, + owner=b3, ) assert isinstance(b4, SpillableBufferSlice) assert b4.owner is b1 assert all(cupy.array(b4.memoryview()) == data[1:]) - b5 = as_buffer(b4.ptr, size=b4.size - 1, owner=b4) + b5 = as_buffer(b4.get_ptr(mode="write"), size=b4.size - 1, owner=b4) assert isinstance(b5, SpillableBufferSlice) assert b5.owner is b1 assert all(cupy.array(b5.memoryview()) == data[1:-1]) @@ -623,7 +624,7 @@ def test_statistics_expose(manager: SpillManager): ] # Expose the first buffer - buffers[0].ptr + buffers[0].mark_exposed() assert len(manager.statistics.exposes) == 1 stat = list(manager.statistics.exposes.values())[0] assert stat.count == 1 @@ -632,7 +633,7 @@ def test_statistics_expose(manager: SpillManager): # Expose all 10 buffers for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() # The rest of the ptr accesses should accumulate to a single stat # because they resolve to the same traceback. @@ -652,7 +653,7 @@ def test_statistics_expose(manager: SpillManager): # Expose the new buffers and check that they are counted as spilled for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() assert len(manager.statistics.exposes) == 3 stat = list(manager.statistics.exposes.values())[2] assert stat.count == 10 diff --git a/python/cudf/cudf/utils/applyutils.py b/python/cudf/cudf/utils/applyutils.py index 89331b933a8..7e998413642 100644 --- a/python/cudf/cudf/utils/applyutils.py +++ b/python/cudf/cudf/utils/applyutils.py @@ -1,13 +1,15 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import functools from typing import Any, Dict +import cupy as cp from numba import cuda from numba.core.utils import pysignature import cudf from cudf import _lib as libcudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column from cudf.utils import utils from cudf.utils.docutils import docfmt_partial @@ -139,21 +141,25 @@ def __init__( self.cache_key = cache_key self.kernel = self.compile(func, sig.parameters.keys(), kwargs.keys()) + @acquire_spill_lock() def run(self, df, **launch_params): # Get input columns if isinstance(self.incols, dict): inputs = { - v: df[k]._column.data_array_view + v: df[k]._column.data_array_view(mode="read") for (k, v) in self.incols.items() } else: - inputs = {k: df[k]._column.data_array_view for k in self.incols} + inputs = { + k: df[k]._column.data_array_view(mode="read") + for k in self.incols + } # Allocate output columns outputs = {} for k, dt in self.outcols.items(): outputs[k] = column.column_empty( len(df), dt, False - ).data_array_view + ).data_array_view(mode="write") # Bind argument args = {} for dct in [inputs, outputs, self.kwargs]: @@ -174,7 +180,7 @@ def run(self, df, **launch_params): ) if out_mask is not None: outdf._data[k] = outdf[k]._column.set_mask( - out_mask.data_array_view + out_mask.data_array_view(mode="write") ) return outdf @@ -213,11 +219,12 @@ def launch_kernel(self, df, args, chunks, blkct=None, tpb=None): def normalize_chunks(self, size, chunks): if isinstance(chunks, int): # *chunks* is the chunksize - return column.arange(0, size, chunks).data_array_view + return cuda.as_cuda_array( + cp.arange(start=0, stop=size, step=chunks) + ).view("int64") else: # *chunks* is an array of chunk leading offset - chunks = column.as_column(chunks) - return chunks.data_array_view + return cuda.as_cuda_array(cp.asarray(chunks)).view("int64") def _make_row_wise_kernel(func, argnames, extras): diff --git a/python/cudf/cudf/utils/queryutils.py b/python/cudf/cudf/utils/queryutils.py index 25b3d517e1c..4ce89b526d6 100644 --- a/python/cudf/cudf/utils/queryutils.py +++ b/python/cudf/cudf/utils/queryutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import ast import datetime @@ -8,6 +8,7 @@ from numba import cuda import cudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column_empty from cudf.utils import applyutils from cudf.utils.dtypes import ( @@ -191,6 +192,7 @@ def _add_prefix(arg): return kernel +@acquire_spill_lock() def query_execute(df, expr, callenv): """Compile & execute the query expression @@ -220,7 +222,7 @@ def query_execute(df, expr, callenv): "or bool dtypes." ) - colarrays = [col.data_array_view for col in colarrays] + colarrays = [col.data_array_view(mode="read") for col in colarrays] kernel = compiled["kernel"] # process env args diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 99e4046b0b3..80deb881ec8 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import operator @@ -12,6 +12,8 @@ from numba.cuda.cudadecl import registry as cuda_decl_registry from numba.cuda.cudadrv import nvvm +import rmm + data_layout = nvvm.data_layout # libcudf size_type @@ -112,7 +114,9 @@ def prepare_args(self, ty, val, **kwargs): if isinstance(ty, types.CPointer) and isinstance( ty.dtype, (StringView, UDFString) ): - return types.uint64, val.ptr + return types.uint64, val.ptr if isinstance( + val, rmm._lib.device_buffer.DeviceBuffer + ) else val.get_ptr(mode="read") else: return ty, val From 20c945be7295efbd6c55d4f388bfcf898484915a Mon Sep 17 00:00:00 2001 From: Cindy Jiang <47068112+cindyyuanjiang@users.noreply.github.com> Date: Thu, 26 Jan 2023 07:24:22 -0800 Subject: [PATCH 7/7] Add `regex_program` java APIs and unit tests (#12548) Adds a set of java regex APIs that take in a `regex_program` as parameter and java unit tests. This is part of the solution for https://github.com/NVIDIA/spark-rapids/issues/7295. Authors: - Cindy Jiang (https://github.com/cindyyuanjiang) Approvers: - MithunR (https://github.com/mythrocks) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/12548 --- .../java/ai/rapids/cudf/CaptureGroups.java | 36 ++ .../main/java/ai/rapids/cudf/ColumnView.java | 305 ++++++++++++-- .../main/java/ai/rapids/cudf/RegexFlag.java | 37 ++ .../java/ai/rapids/cudf/RegexProgram.java | 130 ++++++ java/src/main/native/src/ColumnViewJni.cpp | 167 ++++---- .../java/ai/rapids/cudf/ColumnVectorTest.java | 398 +++++++++++------- 6 files changed, 799 insertions(+), 274 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/CaptureGroups.java create mode 100644 java/src/main/java/ai/rapids/cudf/RegexFlag.java create mode 100644 java/src/main/java/ai/rapids/cudf/RegexProgram.java diff --git a/java/src/main/java/ai/rapids/cudf/CaptureGroups.java b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java new file mode 100644 index 00000000000..2ab778dbc35 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java @@ -0,0 +1,36 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Capture groups setting, closely following cudf::strings::capture_groups. + * + * For processing a regex pattern containing capture groups. These can be used + * to optimize the generated regex instructions where the capture groups do not + * require extracting the groups. + */ +public enum CaptureGroups { + EXTRACT(0), // capture groups processed normally for extract + NON_CAPTURE(1); // convert all capture groups to non-capture groups + + final int nativeId; // Native id, for use with libcudf. + private CaptureGroups(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 47d6b7573cd..8ffe5b4aa09 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2531,12 +2531,35 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) { * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplit(regexProg, limit); + } else { + return stringSplit(pattern, limit); + } + } + + /** + * Returns a list of columns by splitting each string using the specified regex program. The + * number of rows in the output columns will be the same as the input column. Null entries + * are added for a row where split results have been exhausted. Null input entries result in + * all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; - return new Table(stringSplit(this.getNativeView(), pattern, limit, splitByRegex)); + return new Table(stringSplit(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2550,6 +2573,7 @@ public final Table stringSplit(String pattern, int limit, boolean splitByRegex) * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, boolean splitByRegex) { return stringSplit(pattern, -1, splitByRegex); } @@ -2567,7 +2591,11 @@ public final Table stringSplit(String pattern, boolean splitByRegex) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter, int limit) { - return stringSplit(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be a String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new Table(stringSplit(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2580,7 +2608,21 @@ public final Table stringSplit(String delimiter, int limit) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter) { - return stringSplit(delimiter, -1, false); + return stringSplit(delimiter, -1); + } + + /** + * Returns a list of columns by splitting each string using the specified regex program with + * string literal delimiter. The number of rows in the output columns will be the same as the + * input column. Null entries are added for a row where split results have been exhausted. + * Null input entries result in all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg) { + return stringSplit(regexProg, -1); } /** @@ -2595,13 +2637,34 @@ public final Table stringSplit(String delimiter) { * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplitRecord(regexProg, limit); + } else { + return stringSplitRecord(pattern, limit); + } + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program pattern. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; return new ColumnVector( - stringSplitRecord(this.getNativeView(), pattern, limit, splitByRegex)); + stringSplitRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2613,6 +2676,7 @@ public final ColumnVector stringSplitRecord(String pattern, int limit, boolean s * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex) { return stringSplitRecord(pattern, -1, splitByRegex); } @@ -2628,7 +2692,12 @@ public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter, int limit) { - return stringSplitRecord(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new ColumnVector( + stringSplitRecord(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2639,7 +2708,19 @@ public final ColumnVector stringSplitRecord(String delimiter, int limit) { * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter) { - return stringSplitRecord(delimiter, -1, false); + return stringSplitRecord(delimiter, -1); + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program with string literal delimiter. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg) { + return stringSplitRecord(regexProg, -1); } /** @@ -2846,10 +2927,23 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) { * @param repl The string scalar to replace for each pattern match. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl) { return replaceRegex(pattern, repl, -1); } + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl) { + return replaceRegex(regexProg, repl, -1); + } + /** * For each string, replaces any character sequence matching the given pattern using the * replacement string scalar. @@ -2859,12 +2953,27 @@ public final ColumnVector replaceRegex(String pattern, Scalar repl) { * @param maxRepl The maximum number of times a replacement should occur within each string. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) { + return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl, maxRepl); + } + + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @param maxRepl The maximum number of times a replacement should occur within each string. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl, int maxRepl) { if (!repl.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Replacement must be a string scalar"); } - return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(), - maxRepl)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(replaceRegex(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, repl.getScalarHandle(), maxRepl)); } /** @@ -2890,9 +2999,25 @@ public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls) * @param replace The replacement template for creating the output string. * @return A new java column vector containing the string results. */ + @Deprecated public final ColumnVector stringReplaceWithBackrefs(String pattern, String replace) { - return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), pattern, - replace)); + return stringReplaceWithBackrefs(new RegexProgram(pattern), replace); + } + + /** + * For each string, replaces any character sequence matching the given regex program + * pattern using the replace template for back-references. + * + * Any null string entries return corresponding null output column entries. + * + * @param regexProg The regex program with pattern to search within each string. + * @param replace The replacement template for creating the output string. + * @return A new java column vector containing the string results. + */ + public final ColumnVector stringReplaceWithBackrefs(RegexProgram regexProg, String replace) { + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), regexProg.pattern(), + regexProg.combinedFlags(), regexProg.capture().nativeId, replace)); } /** @@ -3164,11 +3289,32 @@ public final ColumnVector clamp(Scalar lo, Scalar loReplace, Scalar hi, Scalar h * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector matchesRe(String pattern) { + return matchesRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given regex program but only at the beginning of the string. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.matches_re("\\d+") + * r is now [false, true, false] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector matchesRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(matchesRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(matchesRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3177,7 +3323,7 @@ public final ColumnVector matchesRe(String pattern) { * * ``` * cv = ["abc","123","def456"] - * result = cv.matches_re("\\d+") + * result = cv.contains_re("\\d+") * r is now [false, true, true] * ``` * Any null string entries return corresponding null output column entries. @@ -3187,11 +3333,32 @@ public final ColumnVector matchesRe(String pattern) { * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector containsRe(String pattern) { + return containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given RegexProgram object starting at any location. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.contains_re("\\d+") + * r is now [false, true, true] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector containsRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(containsRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(containsRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3206,10 +3373,27 @@ public final ColumnVector containsRe(String pattern) { * @throws CudfException if any error happens including if the RE does * not contain any capture groups. */ + @Deprecated public final Table extractRe(String pattern) throws CudfException { + return extractRe(new RegexProgram(pattern)); + } + + /** + * For each captured group specified in the given regex program + * return a column in the table. Null entries are added if the string + * does not match. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg the regex program to use + * @return the table of extracted matches + * @throws CudfException if any error happens including if the RE does + * not contain any capture groups. + */ + public final Table extractRe(RegexProgram regexProg) throws CudfException { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - return new Table(extractRe(this.getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + return new Table(extractRe(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3222,11 +3406,28 @@ public final Table extractRe(String pattern) throws CudfException { * @param idx The regex group index * @return A new column vector of extracted matches */ + @Deprecated public final ColumnVector extractAllRecord(String pattern, int idx) { + return extractAllRecord(new RegexProgram(pattern), idx); + } + + /** + * Extracts all strings that match the given regex program and corresponds to the + * regular expression group index. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg The regex program + * @param idx The regex group index + * @return A new column vector of extracted matches + */ + public final ColumnVector extractAllRecord(RegexProgram regexProg, int idx) { assert type.equals(DType.STRING) : "column type must be a String"; assert idx >= 0 : "group index must be at least 0"; - - return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector( + extractAllRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, idx)); } /** @@ -3881,14 +4082,16 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long[] stringSplit(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long[] stringSplit(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Returns a column that are lists of strings in which each list is made by splitting the @@ -3896,14 +4099,16 @@ private static native long[] stringSplit(long nativeHandle, String pattern, int * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long stringSplitRecord(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long stringSplitRecord(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Native method to calculate substring from a given string column. 0 indexing. @@ -3941,12 +4146,14 @@ private static native long substringColumn(long columnView, long startColumn, lo * Native method for replacing each regular expression pattern match with the specified * replacement string. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression pattern to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param repl native handle of the cudf::scalar containing the replacement string. * @param maxRepl maximum number of times to replace the pattern within a string * @return native handle of the resulting cudf column containing the string results. */ - private static native long replaceRegex(long columnView, String pattern, + private static native long replaceRegex(long columnView, String pattern, int flags, int capture, long repl, long maxRepl) throws CudfException; /** @@ -3960,15 +4167,17 @@ private static native long replaceMultiRegex(long columnView, String[] patterns, long repls) throws CudfException; /** - * Native method for replacing any character sequence matching the given pattern - * using the replace template for back-references. + * Native method for replacing any character sequence matching the given regex program + * pattern using the replace template for back-references. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression patterns to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param replace The replacement template for creating the output string. * @return native handle of the resulting cudf column containing the string results. */ - private static native long stringReplaceWithBackrefs(long columnView, String pattern, - String replace) throws CudfException; + private static native long stringReplaceWithBackrefs(long columnView, String pattern, int flags, + int capture, String replace) throws CudfException; /** * Native method for checking if strings in a column starts with a specified comparison string. @@ -3995,21 +4204,25 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringStrip(long columnView, int type, long toStrip) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern from the + * Native method for checking if strings match the passed in regex program from the * beginning of the string. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long matchesRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long matchesRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern starting at any location. + * Native method for checking if strings match the passed in regex program starting at any location. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long containsRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long containsRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** * Native method for checking if strings match the passed in like pattern @@ -4030,19 +4243,21 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringContains(long cudfViewHandle, long compString) throws CudfException; /** - * Native method for extracting results from an regular expressions. Returns a table handle. + * Native method for extracting results from a regex program. Returns a table handle. */ - private static native long[] extractRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long[] extractRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for extracting all results corresponding to group idx from a regular expression. + * Native method for extracting all results corresponding to group idx from a regex program. * * @param nativeHandle Native handle of the cudf::column_view being operated on. - * @param pattern String regex pattern. + * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param idx Regex group index. A 0 value means matching the entire regex. * @return Native handle of a string column of the result. */ - private static native long extractAllRecord(long nativeHandle, String pattern, int idx); + private static native long extractAllRecord(long nativeHandle, String pattern, int flags, int capture, int idx); private static native long urlDecode(long cudfViewHandle); diff --git a/java/src/main/java/ai/rapids/cudf/RegexFlag.java b/java/src/main/java/ai/rapids/cudf/RegexFlag.java new file mode 100644 index 00000000000..7ed8e0354c9 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexFlag.java @@ -0,0 +1,37 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Regex flags setting, closely following cudf::strings::regex_flags. + * + * These types can be or'd to combine them. The values are chosen to + * leave room for future flags and to match the Python flag values. + */ +public enum RegexFlag { + DEFAULT(0), // default + MULTILINE(8), // the '^' and '$' honor new-line characters + DOTALL(16), // the '.' matching includes new-line characters + ASCII(256); // use only ASCII when matching built-in character classes + + final int nativeId; // Native id, for use with libcudf. + private RegexFlag(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RegexProgram.java b/java/src/main/java/ai/rapids/cudf/RegexProgram.java new file mode 100644 index 00000000000..358eea8ba43 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexProgram.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +import java.util.EnumSet; + +/** + * Regex program class, closely following cudf::strings::regex_program. + */ +public class RegexProgram { + private String pattern; // regex pattern + private EnumSet flags; // regex flags for interpreting special characters in the pattern + // controls how capture groups in the pattern are used + // default is to extract a capture group + private CaptureGroups capture; + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + */ + public RegexProgram(String pattern) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + */ + public RegexProgram(String pattern, EnumSet flags) { + this(pattern, flags, CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, CaptureGroups capture) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), capture); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, EnumSet flags, CaptureGroups capture) { + assert pattern != null : "pattern may not be null"; + this.pattern = pattern; + this.flags = flags; + this.capture = capture; + } + + /** + * Get the pattern used to create this instance + * + * @param return A regex pattern as a string + */ + public String pattern() { + return pattern; + } + + /** + * Get the regex flags setting used to create this instance + * + * @param return Regex flags setting + */ + public EnumSet flags() { + return flags; + } + + /** + * Reset the regex flags setting for this instance + * + * @param flags Regex flags setting + */ + public void setFlags(EnumSet flags) { + this.flags = flags; + } + + /** + * Get the capture groups setting used to create this instance + * + * @param return Capture groups setting + */ + public CaptureGroups capture() { + return capture; + } + + /** + * Reset the capture groups setting for this instance + * + * @param capture Capture groups setting + */ + public void setCapture(CaptureGroups capture) { + this.capture = capture; + } + + /** + * Combine the regex flags using 'or' + * + * @param return An integer representing the value of combined (or'ed) flags + */ + public int combinedFlags() { + int allFlags = 0; + for (RegexFlag flag : flags) { + allFlags = allFlags | flag.nativeId; + } + return allFlags; + } +} diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index b48ddae196b..c17e16bce73 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -62,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -678,11 +679,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reverseStringsOrLists(JNI CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -696,31 +695,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); auto const max_split = limit > 1 ? limit - 1 : limit; + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto result = split_by_regex ? - cudf::strings::split_re(strs_input, pattern, max_split) : - cudf::strings::split(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_re(strings_column, *regex_prog, max_split) : + cudf::strings::split(strings_column, cudf::string_scalar{pattern}, max_split); return cudf::jni::convert_table_for_return(env, std::move(result)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -734,22 +727,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto const max_split = limit > 1 ? limit - 1 : limit; auto result = split_by_regex ? - cudf::strings::split_record_re(strs_input, pattern, max_split) : - cudf::strings::split_record(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_record_re(strings_column, *regex_prog, max_split) : + cudf::strings::split_record(strings_column, cudf::string_scalar{pattern}, max_split); return release_as_jlong(result); } CATCH_STD(env, 0); @@ -1290,32 +1279,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringContains(JNIEnv *en JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_matchesRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::matches_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return release_as_jlong(cudf::strings::matches_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::contains_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const capture = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, capture); + return release_as_jlong(cudf::strings::contains_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } @@ -1555,21 +1554,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass, - jlong j_column_view, - jstring j_pattern, jlong j_repl, - jlong j_maxrepl) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex( + JNIEnv *env, jclass, jlong j_column_view, jstring j_pattern, jint regex_flags, + jint capture_groups, jlong j_repl, jlong j_maxrepl) { JNI_NULL_CHECK(env, j_column_view, "column is null", 0); JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0); JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0); try { cudf::jni::auto_set_device(env); - auto cv = reinterpret_cast(j_column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring pattern(env, j_pattern); - auto repl = reinterpret_cast(j_repl); - return release_as_jlong(cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl)); + auto const column_view = reinterpret_cast(j_column_view); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, j_pattern); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto const repl = reinterpret_cast(j_repl); + return release_as_jlong( + cudf::strings::replace_re(strings_column, *regex_prog, *repl, j_maxrepl)); } CATCH_STD(env, 0); } @@ -1595,19 +1597,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs( - JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) { + JNIEnv *env, jclass, jlong column_view, jstring pattern_obj, jint regex_flags, + jint capture_groups, jstring replaceObj) { JNI_NULL_CHECK(env, column_view, "column is null", 0); - JNI_NULL_CHECK(env, patternObj, "pattern string is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern string is null", 0); JNI_NULL_CHECK(env, replaceObj, "replace string is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring ss_pattern(env, patternObj); + auto const cv = reinterpret_cast(column_view); + auto const strings_column = cudf::strings_column_view{*cv}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); cudf::jni::native_jstring ss_replace(env, replaceObj); return release_as_jlong( - cudf::strings::replace_with_backrefs(scv, ss_pattern.get(), ss_replace.get())); + cudf::strings::replace_with_backrefs(strings_column, *regex_prog, ss_replace.get())); } CATCH_STD(env, 0); } @@ -1663,37 +1669,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringStrip(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *env, jclass, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", nullptr); - JNI_NULL_CHECK(env, patternObj, "pattern is null", nullptr); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", nullptr); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, patternObj); - - return cudf::jni::convert_table_for_return( - env, cudf::strings::extract(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return cudf::jni::convert_table_for_return(env, + cudf::strings::extract(strings_column, *regex_prog)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *env, jclass, - jlong j_view_handle, - jstring pattern_obj, - jint idx) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord( + JNIEnv *env, jclass, jlong j_view_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint idx) { JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", 0); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, pattern_obj); - - auto result = (idx == 0) ? cudf::strings::findall(strings_column, pattern.get()) : - cudf::strings::extract_all_record(strings_column, pattern.get()); - + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto result = (idx == 0) ? cudf::strings::findall(strings_column, *regex_prog) : + cudf::strings::extract_all_record(strings_column, *regex_prog); return release_as_jlong(result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index fc0a542e0a7..5b846545906 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4040,41 +4040,50 @@ void testStringFindOperations() { @Test void testExtractRe() { - try (ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); - Table expected = new Table.TestBuilder() - .column("a", "b", null, null) - .column("1", "2", null, null) - .build(); - Table found = input.extractRe("([ab])(\\d)")) { - assertTablesAreEqual(expected, found); - } + ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); + Table expected = new Table.TestBuilder() + .column("a", "b", null, null) + .column("1", "2", null, null) + .build(); + try (Table found = input.extractRe("([ab])(\\d)")) { + assertTablesAreEqual(expected, found); + } + try (Table found = input.extractRe(new RegexProgram("([ab])(\\d)"))) { + assertTablesAreEqual(expected, found); + } } @Test void testExtractAllRecord() { String pattern = "([ab])(\\d)"; - try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); - ColumnVector expectedIdx0 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a1"), - Arrays.asList("b2"), - Arrays.asList(), - null, - Arrays.asList("a1", "b1", "a2")); - ColumnVector expectedIdx12 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a", "1"), - Arrays.asList("b", "2"), - null, - null, - Arrays.asList("a", "1", "b", "1", "a", "2")); - - ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); - ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); - ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2); - ) { + RegexProgram regexProg = new RegexProgram(pattern); + ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); + ColumnVector expectedIdx0 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a1"), + Arrays.asList("b2"), + Arrays.asList(), + null, + Arrays.asList("a1", "b1", "a2")); + ColumnVector expectedIdx12 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a", "1"), + Arrays.asList("b", "2"), + null, + null, + Arrays.asList("a", "1", "b", "1", "a", "2")); + try (ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); + ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); + ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2)) { + assertColumnsAreEqual(expectedIdx0, resultIdx0); + assertColumnsAreEqual(expectedIdx12, resultIdx1); + assertColumnsAreEqual(expectedIdx12, resultIdx2); + } + try (ColumnVector resultIdx0 = v.extractAllRecord(regexProg, 0); + ColumnVector resultIdx1 = v.extractAllRecord(regexProg, 1); + ColumnVector resultIdx2 = v.extractAllRecord(regexProg, 2)) { assertColumnsAreEqual(expectedIdx0, resultIdx0); assertColumnsAreEqual(expectedIdx12, resultIdx1); assertColumnsAreEqual(expectedIdx12, resultIdx2); @@ -4087,25 +4096,37 @@ void testMatchesRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res1 = testStrings.matchesRe(patternString1); + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00"); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, + true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, + false, false); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, + true, true); + try (ColumnVector res1 = testStrings.matchesRe(patternString1); ColumnVector res2 = testStrings.matchesRe(patternString2); - ColumnVector res3 = testStrings.matchesRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, - true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, - false, false); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, - true, true)) { + ColumnVector res3 = testStrings.matchesRe(patternString3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + try (ColumnVector res1 = testStrings.matchesRe(regexProg1); + ColumnVector res2 = testStrings.matchesRe(regexProg2); + ColumnVector res3 = testStrings.matchesRe(regexProg3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res = testStrings.matchesRe(patternString4)) {} + try (ColumnVector res = testStrings.matchesRe(patternString4)) {} + }); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.matchesRe(regexProg4)) {} }); } @@ -4115,36 +4136,51 @@ void testContainsRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, + true, true, true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, + false, false, false, true); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, + true, true, true, true); + try (ColumnVector res1 = testStrings.containsRe(patternString1); ColumnVector res2 = testStrings.containsRe(patternString2); - ColumnVector res3 = testStrings.containsRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, - true, true, true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, - false, false, false, true); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, - true, true, true, true)) { + ColumnVector res3 = testStrings.containsRe(patternString3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } + try (ColumnVector res1 = testStrings.containsRe(regexProg1); + ColumnVector res2 = testStrings.containsRe(regexProg2); + ColumnVector res3 = testStrings.containsRe(regexProg3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + ColumnVector testStringsError = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStringsError.containsRe(patternString4)) {}}); assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res = testStrings.containsRe(patternString4)) {} + try (ColumnVector res = testStringsError.containsRe(regexProg4)) {} }); } @Test - @Disabled("Needs fix for https://github.com/rapidsai/cudf/issues/4671") void testContainsReEmptyInput() { String patternString1 = ".*"; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings(""); ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector resRe1 = testStrings.containsRe(regexProg1); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true)) { assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected1, resRe1); } } @@ -4405,9 +4441,13 @@ void testsubstring() { @Test void testExtractListElements() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(0)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(0)) { assertColumnsAreEqual(expected, result); } @@ -4415,10 +4455,14 @@ void testExtractListElements() { @Test void testExtractListElementsV() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); - ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); + ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(indices)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(indices)) { assertColumnsAreEqual(expected, result); } @@ -4947,103 +4991,127 @@ void testReverseList() { @Test void testStringSplit() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some things", "strings here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some things", "strings here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2); Table resultSplitAll = v.stringSplit(pattern)) { - assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } } @Test void testStringSplitByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some_things", "strings_here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some_things", "strings_here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); Table resultSplitAll = v.stringSplit(pattern, true)) { assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); assertTablesAreEqual(expectedSplitAll, resultSplitAll); } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecord() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some things"), - Arrays.asList("test", "strings here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some things"), + Arrays.asList("test", "strings here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); ColumnVector resultSplitAll = v.stringSplitRecord(pattern)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecordByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some_things"), - Arrays.asList("test", "strings_here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some_things"), + Arrays.asList("test", "strings_here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); ColumnVector resultSplitAll = v.stringSplitRecord(pattern, true)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test @@ -5091,26 +5159,37 @@ void teststringReplaceThrowsException() { @Test void testReplaceRegex() { - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl); + ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + String pattern = "[tT]itle"; + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + try (ColumnVector actual = v.replaceRegex(pattern, repl); ColumnVector expected = ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) { + try (ColumnVector actual = v.replaceRegex(pattern, repl, 0)) { assertColumnsAreEqual(v, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1); + try (ColumnVector actual = v.replaceRegex(pattern, repl, 1); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 0)) { + assertColumnsAreEqual(v, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 1); ColumnVector expected = ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); @@ -5132,45 +5211,56 @@ void testReplaceMultiRegex() { @Test void testStringReplaceWithBackrefs() { - try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", - null); + try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", null); ColumnVector expected = ColumnVector.fromStrings("

title

", "

another title

", null); - ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

")) { + ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("

(.*)

"), "

\\1

")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-1-01", "2020-2-02", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", null); - ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])-"), "-0\\1-")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])$", "-0\\1")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])$", "-0\\1"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])$"), "-0\\1")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1 random_text", "2020-02-2T12:34:56", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01 random_text", "2020-02-02T12:34:56", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])([ T])", "-0\\1\\2")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])([ T])", "-0\\1\\2"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])([ T])"), "-0\\1\\2")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // test zero as group index try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null); ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;", "aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) { + ColumnVector actual = v.stringReplaceWithBackrefs("([a-z]+)-([0-9]+)", "${0}:${1}:${2};"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("([a-z]+)-([0-9]+)"), "${0}:${1}:${2};")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // group index exceeds group count @@ -5180,6 +5270,12 @@ void testStringReplaceWithBackrefs() { } }); + assertThrows(CudfException.class, () -> { + try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh"); + ColumnVector r = v.stringReplaceWithBackrefs( + new RegexProgram("([A-Z]+)([0-9]+)([a-z]+)"), "\\4")) { + } + }); } @Test