diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 093063c21e7..5c4328cb7a2 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -54,7 +54,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -66,7 +66,7 @@ jobs: wheel-publish-cudf: needs: wheel-build-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -76,7 +76,7 @@ jobs: wheel-build-dask-cudf: needs: wheel-publish-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -87,7 +87,7 @@ jobs: wheel-publish-dask-cudf: needs: wheel-build-dask-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-publish.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-publish.yml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index cb83aab31cd..89e14d3e421 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -85,7 +85,7 @@ jobs: wheel-build-cudf: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 with: build_type: pull-request package-name: cudf @@ -94,17 +94,19 @@ jobs: wheel-tests-cudf: needs: wheel-build-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 with: build_type: pull-request package-name: cudf + # Install cupy-cuda11x for arm from a special index url + # Install tokenizers last binary wheel to avoid a Rust compile from the latest sdist test-before-arm64: "pip install tokenizers==0.10.2 cupy-cuda11x -f https://pip.cupy.dev/aarch64" test-unittest: "pytest -v -n 8 ./python/cudf/cudf/tests" test-smoketest: "python ./ci/wheel_smoke_test_cudf.py" wheel-build-dask-cudf: - needs: wheel-build-cudf + needs: wheel-tests-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-build.yml@cuda-118 with: build_type: pull-request package-name: dask_cudf @@ -113,7 +115,7 @@ jobs: wheel-tests-dask-cudf: needs: wheel-build-dask-cudf secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@cuda-118 with: build_type: pull-request package-name: dask_cudf diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2b583773e05..b383d185564 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -67,7 +67,7 @@ jobs: run_script: "ci/test_notebooks.sh" wheel-tests-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} @@ -78,7 +78,7 @@ jobs: test-unittest: "pytest -v -n 8 ./python/cudf/cudf/tests" wheel-tests-dask-cudf: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@main + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-pure-test.yml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/cpp/include/cudf/detail/aggregation/aggregation.cuh b/cpp/include/cudf/detail/aggregation/aggregation.cuh index 818e8cd7cc6..f13166d5321 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.cuh +++ b/cpp/include/cudf/detail/aggregation/aggregation.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -232,8 +232,8 @@ struct update_target_element< aggregation::SUM, target_has_nulls, source_has_nulls, - std::enable_if_t() && cudf::has_atomic_support() && - !is_fixed_point()>> { + std::enable_if_t() && cudf::has_atomic_support() && + !cudf::is_fixed_point() && !cudf::is_timestamp()>> { __device__ void operator()(mutable_column_device_view target, size_type target_index, column_device_view source, diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 75027c78a68..360c314f2db 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1154,9 +1154,7 @@ struct target_type_impl { using type = bool; }; -// Always use `double` for MEAN -// Except for chrono types where result is chrono. (Use FloorDiv) -// TODO: MEAN should be only be enabled for duration types - not for timestamps +// Always use `double` for MEAN except for durations and fixed point types. template struct target_type_impl< Source, @@ -1167,10 +1165,10 @@ struct target_type_impl< }; template -struct target_type_impl< - Source, - k, - std::enable_if_t<(is_chrono() or is_fixed_point()) && (k == aggregation::MEAN)>> { +struct target_type_impl() or is_fixed_point()) && + (k == aggregation::MEAN)>> { using type = Source; }; @@ -1206,10 +1204,11 @@ struct target_type_impl< using type = Source; }; -// Summing/Multiplying chrono types, use same type accumulator -// TODO: Sum/Product should only be enabled for duration types - not for timestamps +// Summing duration types, use same type accumulator template -struct target_type_impl() && is_sum_product_agg(k)>> { +struct target_type_impl() && (k == aggregation::SUM)>> { using type = Source; }; diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 70176392ee9..23d130e1585 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -90,6 +90,13 @@ struct page_state_s { const uint8_t* lvl_start[NUM_LEVEL_TYPES]; // [def,rep] int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded int32_t row_index_lower_bound; // lower bound of row indices we should process + + // a shared-memory cache of frequently used data when decoding. The source of this data is + // normally stored in global memory which can yield poor performance. So, when possible + // we copy that info here prior to decoding + PageNestingDecodeInfo nesting_decode_cache[max_cacheable_nesting_decode_info]; + // points to either nesting_decode_cache above when possible, or to the global source otherwise + PageNestingDecodeInfo* nesting_info; }; /** @@ -927,23 +934,49 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, int chunk_idx; // Fetch page info - if (t == 0) s->page = *p; + if (!t) s->page = *p; __syncthreads(); if (s->page.flags & PAGEINFO_FLAGS_DICTIONARY) { return false; } // Fetch column chunk info chunk_idx = s->page.chunk_idx; - if (t == 0) { s->col = chunks[chunk_idx]; } - - // zero nested value and valid counts - int d = 0; - while (d < s->page.num_output_nesting_levels) { - if (d + t < s->page.num_output_nesting_levels) { - s->page.nesting[d + t].valid_count = 0; - s->page.nesting[d + t].value_count = 0; - s->page.nesting[d + t].null_count = 0; + if (!t) { s->col = chunks[chunk_idx]; } + + // if we can use the decode cache, set it up now + auto const can_use_decode_cache = s->page.nesting_info_size <= max_cacheable_nesting_decode_info; + if (can_use_decode_cache) { + int depth = 0; + while (depth < s->page.nesting_info_size) { + int const thread_depth = depth + t; + if (thread_depth < s->page.nesting_info_size) { + // these values need to be copied over from global + s->nesting_decode_cache[thread_depth].max_def_level = + s->page.nesting_decode[thread_depth].max_def_level; + s->nesting_decode_cache[thread_depth].page_start_value = + s->page.nesting_decode[thread_depth].page_start_value; + s->nesting_decode_cache[thread_depth].start_depth = + s->page.nesting_decode[thread_depth].start_depth; + s->nesting_decode_cache[thread_depth].end_depth = + s->page.nesting_decode[thread_depth].end_depth; + } + depth += blockDim.x; + } + } + if (!t) { + s->nesting_info = can_use_decode_cache ? s->nesting_decode_cache : s->page.nesting_decode; + } + __syncthreads(); + + // zero counts + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + int const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->nesting_info[thread_depth].valid_count = 0; + s->nesting_info[thread_depth].value_count = 0; + s->nesting_info[thread_depth].null_count = 0; } - d += blockDim.x; + depth += blockDim.x; } __syncthreads(); @@ -1076,7 +1109,7 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, if (is_decode_step) { int max_depth = s->col.max_nesting_depth; for (int idx = 0; idx < max_depth; idx++) { - PageNestingInfo* pni = &s->page.nesting[idx]; + PageNestingDecodeInfo* nesting_info = &s->nesting_info[idx]; size_t output_offset; // schemas without lists @@ -1085,21 +1118,21 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, } // for schemas with lists, we've already got the exact value precomputed else { - output_offset = pni->page_start_value; + output_offset = nesting_info->page_start_value; } - pni->data_out = static_cast(s->col.column_data_base[idx]); + nesting_info->data_out = static_cast(s->col.column_data_base[idx]); - if (pni->data_out != nullptr) { + if (nesting_info->data_out != nullptr) { // anything below max depth with a valid data pointer must be a list, so the // element size is the size of the offset type. uint32_t len = idx < max_depth - 1 ? sizeof(cudf::size_type) : s->dtype_len; - pni->data_out += (output_offset * len); + nesting_info->data_out += (output_offset * len); } - pni->valid_map = s->col.valid_map_base[idx]; - if (pni->valid_map != nullptr) { - pni->valid_map += output_offset >> 5; - pni->valid_map_offset = (int32_t)(output_offset & 0x1f); + nesting_info->valid_map = s->col.valid_map_base[idx]; + if (nesting_info->valid_map != nullptr) { + nesting_info->valid_map += output_offset >> 5; + nesting_info->valid_map_offset = (int32_t)(output_offset & 0x1f); } } } @@ -1217,26 +1250,26 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, * @brief Store a validity mask containing value_count bits into the output validity buffer of the * page. * - * @param[in,out] pni The page/nesting information to store the mask in. The validity map offset is - * also updated + * @param[in,out] nesting_info The page/nesting information to store the mask in. The validity map + * offset is also updated * @param[in] valid_mask The validity mask to be stored * @param[in] value_count # of bits in the validity mask */ -static __device__ void store_validity(PageNestingInfo* pni, +static __device__ void store_validity(PageNestingDecodeInfo* nesting_info, uint32_t valid_mask, int32_t value_count) { - int word_offset = pni->valid_map_offset / 32; - int bit_offset = pni->valid_map_offset % 32; + int word_offset = nesting_info->valid_map_offset / 32; + int bit_offset = nesting_info->valid_map_offset % 32; // if we fit entirely in the output word if (bit_offset + value_count <= 32) { auto relevant_mask = static_cast((static_cast(1) << value_count) - 1); if (relevant_mask == ~0) { - pni->valid_map[word_offset] = valid_mask; + nesting_info->valid_map[word_offset] = valid_mask; } else { - atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(pni->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); + atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(nesting_info->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); } } // we're going to spill over into the next word. @@ -1250,17 +1283,17 @@ static __device__ void store_validity(PageNestingInfo* pni, // first word. strip bits_left bits off the beginning and store that uint32_t relevant_mask = ((1 << bits_left) - 1); uint32_t mask_word0 = valid_mask & relevant_mask; - atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(pni->valid_map + word_offset, mask_word0 << bit_offset); + atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(nesting_info->valid_map + word_offset, mask_word0 << bit_offset); // second word. strip the remainder of the bits off the end and store that relevant_mask = ((1 << (value_count - bits_left)) - 1); uint32_t mask_word1 = valid_mask & (relevant_mask << bits_left); - atomicAnd(pni->valid_map + word_offset + 1, ~(relevant_mask)); - atomicOr(pni->valid_map + word_offset + 1, mask_word1 >> bits_left); + atomicAnd(nesting_info->valid_map + word_offset + 1, ~(relevant_mask)); + atomicOr(nesting_info->valid_map + word_offset + 1, mask_word1 >> bits_left); } - pni->valid_map_offset += value_count; + nesting_info->valid_map_offset += value_count; } /** @@ -1294,8 +1327,8 @@ inline __device__ void get_nesting_bounds(int& start_depth, // bound what nesting levels we apply values to if (s->col.max_level[level_type::REPETITION] > 0) { int r = s->rep[index]; - start_depth = s->page.nesting[r].start_depth; - end_depth = s->page.nesting[d].end_depth; + start_depth = s->nesting_info[r].start_depth; + end_depth = s->nesting_info[d].end_depth; } // for columns without repetition (even ones involving structs) we always // traverse the entire hierarchy. @@ -1326,6 +1359,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // how many rows we've processed in the page so far int input_row_count = s->input_row_count; + PageNestingDecodeInfo* nesting_info_base = s->nesting_info; + // process until we've reached the target while (input_value_count < target_input_value_count) { // determine the nesting bounds for this thread (the range of nesting depths we @@ -1367,14 +1402,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // walk from 0 to max_depth uint32_t next_thread_value_count, next_warp_value_count; for (int s_idx = 0; s_idx < max_depth; s_idx++) { - PageNestingInfo* pni = &s->page.nesting[s_idx]; + PageNestingDecodeInfo* nesting_info = &nesting_info_base[s_idx]; // if we are within the range of nesting levels we should be adding value indices for int const in_nesting_bounds = ((s_idx >= start_depth && s_idx <= end_depth) && in_row_bounds) ? 1 : 0; // everything up to the max_def_level is a non-null value - uint32_t const is_valid = d >= pni->max_def_level && in_nesting_bounds ? 1 : 0; + uint32_t const is_valid = d >= nesting_info->max_def_level && in_nesting_bounds ? 1 : 0; // compute warp and thread valid counts uint32_t const warp_valid_mask = @@ -1395,8 +1430,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // if this is the value column emit an index for value decoding if (is_valid && s_idx == max_depth - 1) { - int const src_pos = pni->valid_count + thread_valid_count; - int const dst_pos = pni->value_count + thread_value_count; + int const src_pos = nesting_info->valid_count + thread_valid_count; + int const dst_pos = nesting_info->value_count + thread_value_count; // nz_idx is a mapping of src buffer indices to destination buffer indices s->nz_idx[rolling_index(src_pos)] = dst_pos; } @@ -1414,12 +1449,12 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // if we're -not- at a leaf column and we're within nesting/row bounds // and we have a valid data_out pointer, it implies this is a list column, so // emit an offset. - if (in_nesting_bounds && pni->data_out != nullptr) { - int const idx = pni->value_count + thread_value_count; - cudf::size_type const ofs = s->page.nesting[s_idx + 1].value_count + + if (in_nesting_bounds && nesting_info->data_out != nullptr) { + int const idx = nesting_info->value_count + thread_value_count; + cudf::size_type const ofs = nesting_info_base[s_idx + 1].value_count + next_thread_value_count + - s->page.nesting[s_idx + 1].page_start_value; - (reinterpret_cast(pni->data_out))[idx] = ofs; + nesting_info_base[s_idx + 1].page_start_value; + (reinterpret_cast(nesting_info->data_out))[idx] = ofs; } } @@ -1441,14 +1476,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // increment count of valid values, count of total values, and update validity mask if (!t) { - if (pni->valid_map != nullptr && warp_valid_mask_bit_count > 0) { + if (nesting_info->valid_map != nullptr && warp_valid_mask_bit_count > 0) { uint32_t const warp_output_valid_mask = warp_valid_mask >> first_thread_in_write_range; - store_validity(pni, warp_output_valid_mask, warp_valid_mask_bit_count); + store_validity(nesting_info, warp_output_valid_mask, warp_valid_mask_bit_count); - pni->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask); + nesting_info->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask); } - pni->valid_count += warp_valid_count; - pni->value_count += warp_value_count; + nesting_info->valid_count += warp_valid_count; + nesting_info->value_count += warp_value_count; } // propagate value counts for the next level @@ -1463,7 +1498,7 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // update if (!t) { // update valid value count for decoding and total # of values we've processed - s->nz_count = s->page.nesting[max_depth - 1].valid_count; + s->nz_count = nesting_info_base[max_depth - 1].valid_count; s->input_value_count = input_value_count; s->input_row_count = input_row_count; } @@ -1545,7 +1580,7 @@ static __device__ void gpuUpdatePageSizes(page_state_s* s, // count rows and leaf values int const is_new_row = start_depth == 0 ? 1 : 0; uint32_t const warp_row_count_mask = ballot(is_new_row); - int const is_new_leaf = (d >= s->page.nesting[max_depth - 1].max_def_level) ? 1 : 0; + int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level) ? 1 : 0; uint32_t const warp_leaf_count_mask = ballot(is_new_leaf); // is this thread within row bounds? on the first pass we don't know the bounds, so we will be // computing the full size of the column. on the second pass, we will know our actual row @@ -1673,14 +1708,14 @@ __global__ void __launch_bounds__(block_size) // to do the expensive work of traversing the level data to determine sizes. we can just compute // it directly. if (!has_repetition && !compute_string_sizes) { - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { - if (is_base_pass) { pp->nesting[i].size = pp->num_input_values; } - pp->nesting[i].batch_size = pp->num_input_values; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + if (is_base_pass) { pp->nesting[thread_depth].size = pp->num_input_values; } + pp->nesting[thread_depth].batch_size = pp->num_input_values; } - d += blockDim.x; + depth += blockDim.x; } return; } @@ -1688,25 +1723,29 @@ __global__ void __launch_bounds__(block_size) // in the trim pass, for anything with lists, we only need to fully process bounding pages (those // at the beginning or the end of the row bounds) if (!is_base_pass && !is_bounds_page(s, min_row, num_rows)) { - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { // if we are not a bounding page (as checked above) then we are either // returning 0 rows from the page (completely outside the bounds) or all // rows in the page (completely within the bounds) - pp->nesting[i].batch_size = s->num_rows == 0 ? 0 : pp->nesting[i].size; + pp->nesting[thread_depth].batch_size = + s->num_rows == 0 ? 0 : pp->nesting[thread_depth].size; } - d += blockDim.x; + depth += blockDim.x; } return; } // zero sizes - int d = 0; - while (d < s->page.num_output_nesting_levels) { - if (d + t < s->page.num_output_nesting_levels) { s->page.nesting[d + t].batch_size = 0; } - d += blockDim.x; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->page.nesting[thread_depth].batch_size = 0; + } + depth += blockDim.x; } __syncthreads(); @@ -1754,13 +1793,13 @@ __global__ void __launch_bounds__(block_size) if (!t) { pp->num_rows = s->page.nesting[0].batch_size; } // store off this batch size as the "full" size - int d = 0; - while (d < s->page.num_output_nesting_levels) { - auto const i = d + t; - if (i < s->page.num_output_nesting_levels) { - pp->nesting[i].size = pp->nesting[i].batch_size; + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + pp->nesting[thread_depth].size = pp->nesting[thread_depth].batch_size; } - d += blockDim.x; + depth += blockDim.x; } } @@ -1808,6 +1847,8 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( ((s->col.data_type & 7) == BOOLEAN || (s->col.data_type & 7) == BYTE_ARRAY) ? 64 : 32; } + PageNestingDecodeInfo* nesting_info_base = s->nesting_info; + // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t skipped_leaf_values = s->page.skipped_leaf_values; while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) { @@ -1876,7 +1917,7 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( uint32_t dtype_len = s->dtype_len; void* dst = - s->page.nesting[leaf_level_index].data_out + static_cast(dst_pos) * dtype_len; + nesting_info_base[leaf_level_index].data_out + static_cast(dst_pos) * dtype_len; if (dtype == BYTE_ARRAY) { if (s->col.converted_type == DECIMAL) { auto const [ptr, len] = gpuGetStringData(s, val_src_pos); @@ -1931,6 +1972,19 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( } __syncthreads(); } + + // if we are using the nesting decode cache, copy null count back + if (s->nesting_info == s->nesting_decode_cache) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + int const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->page.nesting_decode[thread_depth].null_count = + s->nesting_decode_cache[thread_depth].null_count; + } + depth += blockDim.x; + } + } } } // anonymous namespace diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 33a189cdf87..9b156745e41 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -83,6 +83,40 @@ enum level_type { NUM_LEVEL_TYPES }; +/** + * @brief Nesting information specifically needed by the decode and preprocessing + * kernels. + * + * This data is kept separate from PageNestingInfo to keep it as small as possible. + * It is used in a cached form in shared memory when possible. + */ +struct PageNestingDecodeInfo { + // set up prior to decoding + int32_t max_def_level; + // input repetition/definition levels are remapped with these values + // into the corresponding real output nesting depths. + int32_t start_depth; + int32_t end_depth; + + // computed during preprocessing + int32_t page_start_value; + + // computed during decoding + int32_t null_count; + + // used internally during decoding + int32_t valid_map_offset; + int32_t valid_count; + int32_t value_count; + uint8_t* data_out; + bitmask_type* valid_map; +}; + +// Use up to 512 bytes of shared memory as a cache for nesting information. +// As of 1/20/23, this gives us a max nesting depth of 10 (after which it falls back to +// global memory). This handles all but the most extreme cases. +constexpr int max_cacheable_nesting_decode_info = (512) / sizeof(PageNestingDecodeInfo); + /** * @brief Nesting information * @@ -94,30 +128,15 @@ enum level_type { * */ struct PageNestingInfo { - // input repetition/definition levels are remapped with these values - // into the corresponding real output nesting depths. - int32_t start_depth; - int32_t end_depth; - - // set at initialization - int32_t max_def_level; - int32_t max_rep_level; + // set at initialization (see start_offset_output_iterator in reader_impl_preprocess.cu) cudf::type_id type; // type of the corresponding cudf output column bool nullable; - // set during preprocessing + // TODO: these fields might make sense to move into PageNestingDecodeInfo for memory performance + // reasons. int32_t size; // this page/nesting-level's row count contribution to the output column, if fully // decoded - int32_t batch_size; // the size of the page for this batch - int32_t page_start_value; // absolute output start index in output column data - - // set during data decoding - int32_t valid_count; // # of valid values decoded in this page/nesting-level - int32_t value_count; // total # of values decoded in this page/nesting-level - int32_t null_count; // null count - int32_t valid_map_offset; // current offset in bits relative to valid_map - uint8_t* data_out; // pointer into output buffer - uint32_t* valid_map; // pointer into output validity buffer + int32_t batch_size; // the size of the page for this batch }; /** @@ -159,9 +178,9 @@ struct PageInfo { // skipped_leaf_values will always be 0. // // # of values skipped in the repetition/definition level stream - int skipped_values; + int32_t skipped_values; // # of values skipped in the actual data stream. - int skipped_leaf_values; + int32_t skipped_leaf_values; // for string columns only, the size of all the chars in the string for // this page. only valid/computed during the base preprocess pass int32_t str_bytes; @@ -170,9 +189,10 @@ struct PageInfo { // input column nesting information, output column nesting information and // mappings between the two. the length of the array, nesting_info_size is // max(num_output_nesting_levels, max_definition_levels + 1) - int num_output_nesting_levels; - int nesting_info_size; + int32_t num_output_nesting_levels; + int32_t nesting_info_size; PageNestingInfo* nesting; + PageNestingDecodeInfo* nesting_decode; }; /** @@ -242,7 +262,7 @@ struct ColumnChunkDesc { PageInfo* page_info; // output page info for up to num_dict_pages + // num_data_pages (dictionary pages first) string_index_pair* str_dict_index; // index for string dictionary - uint32_t** valid_map_base; // base pointers of valid bit map for this column + bitmask_type** valid_map_base; // base pointers of valid bit map for this column void** column_data_base; // base pointers of column data int8_t codec; // compressed codec enum int8_t converted_type; // converted type enum @@ -263,6 +283,7 @@ struct file_intermediate_data { hostdevice_vector chunks{}; hostdevice_vector pages_info{}; hostdevice_vector page_nesting_info{}; + hostdevice_vector page_nesting_decode_info{}; }; /** diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index d5dac10b8f6..b1c4dd22c0d 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -24,9 +24,10 @@ namespace cudf::io::detail::parquet { void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) { - auto& chunks = _file_itm_data.chunks; - auto& pages = _file_itm_data.pages_info; - auto& page_nesting = _file_itm_data.page_nesting_info; + auto& chunks = _file_itm_data.chunks; + auto& pages = _file_itm_data.pages_info; + auto& page_nesting = _file_itm_data.page_nesting_info; + auto& page_nesting_decode = _file_itm_data.page_nesting_decode_info; // Should not reach here if there is no page data. CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); @@ -39,7 +40,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // In order to reduce the number of allocations of hostdevice_vector, we allocate a single vector // to store all per-chunk pointers to nested data/nullmask. `chunk_offsets[i]` will store the // offset into `chunk_nested_data`/`chunk_nested_valids` for the array of pointers for chunk `i` - auto chunk_nested_valids = hostdevice_vector(sum_max_depths, _stream); + auto chunk_nested_valids = hostdevice_vector(sum_max_depths, _stream); auto chunk_nested_data = hostdevice_vector(sum_max_depths, _stream); auto chunk_offsets = std::vector(); @@ -124,6 +125,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) pages.device_to_host(_stream); page_nesting.device_to_host(_stream); + page_nesting_decode.device_to_host(_stream); _stream.synchronize(); // for list columns, add the final offset to every offset buffer. @@ -166,8 +168,8 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) gpu::ColumnChunkDesc* col = &chunks[pi->chunk_idx]; input_column_info const& input_col = _input_columns[col->src_col_index]; - int index = pi->nesting - page_nesting.device_ptr(); - gpu::PageNestingInfo* pni = &page_nesting[index]; + int index = pi->nesting_decode - page_nesting_decode.device_ptr(); + gpu::PageNestingDecodeInfo* pndi = &page_nesting_decode[index]; auto* cols = &_output_buffers; for (size_t l_idx = 0; l_idx < input_col.nesting_depth(); l_idx++) { @@ -178,7 +180,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) if (chunk_nested_valids.host_ptr(chunk_offsets[pi->chunk_idx])[l_idx] == nullptr) { continue; } - out_buf.null_count() += pni[l_idx].null_count; + out_buf.null_count() += pndi[l_idx].null_count; } } diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 435fdb1a411..6577a1a3f0f 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -523,9 +523,10 @@ void decode_page_headers(hostdevice_vector& chunks, void reader::impl::allocate_nesting_info() { - auto const& chunks = _file_itm_data.chunks; - auto& pages = _file_itm_data.pages_info; - auto& page_nesting_info = _file_itm_data.page_nesting_info; + auto const& chunks = _file_itm_data.chunks; + auto& pages = _file_itm_data.pages_info; + auto& page_nesting_info = _file_itm_data.page_nesting_info; + auto& page_nesting_decode_info = _file_itm_data.page_nesting_decode_info; // compute total # of page_nesting infos needed and allocate space. doing this in one // buffer to keep it to a single gpu allocation @@ -539,6 +540,8 @@ void reader::impl::allocate_nesting_info() }); page_nesting_info = hostdevice_vector{total_page_nesting_infos, _stream}; + page_nesting_decode_info = + hostdevice_vector{total_page_nesting_infos, _stream}; // retrieve from the gpu so we can update pages.device_to_host(_stream, true); @@ -556,6 +559,9 @@ void reader::impl::allocate_nesting_info() target_page_index += chunks[idx].num_dict_pages; for (int p_idx = 0; p_idx < chunks[idx].num_data_pages; p_idx++) { pages[target_page_index + p_idx].nesting = page_nesting_info.device_ptr() + src_info_index; + pages[target_page_index + p_idx].nesting_decode = + page_nesting_decode_info.device_ptr() + src_info_index; + pages[target_page_index + p_idx].nesting_info_size = per_page_nesting_info_size; pages[target_page_index + p_idx].num_output_nesting_levels = _metadata->get_output_nesting_depth(src_col_schema); @@ -601,6 +607,9 @@ void reader::impl::allocate_nesting_info() gpu::PageNestingInfo* pni = &page_nesting_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + gpu::PageNestingDecodeInfo* nesting_info = + &page_nesting_decode_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + // if we have lists, set our start and end depth remappings if (schema.max_repetition_level > 0) { auto remap = depth_remapping.find(src_col_schema); @@ -610,17 +619,16 @@ void reader::impl::allocate_nesting_info() std::vector const& def_depth_remap = (remap->second.second); for (size_t m = 0; m < rep_depth_remap.size(); m++) { - pni[m].start_depth = rep_depth_remap[m]; + nesting_info[m].start_depth = rep_depth_remap[m]; } for (size_t m = 0; m < def_depth_remap.size(); m++) { - pni[m].end_depth = def_depth_remap[m]; + nesting_info[m].end_depth = def_depth_remap[m]; } } // values indexed by output column index - pni[cur_depth].max_def_level = cur_schema.max_definition_level; - pni[cur_depth].max_rep_level = cur_schema.max_repetition_level; - pni[cur_depth].size = 0; + nesting_info[cur_depth].max_def_level = cur_schema.max_definition_level; + pni[cur_depth].size = 0; pni[cur_depth].type = to_type_id(cur_schema, _strings_to_categorical, _timestamp_type.id()); pni[cur_depth].nullable = cur_schema.repetition_type == OPTIONAL; @@ -640,6 +648,7 @@ void reader::impl::allocate_nesting_info() // copy nesting info to the device page_nesting_info.host_to_device(_stream); + page_nesting_decode_info.host_to_device(_stream); } void reader::impl::load_and_decompress_data(std::vector const& row_groups_info, @@ -1256,7 +1265,7 @@ struct start_offset_output_iterator { if (p.src_col_schema != src_col_schema || p.flags & gpu::PAGEINFO_FLAGS_DICTIONARY) { return empty; } - return p.nesting[nesting_depth].page_start_value; + return p.nesting_decode[nesting_depth].page_start_value; } }; diff --git a/cpp/src/rolling/detail/rolling.cuh b/cpp/src/rolling/detail/rolling.cuh index fcc85b4f913..d996f88ca49 100644 --- a/cpp/src/rolling/detail/rolling.cuh +++ b/cpp/src/rolling/detail/rolling.cuh @@ -84,16 +84,9 @@ struct DeviceRolling { static constexpr bool is_supported() { return cudf::detail::is_valid_aggregation() && has_corresponding_operator() && - // TODO: Delete all this extra logic once is_valid_aggregation<> cleans up some edge - // cases it isn't handling. - // MIN/MAX supports all fixed width types + // MIN/MAX only supports fixed width types (((O == aggregation::MIN || O == aggregation::MAX) && cudf::is_fixed_width()) || - - // SUM supports all fixed width types except timestamps - ((O == aggregation::SUM) && (cudf::is_fixed_width() && !cudf::is_timestamp())) || - - // MEAN supports numeric and duration - ((O == aggregation::MEAN) && (cudf::is_numeric() || cudf::is_duration()))); + (O == aggregation::SUM) || (O == aggregation::MEAN)); } // operations we do support diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 2cd6e49d7bb..21752196430 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -38,6 +38,7 @@ #include #include +#include #include @@ -4826,6 +4827,49 @@ TEST_F(ParquetReaderTest, StructByteArray) CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); } +TEST_F(ParquetReaderTest, NestingOptimizationTest) +{ + // test nesting levels > cudf::io::parquet::gpu::max_cacheable_nesting_decode_info deep. + constexpr cudf::size_type num_nesting_levels = 16; + static_assert(num_nesting_levels > cudf::io::parquet::gpu::max_cacheable_nesting_decode_info); + constexpr cudf::size_type rows_per_level = 2; + + constexpr cudf::size_type num_values = (1 << num_nesting_levels) * rows_per_level; + auto value_iter = thrust::make_counting_iterator(0); + auto validity = + cudf::detail::make_counting_transform_iterator(0, [](cudf::size_type i) { return i % 2; }); + cudf::test::fixed_width_column_wrapper values(value_iter, value_iter + num_values, validity); + + // ~256k values with num_nesting_levels = 16 + int total_values_produced = num_values; + auto prev_col = values.release(); + for (int idx = 0; idx < num_nesting_levels; idx++) { + auto const depth = num_nesting_levels - idx; + auto const num_rows = (1 << (num_nesting_levels - idx)); + + auto offsets_iter = cudf::detail::make_counting_transform_iterator( + 0, [depth, rows_per_level](cudf::size_type i) { return i * rows_per_level; }); + total_values_produced += (num_rows + 1); + + cudf::test::fixed_width_column_wrapper offsets(offsets_iter, + offsets_iter + num_rows + 1); + auto c = cudf::make_lists_column(num_rows, offsets.release(), std::move(prev_col), 0, {}); + prev_col = std::move(c); + } + auto const& expect = prev_col; + + auto filepath = temp_env->get_temp_filepath("NestingDecodeCache.parquet"); + cudf::io::parquet_writer_options opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, table_view{{*expect}}); + cudf::io::write_parquet(opts); + + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expect, result.tbl->get_column(0)); +} + TEST_F(ParquetWriterTest, SingleValueDictionaryTest) { constexpr unsigned int expected_bits = 1; diff --git a/cpp/tests/rolling/empty_input_test.cpp b/cpp/tests/rolling/empty_input_test.cpp index 626563d5eba..aca1cbf40b7 100644 --- a/cpp/tests/rolling/empty_input_test.cpp +++ b/cpp/tests/rolling/empty_input_test.cpp @@ -183,22 +183,36 @@ TYPED_TEST(TypedRollingEmptyInputTest, EmptyFixedWidthInputs) /// `SUM` returns 64-bit promoted types for integral/decimal input. /// For other fixed-width input types, the same type is returned. + /// Timestamp types are not supported. { auto aggs = agg_vector_t{}; aggs.emplace_back(sum()); using expected_type = cudf::detail::target_type_t; - rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + if constexpr (cudf::is_timestamp()) { + EXPECT_THROW( + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()), + cudf::logic_error); + } else { + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + } } /// `MEAN` returns float64 for all numeric types, - /// except for chrono-types, which yield the same chrono-type. + /// except for duration-types, which yield the same duration-type. + /// Timestamp types are not supported. { auto aggs = agg_vector_t{}; aggs.emplace_back(mean()); using expected_type = cudf::detail::target_type_t; - rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + if constexpr (cudf::is_timestamp()) { + EXPECT_THROW( + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()), + cudf::logic_error); + } else { + rolling_output_type_matches(empty_input, aggs, cudf::type_to_id()); + } } /// For an input type `T`, `COLLECT_LIST` returns a column of type `list`. diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 54a28db1b58..bac5eae6b34 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -236,13 +236,11 @@ Spilling consists of two components: - A spill manager that tracks all instances of `SpillableBuffer` and spills them on demand. A global spill manager is used throughout cudf when spilling is enabled, which makes `as_buffer()` return `SpillableBuffer` instead of the default `Buffer` instances. -Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. +Accessing `Buffer.get_ptr(...)`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.get_ptr(...)`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. To address this, we mark the `SpillableBuffer` as unspillable, we say that the buffer has been _exposed_. This can either be permanent if the device pointer is exposed to external projects or temporary while `libcudf` accesses the device memory. -The `SpillableBuffer.get_ptr()` returns the device pointer of the buffer memory just like `.ptr` but if given an instance of `SpillLock`, the buffer is only unspillable as long as the instance of `SpillLock` is alive. - -For convenience, one can use the decorator/context `acquire_spill_lock` to associate a `SpillLock` with a lifetime bound to the context automatically. +The `SpillableBuffer.get_ptr(...)` returns the device pointer of the buffer memory but if called within an `acquire_spill_lock` decorator/context, the buffer is only marked unspillable while running within the decorator/context. #### Statistics cuDF supports spilling statistics, which can be very useful for performance profiling and to identify code that renders buffers unspillable. diff --git a/java/src/main/java/ai/rapids/cudf/CaptureGroups.java b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java new file mode 100644 index 00000000000..2ab778dbc35 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java @@ -0,0 +1,36 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Capture groups setting, closely following cudf::strings::capture_groups. + * + * For processing a regex pattern containing capture groups. These can be used + * to optimize the generated regex instructions where the capture groups do not + * require extracting the groups. + */ +public enum CaptureGroups { + EXTRACT(0), // capture groups processed normally for extract + NON_CAPTURE(1); // convert all capture groups to non-capture groups + + final int nativeId; // Native id, for use with libcudf. + private CaptureGroups(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 47d6b7573cd..8ffe5b4aa09 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2531,12 +2531,35 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) { * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplit(regexProg, limit); + } else { + return stringSplit(pattern, limit); + } + } + + /** + * Returns a list of columns by splitting each string using the specified regex program. The + * number of rows in the output columns will be the same as the input column. Null entries + * are added for a row where split results have been exhausted. Null input entries result in + * all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; - return new Table(stringSplit(this.getNativeView(), pattern, limit, splitByRegex)); + return new Table(stringSplit(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2550,6 +2573,7 @@ public final Table stringSplit(String pattern, int limit, boolean splitByRegex) * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, boolean splitByRegex) { return stringSplit(pattern, -1, splitByRegex); } @@ -2567,7 +2591,11 @@ public final Table stringSplit(String pattern, boolean splitByRegex) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter, int limit) { - return stringSplit(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be a String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new Table(stringSplit(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2580,7 +2608,21 @@ public final Table stringSplit(String delimiter, int limit) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter) { - return stringSplit(delimiter, -1, false); + return stringSplit(delimiter, -1); + } + + /** + * Returns a list of columns by splitting each string using the specified regex program with + * string literal delimiter. The number of rows in the output columns will be the same as the + * input column. Null entries are added for a row where split results have been exhausted. + * Null input entries result in all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg) { + return stringSplit(regexProg, -1); } /** @@ -2595,13 +2637,34 @@ public final Table stringSplit(String delimiter) { * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplitRecord(regexProg, limit); + } else { + return stringSplitRecord(pattern, limit); + } + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program pattern. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; return new ColumnVector( - stringSplitRecord(this.getNativeView(), pattern, limit, splitByRegex)); + stringSplitRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2613,6 +2676,7 @@ public final ColumnVector stringSplitRecord(String pattern, int limit, boolean s * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex) { return stringSplitRecord(pattern, -1, splitByRegex); } @@ -2628,7 +2692,12 @@ public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter, int limit) { - return stringSplitRecord(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new ColumnVector( + stringSplitRecord(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2639,7 +2708,19 @@ public final ColumnVector stringSplitRecord(String delimiter, int limit) { * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter) { - return stringSplitRecord(delimiter, -1, false); + return stringSplitRecord(delimiter, -1); + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program with string literal delimiter. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg) { + return stringSplitRecord(regexProg, -1); } /** @@ -2846,10 +2927,23 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) { * @param repl The string scalar to replace for each pattern match. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl) { return replaceRegex(pattern, repl, -1); } + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl) { + return replaceRegex(regexProg, repl, -1); + } + /** * For each string, replaces any character sequence matching the given pattern using the * replacement string scalar. @@ -2859,12 +2953,27 @@ public final ColumnVector replaceRegex(String pattern, Scalar repl) { * @param maxRepl The maximum number of times a replacement should occur within each string. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) { + return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl, maxRepl); + } + + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @param maxRepl The maximum number of times a replacement should occur within each string. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl, int maxRepl) { if (!repl.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Replacement must be a string scalar"); } - return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(), - maxRepl)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(replaceRegex(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, repl.getScalarHandle(), maxRepl)); } /** @@ -2890,9 +2999,25 @@ public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls) * @param replace The replacement template for creating the output string. * @return A new java column vector containing the string results. */ + @Deprecated public final ColumnVector stringReplaceWithBackrefs(String pattern, String replace) { - return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), pattern, - replace)); + return stringReplaceWithBackrefs(new RegexProgram(pattern), replace); + } + + /** + * For each string, replaces any character sequence matching the given regex program + * pattern using the replace template for back-references. + * + * Any null string entries return corresponding null output column entries. + * + * @param regexProg The regex program with pattern to search within each string. + * @param replace The replacement template for creating the output string. + * @return A new java column vector containing the string results. + */ + public final ColumnVector stringReplaceWithBackrefs(RegexProgram regexProg, String replace) { + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), regexProg.pattern(), + regexProg.combinedFlags(), regexProg.capture().nativeId, replace)); } /** @@ -3164,11 +3289,32 @@ public final ColumnVector clamp(Scalar lo, Scalar loReplace, Scalar hi, Scalar h * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector matchesRe(String pattern) { + return matchesRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given regex program but only at the beginning of the string. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.matches_re("\\d+") + * r is now [false, true, false] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector matchesRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(matchesRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(matchesRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3177,7 +3323,7 @@ public final ColumnVector matchesRe(String pattern) { * * ``` * cv = ["abc","123","def456"] - * result = cv.matches_re("\\d+") + * result = cv.contains_re("\\d+") * r is now [false, true, true] * ``` * Any null string entries return corresponding null output column entries. @@ -3187,11 +3333,32 @@ public final ColumnVector matchesRe(String pattern) { * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector containsRe(String pattern) { + return containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given RegexProgram object starting at any location. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.contains_re("\\d+") + * r is now [false, true, true] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector containsRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(containsRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(containsRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3206,10 +3373,27 @@ public final ColumnVector containsRe(String pattern) { * @throws CudfException if any error happens including if the RE does * not contain any capture groups. */ + @Deprecated public final Table extractRe(String pattern) throws CudfException { + return extractRe(new RegexProgram(pattern)); + } + + /** + * For each captured group specified in the given regex program + * return a column in the table. Null entries are added if the string + * does not match. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg the regex program to use + * @return the table of extracted matches + * @throws CudfException if any error happens including if the RE does + * not contain any capture groups. + */ + public final Table extractRe(RegexProgram regexProg) throws CudfException { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - return new Table(extractRe(this.getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + return new Table(extractRe(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3222,11 +3406,28 @@ public final Table extractRe(String pattern) throws CudfException { * @param idx The regex group index * @return A new column vector of extracted matches */ + @Deprecated public final ColumnVector extractAllRecord(String pattern, int idx) { + return extractAllRecord(new RegexProgram(pattern), idx); + } + + /** + * Extracts all strings that match the given regex program and corresponds to the + * regular expression group index. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg The regex program + * @param idx The regex group index + * @return A new column vector of extracted matches + */ + public final ColumnVector extractAllRecord(RegexProgram regexProg, int idx) { assert type.equals(DType.STRING) : "column type must be a String"; assert idx >= 0 : "group index must be at least 0"; - - return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector( + extractAllRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, idx)); } /** @@ -3881,14 +4082,16 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long[] stringSplit(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long[] stringSplit(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Returns a column that are lists of strings in which each list is made by splitting the @@ -3896,14 +4099,16 @@ private static native long[] stringSplit(long nativeHandle, String pattern, int * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long stringSplitRecord(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long stringSplitRecord(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Native method to calculate substring from a given string column. 0 indexing. @@ -3941,12 +4146,14 @@ private static native long substringColumn(long columnView, long startColumn, lo * Native method for replacing each regular expression pattern match with the specified * replacement string. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression pattern to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param repl native handle of the cudf::scalar containing the replacement string. * @param maxRepl maximum number of times to replace the pattern within a string * @return native handle of the resulting cudf column containing the string results. */ - private static native long replaceRegex(long columnView, String pattern, + private static native long replaceRegex(long columnView, String pattern, int flags, int capture, long repl, long maxRepl) throws CudfException; /** @@ -3960,15 +4167,17 @@ private static native long replaceMultiRegex(long columnView, String[] patterns, long repls) throws CudfException; /** - * Native method for replacing any character sequence matching the given pattern - * using the replace template for back-references. + * Native method for replacing any character sequence matching the given regex program + * pattern using the replace template for back-references. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression patterns to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param replace The replacement template for creating the output string. * @return native handle of the resulting cudf column containing the string results. */ - private static native long stringReplaceWithBackrefs(long columnView, String pattern, - String replace) throws CudfException; + private static native long stringReplaceWithBackrefs(long columnView, String pattern, int flags, + int capture, String replace) throws CudfException; /** * Native method for checking if strings in a column starts with a specified comparison string. @@ -3995,21 +4204,25 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringStrip(long columnView, int type, long toStrip) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern from the + * Native method for checking if strings match the passed in regex program from the * beginning of the string. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long matchesRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long matchesRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern starting at any location. + * Native method for checking if strings match the passed in regex program starting at any location. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long containsRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long containsRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** * Native method for checking if strings match the passed in like pattern @@ -4030,19 +4243,21 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringContains(long cudfViewHandle, long compString) throws CudfException; /** - * Native method for extracting results from an regular expressions. Returns a table handle. + * Native method for extracting results from a regex program. Returns a table handle. */ - private static native long[] extractRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long[] extractRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for extracting all results corresponding to group idx from a regular expression. + * Native method for extracting all results corresponding to group idx from a regex program. * * @param nativeHandle Native handle of the cudf::column_view being operated on. - * @param pattern String regex pattern. + * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param idx Regex group index. A 0 value means matching the entire regex. * @return Native handle of a string column of the result. */ - private static native long extractAllRecord(long nativeHandle, String pattern, int idx); + private static native long extractAllRecord(long nativeHandle, String pattern, int flags, int capture, int idx); private static native long urlDecode(long cudfViewHandle); diff --git a/java/src/main/java/ai/rapids/cudf/RegexFlag.java b/java/src/main/java/ai/rapids/cudf/RegexFlag.java new file mode 100644 index 00000000000..7ed8e0354c9 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexFlag.java @@ -0,0 +1,37 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Regex flags setting, closely following cudf::strings::regex_flags. + * + * These types can be or'd to combine them. The values are chosen to + * leave room for future flags and to match the Python flag values. + */ +public enum RegexFlag { + DEFAULT(0), // default + MULTILINE(8), // the '^' and '$' honor new-line characters + DOTALL(16), // the '.' matching includes new-line characters + ASCII(256); // use only ASCII when matching built-in character classes + + final int nativeId; // Native id, for use with libcudf. + private RegexFlag(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RegexProgram.java b/java/src/main/java/ai/rapids/cudf/RegexProgram.java new file mode 100644 index 00000000000..358eea8ba43 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexProgram.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +import java.util.EnumSet; + +/** + * Regex program class, closely following cudf::strings::regex_program. + */ +public class RegexProgram { + private String pattern; // regex pattern + private EnumSet flags; // regex flags for interpreting special characters in the pattern + // controls how capture groups in the pattern are used + // default is to extract a capture group + private CaptureGroups capture; + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + */ + public RegexProgram(String pattern) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + */ + public RegexProgram(String pattern, EnumSet flags) { + this(pattern, flags, CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, CaptureGroups capture) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), capture); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, EnumSet flags, CaptureGroups capture) { + assert pattern != null : "pattern may not be null"; + this.pattern = pattern; + this.flags = flags; + this.capture = capture; + } + + /** + * Get the pattern used to create this instance + * + * @param return A regex pattern as a string + */ + public String pattern() { + return pattern; + } + + /** + * Get the regex flags setting used to create this instance + * + * @param return Regex flags setting + */ + public EnumSet flags() { + return flags; + } + + /** + * Reset the regex flags setting for this instance + * + * @param flags Regex flags setting + */ + public void setFlags(EnumSet flags) { + this.flags = flags; + } + + /** + * Get the capture groups setting used to create this instance + * + * @param return Capture groups setting + */ + public CaptureGroups capture() { + return capture; + } + + /** + * Reset the capture groups setting for this instance + * + * @param capture Capture groups setting + */ + public void setCapture(CaptureGroups capture) { + this.capture = capture; + } + + /** + * Combine the regex flags using 'or' + * + * @param return An integer representing the value of combined (or'ed) flags + */ + public int combinedFlags() { + int allFlags = 0; + for (RegexFlag flag : flags) { + allFlags = allFlags | flag.nativeId; + } + return allFlags; + } +} diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index b48ddae196b..c17e16bce73 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -62,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -678,11 +679,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reverseStringsOrLists(JNI CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -696,31 +695,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); auto const max_split = limit > 1 ? limit - 1 : limit; + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto result = split_by_regex ? - cudf::strings::split_re(strs_input, pattern, max_split) : - cudf::strings::split(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_re(strings_column, *regex_prog, max_split) : + cudf::strings::split(strings_column, cudf::string_scalar{pattern}, max_split); return cudf::jni::convert_table_for_return(env, std::move(result)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -734,22 +727,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto const max_split = limit > 1 ? limit - 1 : limit; auto result = split_by_regex ? - cudf::strings::split_record_re(strs_input, pattern, max_split) : - cudf::strings::split_record(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_record_re(strings_column, *regex_prog, max_split) : + cudf::strings::split_record(strings_column, cudf::string_scalar{pattern}, max_split); return release_as_jlong(result); } CATCH_STD(env, 0); @@ -1290,32 +1279,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringContains(JNIEnv *en JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_matchesRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::matches_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return release_as_jlong(cudf::strings::matches_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::contains_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const capture = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, capture); + return release_as_jlong(cudf::strings::contains_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } @@ -1555,21 +1554,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass, - jlong j_column_view, - jstring j_pattern, jlong j_repl, - jlong j_maxrepl) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex( + JNIEnv *env, jclass, jlong j_column_view, jstring j_pattern, jint regex_flags, + jint capture_groups, jlong j_repl, jlong j_maxrepl) { JNI_NULL_CHECK(env, j_column_view, "column is null", 0); JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0); JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0); try { cudf::jni::auto_set_device(env); - auto cv = reinterpret_cast(j_column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring pattern(env, j_pattern); - auto repl = reinterpret_cast(j_repl); - return release_as_jlong(cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl)); + auto const column_view = reinterpret_cast(j_column_view); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, j_pattern); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto const repl = reinterpret_cast(j_repl); + return release_as_jlong( + cudf::strings::replace_re(strings_column, *regex_prog, *repl, j_maxrepl)); } CATCH_STD(env, 0); } @@ -1595,19 +1597,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs( - JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) { + JNIEnv *env, jclass, jlong column_view, jstring pattern_obj, jint regex_flags, + jint capture_groups, jstring replaceObj) { JNI_NULL_CHECK(env, column_view, "column is null", 0); - JNI_NULL_CHECK(env, patternObj, "pattern string is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern string is null", 0); JNI_NULL_CHECK(env, replaceObj, "replace string is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring ss_pattern(env, patternObj); + auto const cv = reinterpret_cast(column_view); + auto const strings_column = cudf::strings_column_view{*cv}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); cudf::jni::native_jstring ss_replace(env, replaceObj); return release_as_jlong( - cudf::strings::replace_with_backrefs(scv, ss_pattern.get(), ss_replace.get())); + cudf::strings::replace_with_backrefs(strings_column, *regex_prog, ss_replace.get())); } CATCH_STD(env, 0); } @@ -1663,37 +1669,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringStrip(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *env, jclass, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", nullptr); - JNI_NULL_CHECK(env, patternObj, "pattern is null", nullptr); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", nullptr); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, patternObj); - - return cudf::jni::convert_table_for_return( - env, cudf::strings::extract(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return cudf::jni::convert_table_for_return(env, + cudf::strings::extract(strings_column, *regex_prog)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *env, jclass, - jlong j_view_handle, - jstring pattern_obj, - jint idx) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord( + JNIEnv *env, jclass, jlong j_view_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint idx) { JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", 0); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, pattern_obj); - - auto result = (idx == 0) ? cudf::strings::findall(strings_column, pattern.get()) : - cudf::strings::extract_all_record(strings_column, pattern.get()); - + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto result = (idx == 0) ? cudf::strings::findall(strings_column, *regex_prog) : + cudf::strings::extract_all_record(strings_column, *regex_prog); return release_as_jlong(result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index fc0a542e0a7..5b846545906 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4040,41 +4040,50 @@ void testStringFindOperations() { @Test void testExtractRe() { - try (ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); - Table expected = new Table.TestBuilder() - .column("a", "b", null, null) - .column("1", "2", null, null) - .build(); - Table found = input.extractRe("([ab])(\\d)")) { - assertTablesAreEqual(expected, found); - } + ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); + Table expected = new Table.TestBuilder() + .column("a", "b", null, null) + .column("1", "2", null, null) + .build(); + try (Table found = input.extractRe("([ab])(\\d)")) { + assertTablesAreEqual(expected, found); + } + try (Table found = input.extractRe(new RegexProgram("([ab])(\\d)"))) { + assertTablesAreEqual(expected, found); + } } @Test void testExtractAllRecord() { String pattern = "([ab])(\\d)"; - try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); - ColumnVector expectedIdx0 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a1"), - Arrays.asList("b2"), - Arrays.asList(), - null, - Arrays.asList("a1", "b1", "a2")); - ColumnVector expectedIdx12 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a", "1"), - Arrays.asList("b", "2"), - null, - null, - Arrays.asList("a", "1", "b", "1", "a", "2")); - - ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); - ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); - ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2); - ) { + RegexProgram regexProg = new RegexProgram(pattern); + ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); + ColumnVector expectedIdx0 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a1"), + Arrays.asList("b2"), + Arrays.asList(), + null, + Arrays.asList("a1", "b1", "a2")); + ColumnVector expectedIdx12 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a", "1"), + Arrays.asList("b", "2"), + null, + null, + Arrays.asList("a", "1", "b", "1", "a", "2")); + try (ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); + ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); + ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2)) { + assertColumnsAreEqual(expectedIdx0, resultIdx0); + assertColumnsAreEqual(expectedIdx12, resultIdx1); + assertColumnsAreEqual(expectedIdx12, resultIdx2); + } + try (ColumnVector resultIdx0 = v.extractAllRecord(regexProg, 0); + ColumnVector resultIdx1 = v.extractAllRecord(regexProg, 1); + ColumnVector resultIdx2 = v.extractAllRecord(regexProg, 2)) { assertColumnsAreEqual(expectedIdx0, resultIdx0); assertColumnsAreEqual(expectedIdx12, resultIdx1); assertColumnsAreEqual(expectedIdx12, resultIdx2); @@ -4087,25 +4096,37 @@ void testMatchesRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res1 = testStrings.matchesRe(patternString1); + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00"); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, + true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, + false, false); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, + true, true); + try (ColumnVector res1 = testStrings.matchesRe(patternString1); ColumnVector res2 = testStrings.matchesRe(patternString2); - ColumnVector res3 = testStrings.matchesRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, - true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, - false, false); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, - true, true)) { + ColumnVector res3 = testStrings.matchesRe(patternString3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + try (ColumnVector res1 = testStrings.matchesRe(regexProg1); + ColumnVector res2 = testStrings.matchesRe(regexProg2); + ColumnVector res3 = testStrings.matchesRe(regexProg3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res = testStrings.matchesRe(patternString4)) {} + try (ColumnVector res = testStrings.matchesRe(patternString4)) {} + }); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.matchesRe(regexProg4)) {} }); } @@ -4115,36 +4136,51 @@ void testContainsRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, + true, true, true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, + false, false, false, true); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, + true, true, true, true); + try (ColumnVector res1 = testStrings.containsRe(patternString1); ColumnVector res2 = testStrings.containsRe(patternString2); - ColumnVector res3 = testStrings.containsRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, - true, true, true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, - false, false, false, true); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, - true, true, true, true)) { + ColumnVector res3 = testStrings.containsRe(patternString3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } + try (ColumnVector res1 = testStrings.containsRe(regexProg1); + ColumnVector res2 = testStrings.containsRe(regexProg2); + ColumnVector res3 = testStrings.containsRe(regexProg3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + ColumnVector testStringsError = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStringsError.containsRe(patternString4)) {}}); assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res = testStrings.containsRe(patternString4)) {} + try (ColumnVector res = testStringsError.containsRe(regexProg4)) {} }); } @Test - @Disabled("Needs fix for https://github.com/rapidsai/cudf/issues/4671") void testContainsReEmptyInput() { String patternString1 = ".*"; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings(""); ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector resRe1 = testStrings.containsRe(regexProg1); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true)) { assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected1, resRe1); } } @@ -4405,9 +4441,13 @@ void testsubstring() { @Test void testExtractListElements() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(0)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(0)) { assertColumnsAreEqual(expected, result); } @@ -4415,10 +4455,14 @@ void testExtractListElements() { @Test void testExtractListElementsV() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); - ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); + ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(indices)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(indices)) { assertColumnsAreEqual(expected, result); } @@ -4947,103 +4991,127 @@ void testReverseList() { @Test void testStringSplit() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some things", "strings here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some things", "strings here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2); Table resultSplitAll = v.stringSplit(pattern)) { - assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } } @Test void testStringSplitByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some_things", "strings_here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some_things", "strings_here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); Table resultSplitAll = v.stringSplit(pattern, true)) { assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); assertTablesAreEqual(expectedSplitAll, resultSplitAll); } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecord() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some things"), - Arrays.asList("test", "strings here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some things"), + Arrays.asList("test", "strings here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); ColumnVector resultSplitAll = v.stringSplitRecord(pattern)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecordByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some_things"), - Arrays.asList("test", "strings_here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some_things"), + Arrays.asList("test", "strings_here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); ColumnVector resultSplitAll = v.stringSplitRecord(pattern, true)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test @@ -5091,26 +5159,37 @@ void teststringReplaceThrowsException() { @Test void testReplaceRegex() { - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl); + ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + String pattern = "[tT]itle"; + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + try (ColumnVector actual = v.replaceRegex(pattern, repl); ColumnVector expected = ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) { + try (ColumnVector actual = v.replaceRegex(pattern, repl, 0)) { assertColumnsAreEqual(v, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1); + try (ColumnVector actual = v.replaceRegex(pattern, repl, 1); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 0)) { + assertColumnsAreEqual(v, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 1); ColumnVector expected = ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); @@ -5132,45 +5211,56 @@ void testReplaceMultiRegex() { @Test void testStringReplaceWithBackrefs() { - try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", - null); + try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", null); ColumnVector expected = ColumnVector.fromStrings("

title

", "

another title

", null); - ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

")) { + ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("

(.*)

"), "

\\1

")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-1-01", "2020-2-02", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", null); - ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])-"), "-0\\1-")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])$", "-0\\1")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])$", "-0\\1"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])$"), "-0\\1")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1 random_text", "2020-02-2T12:34:56", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01 random_text", "2020-02-02T12:34:56", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])([ T])", "-0\\1\\2")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])([ T])", "-0\\1\\2"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])([ T])"), "-0\\1\\2")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // test zero as group index try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null); ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;", "aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) { + ColumnVector actual = v.stringReplaceWithBackrefs("([a-z]+)-([0-9]+)", "${0}:${1}:${2};"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("([a-z]+)-([0-9]+)"), "${0}:${1}:${2};")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // group index exceeds group count @@ -5180,6 +5270,12 @@ void testStringReplaceWithBackrefs() { } }); + assertThrows(CudfException.class, () -> { + try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh"); + ColumnVector r = v.stringReplaceWithBackrefs( + new RegexProgram("([A-Z]+)([0-9]+)([a-z]+)"), "\\4")) { + } + }); } @Test diff --git a/python/cudf/cudf/_lib/column.pyi b/python/cudf/cudf/_lib/column.pyi index 612f3cdf95a..013cba3ae03 100644 --- a/python/cudf/cudf/_lib/column.pyi +++ b/python/cudf/cudf/_lib/column.pyi @@ -52,8 +52,6 @@ class Column: @property def base_mask(self) -> Optional[Buffer]: ... @property - def base_mask_ptr(self) -> int: ... - @property def mask(self) -> Optional[Buffer]: ... @property def mask_ptr(self) -> int: ... diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index a5d72193049..11b4a900896 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -101,7 +101,7 @@ cdef class Column: if self.data is None: return 0 else: - return self.data.ptr + return self.data.get_ptr(mode="write") def set_base_data(self, value): if value is not None and not isinstance(value, Buffer): @@ -124,13 +124,6 @@ cdef class Column: def base_mask(self): return self._base_mask - @property - def base_mask_ptr(self): - if self.base_mask is None: - return 0 - else: - return self.base_mask.ptr - @property def mask(self): if self._mask is None: @@ -145,7 +138,7 @@ cdef class Column: if self.mask is None: return 0 else: - return self.mask.ptr + return self.mask.get_ptr(mode="write") def set_base_mask(self, value): """ @@ -206,7 +199,7 @@ cdef class Column: elif hasattr(value, "__cuda_array_interface__"): if value.__cuda_array_interface__["typestr"] not in ("|i1", "|u1"): if isinstance(value, Column): - value = value.data_array_view + value = value.data_array_view(mode="write") value = cp.asarray(value).view('|u1') mask = as_buffer(value) if mask.size < required_num_bytes: @@ -329,10 +322,10 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr( + mode="write") + ) cdef Column child_column if col.base_children: @@ -341,7 +334,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="write") + ) else: mask = NULL @@ -387,10 +382,8 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr(mode="read")) cdef Column child_column if col.base_children: @@ -399,7 +392,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="read") + ) else: mask = NULL @@ -549,7 +544,8 @@ cdef class Column: f"{data_owner} is spilled, which invalidates " f"the exposed data_ptr ({hex(data_ptr)})" ) - data_owner.ptr # accessing the pointer marks it exposed. + # accessing the pointer marks it exposed permanently. + data_owner.mark_exposed() else: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, size=0) diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index c01709322ed..6a53586396f 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import pickle @@ -765,7 +765,7 @@ cdef class _CPackedColumns: gpu_data = Buffer.deserialize(header["data"], frames) dbuf = DeviceBuffer( - ptr=gpu_data.ptr, + ptr=gpu_data.get_ptr(mode="write"), size=gpu_data.nbytes ) diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 6f17dbab86c..a0a8279b213 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -62,7 +62,9 @@ def mask_to_bools(object mask_buffer, size_type begin_bit, size_type end_bit): if not isinstance(mask_buffer, cudf.core.buffer.Buffer): raise TypeError("mask_buffer is not an instance of " "cudf.core.buffer.Buffer") - cdef bitmask_type* bit_mask = (mask_buffer.ptr) + cdef bitmask_type* bit_mask = ( + mask_buffer.get_ptr(mode="read") + ) cdef unique_ptr[column] result with nogil: diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index ebc4d76b6a0..71f48e0ab0c 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -176,7 +176,9 @@ def _getitem(self, offset: int, size: int) -> Buffer: """ return self._from_device_memory( cuda_array_interface_wrapper( - ptr=self.ptr + offset, size=size, owner=self.owner + ptr=self.get_ptr(mode="read") + offset, + size=size, + owner=self.owner, ) ) @@ -202,11 +204,6 @@ def nbytes(self) -> int: """Size of the buffer in bytes.""" return self._size - @property - def ptr(self) -> int: - """Device pointer to the start of the buffer.""" - return self._ptr - @property def owner(self) -> Any: """Object owning the memory of the buffer.""" @@ -215,18 +212,74 @@ def owner(self) -> Any: @property def __cuda_array_interface__(self) -> Mapping: """Implementation of the CUDA Array Interface.""" + return self._get_cuda_array_interface(readonly=False) + + def _get_cuda_array_interface(self, readonly=False): + """Helper function to create a CUDA Array Interface. + + Parameters + ---------- + readonly : bool, default False + If True, returns a CUDA Array Interface with + readonly flag set to True. + If False, returns a CUDA Array Interface with + readonly flag set to False. + + Returns + ------- + dict + """ return { - "data": (self.ptr, False), + "data": ( + self.get_ptr(mode="read" if readonly else "write"), + readonly, + ), "shape": (self.size,), "strides": None, "typestr": "|u1", "version": 0, } + @property + def _readonly_proxy_cai_obj(self): + """ + Returns a proxy object with a read-only CUDA Array Interface. + """ + return cuda_array_interface_wrapper( + ptr=self.get_ptr(mode="read"), + size=self.size, + owner=self, + readonly=True, + typestr="|u1", + version=0, + ) + + def get_ptr(self, *, mode) -> int: + """Device pointer to the start of the buffer. + + Parameters + ---------- + mode : str + Supported values are {"read", "write"} + If "write", the data pointed to may be modified + by the caller. If "read", the data pointed to + must not be modified by the caller. + Failure to fulfill this contract will cause + incorrect behavior. + + + See Also + -------- + SpillableBuffer.get_ptr + """ + return self._ptr + def memoryview(self) -> memoryview: """Read-only access to the buffer through host memory.""" host_buf = host_memory_allocation(self.size) - rmm._lib.device_buffer.copy_ptr_to_host(self.ptr, host_buf) + rmm._lib.device_buffer.copy_ptr_to_host( + self.get_ptr(mode="read"), host_buf + ) return memoryview(host_buf).toreadonly() def serialize(self) -> Tuple[dict, list]: diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 4f625e3b7c8..2064c1fd133 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -7,7 +7,16 @@ import time import weakref from threading import RLock -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, +) import numpy @@ -16,6 +25,7 @@ from cudf.core.buffer.buffer import ( Buffer, cuda_array_interface_wrapper, + get_ptr_and_size, host_memory_allocation, ) from cudf.utils.string import format_bytes @@ -27,6 +37,86 @@ T = TypeVar("T", bound="SpillableBuffer") +def get_spillable_owner(data) -> Optional[SpillableBuffer]: + """Get the spillable owner of `data`, if any exist + + Search through the stack of data owners in order to find an + owner of type `SpillableBuffer` (not subclasses). + + Parameters + ---------- + data : buffer-like or array-like + A buffer-like or array-like object that represent C-contiguous memory. + + Return + ------ + SpillableBuffer or None + The owner of `data` if spillable or None. + """ + + if type(data) is SpillableBuffer: + return data + if hasattr(data, "owner"): + return get_spillable_owner(data.owner) + return None + + +def as_spillable_buffer(data, exposed: bool) -> SpillableBuffer: + """Factory function to wrap `data` in a SpillableBuffer object. + + If `data` isn't a buffer already, a new buffer that points to the memory of + `data` is created. If `data` represents host memory, it is copied to a new + `rmm.DeviceBuffer` device allocation. Otherwise, the memory of `data` is + **not** copied, instead the new buffer keeps a reference to `data` in order + to retain its lifetime. + + If `data` is owned by a spillable buffer, a "slice" of the buffer is + returned. In this case, the spillable buffer must either be "exposed" or + spilled locked (called within an acquire_spill_lock context). This is to + guarantee that the memory of `data` isn't spilled before this function gets + to calculate the offset of the new slice. + + It is illegal for a spillable buffer to own another spillable buffer. + + Parameters + ---------- + data : buffer-like or array-like + A buffer-like or array-like object that represent C-contiguous memory. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). + + Return + ------ + SpillableBuffer + A spillabe buffer instance that represents the device memory of `data`. + """ + + from cudf.core.buffer.utils import get_spill_lock + + if not hasattr(data, "__cuda_array_interface__"): + if exposed: + raise ValueError("cannot created exposed host memory") + return SpillableBuffer._from_host_memory(data) + + spillable_owner = get_spillable_owner(data) + if spillable_owner is None: + return SpillableBuffer._from_device_memory(data, exposed=exposed) + + if not spillable_owner.exposed and get_spill_lock() is None: + raise ValueError( + "A owning spillable buffer must " + "either be exposed or spilled locked." + ) + + # At this point, we know that `data` is owned by a spillable buffer, + # which is exposed or spilled locked. + ptr, size = get_ptr_and_size(data.__cuda_array_interface__) + base_ptr = spillable_owner.memory_info()[0] + return SpillableBufferSlice( + spillable_owner, offset=ptr - base_ptr, size=size + ) + + class SpillLock: pass @@ -55,7 +145,7 @@ def __len__(self): def __getitem__(self, i): if i == 0: - return self._buf.ptr + return self._buf.get_ptr(mode="write") elif i == 1: return False raise IndexError("tuple index out of range") @@ -269,7 +359,7 @@ def spill_lock(self, spill_lock: SpillLock) -> None: self.spill(target="gpu") self._spill_locks.add(spill_lock) - def get_ptr(self) -> int: + def get_ptr(self, *, mode) -> int: """Get a device pointer to the memory of the buffer. If this is called within an `acquire_spill_lock` context, @@ -279,8 +369,8 @@ def get_ptr(self) -> int: If this is *not* called within a `acquire_spill_lock` context, this buffer is marked as unspillable permanently. - Return - ------ + Returns + ------- int The device pointer as an integer """ @@ -319,18 +409,6 @@ def memory_info(self) -> Tuple[int, int, str]: ).__array_interface__["data"][0] return (ptr, self.nbytes, self._ptr_desc["type"]) - @property - def ptr(self) -> int: - """Access the memory directly - - Notice, this will mark the buffer as "exposed" and make - it unspillable permanently. - - Consider using `.get_ptr()` instead. - """ - self.mark_exposed() - return self._ptr - @property def owner(self) -> Any: return self._owner @@ -469,12 +547,12 @@ def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: self._owner = base self.lock = base.lock - @property - def ptr(self) -> int: - return self._base.ptr + self._offset - - def get_ptr(self) -> int: - return self._base.get_ptr() + self._offset + def get_ptr(self, *, mode) -> int: + """ + A passthrough method to `SpillableBuffer.get_ptr` + with factoring in the `offset`. + """ + return self._base.get_ptr(mode=mode) + self._offset def _getitem(self, offset: int, size: int) -> Buffer: return SpillableBufferSlice( diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index 062e86d0cb1..fac28f52b64 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -8,7 +8,7 @@ from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper from cudf.core.buffer.spill_manager import get_global_manager -from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.core.buffer.spillable_buffer import SpillLock, as_spillable_buffer def as_buffer( @@ -72,11 +72,7 @@ def as_buffer( ) if get_global_manager() is not None: - if hasattr(data, "__cuda_array_interface__"): - return SpillableBuffer._from_device_memory(data, exposed=exposed) - if exposed: - raise ValueError("cannot created exposed host memory") - return SpillableBuffer._from_host_memory(data) + return as_spillable_buffer(data, exposed=exposed) if hasattr(data, "__cuda_array_interface__"): return Buffer._from_device_memory(data) diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index ef9f515fff7..af21d7545ee 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -956,9 +956,10 @@ def clip(self, lo: ScalarLike, hi: ScalarLike) -> "column.ColumnBase": self.astype(self.categories.dtype).clip(lo, hi).astype(self.dtype) ) - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: - return self.codes.data_array_view + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: + return self.codes.data_array_view(mode=mode) def unique(self, preserve_order=False) -> CategoricalColumn: codes = self.as_numerical.unique(preserve_order=preserve_order) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 69319e2f775..2f4d9e28314 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -64,7 +64,7 @@ ) from cudf.core._compat import PANDAS_GE_150 from cudf.core.abc import Serializable -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import Buffer, acquire_spill_lock, as_buffer from cudf.core.dtypes import ( CategoricalDtype, IntervalDtype, @@ -113,19 +113,71 @@ def as_frame(self) -> "cudf.core.frame.Frame": {None: self.copy(deep=False)} ) - @property - def data_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def data_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the data as a device array object + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.data).view(self.dtype) + if self.data is not None: + if mode == "read": + obj = self.data._readonly_proxy_cai_obj + elif mode == "write": + obj = self.data + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(self.dtype) - @property - def mask_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def mask_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the mask as a device array + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.mask).view(mask_dtype) + if self.mask is not None: + if mode == "read": + obj = self.mask._readonly_proxy_cai_obj + elif mode == "write": + obj = self.mask + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(mask_dtype) def __len__(self) -> int: return self.size @@ -163,7 +215,8 @@ def values_host(self) -> "np.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return self.data_array_view.copy_to_host() + with acquire_spill_lock(): + return self.data_array_view(mode="read").copy_to_host() @property def values(self) -> "cupy.ndarray": @@ -176,7 +229,7 @@ def values(self) -> "cupy.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return cupy.asarray(self.data_array_view) + return cupy.asarray(self.data_array_view(mode="write")) def find_and_replace( self: T, @@ -363,7 +416,7 @@ def nullmask(self) -> Buffer: """The gpu buffer for the null-mask""" if not self.nullable: raise ValueError("Column has no null mask") - return self.mask_array_view + return self.mask_array_view(mode="read") def copy(self: T, deep: bool = True) -> T: """Columns are immutable, so a deep copy produces a copy of the diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 7943135afe1..8ee3b6e15b6 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -35,7 +35,11 @@ is_number, is_scalar, ) -from cudf.core.buffer import Buffer, cuda_array_interface_wrapper +from cudf.core.buffer import ( + Buffer, + acquire_spill_lock, + cuda_array_interface_wrapper, +) from cudf.core.column import ( ColumnBase, as_column, @@ -110,8 +114,8 @@ def __contains__(self, item: ScalarLike) -> bool: # Handles improper item types # Fails if item is of type None, so the handler. try: - if np.can_cast(item, self.data_array_view.dtype): - item = self.data_array_view.dtype.type(item) + if np.can_cast(item, self.dtype): + item = self.dtype.type(item) else: return False except (TypeError, ValueError): @@ -564,6 +568,7 @@ def fillna( return super(NumericalColumn, col).fillna(fill_value, method) + @acquire_spill_lock() def _find_value( self, value: ScalarLike, closest: bool, find: Callable, compare: str ) -> int: @@ -573,14 +578,14 @@ def _find_value( found = 0 if len(self): found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, ) if found == -1: if self.is_monotonic_increasing and closest: found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, compare=compare, diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 4ca3a9ff04d..9c30585a541 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5395,8 +5395,9 @@ def base_size(self) -> int: else: return self.base_children[0].size - 1 - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: raise ValueError("Cannot get an array view of a StringColumn") def to_arrow(self) -> pa.Array: diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index b7d1724a342..e7979fa4d27 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -13,7 +13,7 @@ from cudf import _lib as libcudf from cudf._typing import ColumnBinaryOperand, DatetimeLikeScalar, Dtype from cudf.api.types import is_scalar, is_timedelta64_dtype -from cudf.core.buffer import Buffer +from cudf.core.buffer import Buffer, acquire_spill_lock from cudf.core.column import ColumnBase, column, string from cudf.utils.dtypes import np_to_pa_dtype from cudf.utils.utils import _fillna_natwise @@ -125,11 +125,16 @@ def values(self): "TimeDelta Arrays is not yet implemented in cudf" ) + @acquire_spill_lock() def to_arrow(self) -> pa.Array: mask = None if self.nullable: - mask = pa.py_buffer(self.mask_array_view.copy_to_host()) - data = pa.py_buffer(self.as_numerical.data_array_view.copy_to_host()) + mask = pa.py_buffer( + self.mask_array_view(mode="read").copy_to_host() + ) + data = pa.py_buffer( + self.as_numerical.data_array_view(mode="read").copy_to_host() + ) pa_dtype = np_to_pa_dtype(self.dtype) return pa.Array.from_buffers( type=pa_dtype, diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index b38d3048ed7..2090906380e 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. import enum from collections import abc @@ -89,7 +89,7 @@ def ptr(self) -> int: """ Pointer to start of the buffer as an integer. """ - return self._buf.ptr + return self._buf.get_ptr(mode="write") def __dlpack__(self): # DLPack not implemented in NumPy yet, so leave it out here. diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 32764c6c2f0..8b508eac324 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -1448,7 +1448,7 @@ def searchsorted( # Return result as cupy array if the values is non-scalar # If values is scalar, result is expected to be scalar. - result = cupy.asarray(outcol.data_array_view) + result = cupy.asarray(outcol.data_array_view(mode="read")) if scalar_flag: return result[0].item() else: diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 6526ba1e7c3..c8016786be9 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -49,6 +49,7 @@ is_scalar, ) from cudf.core._base_index import BaseIndex +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ColumnBase, as_column, full from cudf.core.column_accessor import ColumnAccessor from cudf.core.dtypes import ListDtype @@ -2105,6 +2106,7 @@ def add_suffix(self, suffix): Use `Series.add_suffix` or `DataFrame.add_suffix`" ) + @acquire_spill_lock() @_cudf_nvtx_annotate def _apply(self, func, kernel_getter, *args, **kwargs): """Apply `func` across the rows of the frame.""" diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index faad5275abd..1c697a2d824 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -39,6 +39,7 @@ is_struct_dtype, ) from cudf.core.abc import Serializable +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ( ColumnBase, DatetimeColumn, @@ -4855,6 +4856,7 @@ def _align_indices(series_list, how="outer", allow_non_unique=False): return result +@acquire_spill_lock() @_cudf_nvtx_annotate def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): r"""Returns a boolean array where two arrays are equal within a tolerance. @@ -4959,10 +4961,10 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): index = as_index(a.index) a_col = column.as_column(a) - a_array = cupy.asarray(a_col.data_array_view) + a_array = cupy.asarray(a_col.data_array_view(mode="read")) b_col = column.as_column(b) - b_array = cupy.asarray(b_col.data_array_view) + b_array = cupy.asarray(b_col.data_array_view(mode="read")) result = cupy.isclose( a=a_array, b=b_array, rtol=rtol, atol=atol, equal_nan=equal_nan diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index fb1cafa5625..cac4774400a 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION +# Copyright (c) 2020-2023, NVIDIA CORPORATION import itertools @@ -11,6 +11,7 @@ from cudf.api.types import is_integer, is_number from cudf.core import column from cudf.core._compat import PANDAS_GE_150 +from cudf.core.buffer import acquire_spill_lock from cudf.core.column.column import as_column from cudf.core.mixins import Reducible from cudf.utils import cudautils @@ -487,9 +488,11 @@ def _window_to_window_sizes(self, window): if is_integer(window): return window else: - return cudautils.window_sizes_from_offset( - self.obj.index._values.data_array_view, window - ) + with acquire_spill_lock(): + return cudautils.window_sizes_from_offset( + self.obj.index._values.data_array_view(mode="write"), + window, + ) def __repr__(self): return "{} [window={},min_periods={},center={}]".format( @@ -524,16 +527,17 @@ def __init__(self, groupby, window, min_periods=None, center=False): super().__init__(obj, window, min_periods=min_periods, center=center) + @acquire_spill_lock() def _window_to_window_sizes(self, window): if is_integer(window): return cudautils.grouped_window_sizes_from_offset( - column.arange(len(self.obj)).data_array_view, + column.arange(len(self.obj)).data_array_view(mode="read"), self._group_starts, window, ) else: return cudautils.grouped_window_sizes_from_offset( - self.obj.index._values.data_array_view, + self.obj.index._values.data_array_view(mode="read"), self._group_starts, window, ) diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index cbaf47a4c68..fb4daba1209 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -336,7 +336,7 @@ def assert_column_memory_eq( """ def get_ptr(x) -> int: - return x.ptr if x else 0 + return x.get_ptr(mode="read") if x else 0 assert get_ptr(lhs.base_data) == get_ptr(rhs.base_data) assert get_ptr(lhs.base_mask) == get_ptr(rhs.base_mask) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index df7152d53a6..1c9e7475080 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import cupy as cp import pytest @@ -52,7 +52,7 @@ def test_buffer_creation_from_any(): ary = cp.arange(arr_len) b = as_buffer(ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size with pytest.raises( @@ -62,7 +62,7 @@ def test_buffer_creation_from_any(): b = as_buffer(ary.data.ptr, size=ary.nbytes, owner=ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size assert b.owner.owner is ary diff --git a/python/cudf/cudf/tests/test_column.py b/python/cudf/cudf/tests/test_column.py index 75b82baf2e8..7d113bbb9e2 100644 --- a/python/cudf/cudf/tests/test_column.py +++ b/python/cudf/cudf/tests/test_column.py @@ -285,8 +285,8 @@ def test_column_view_valid_numeric_to_numeric(data, from_dtype, to_dtype): expect = pd.Series(cpu_data_view, dtype=cpu_data_view.dtype) got = cudf.Series(gpu_data_view, dtype=gpu_data_view.dtype) - gpu_ptr = gpu_data.data.ptr - assert gpu_ptr == got._column.data.ptr + gpu_ptr = gpu_data.data.get_ptr(mode="read") + assert gpu_ptr == got._column.data.get_ptr(mode="read") assert_eq(expect, got) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 0e0b0a37255..65e24c7c704 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1868,7 +1868,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s1.buffers()[0]).view("u1")[0], - gs1._column.mask_array_view.copy_to_host().view("u1")[0], + gs1._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s1, gs1.to_arrow()) @@ -1879,7 +1879,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s2.buffers()[0]).view("u1")[0], - gs2._column.mask_array_view.copy_to_host().view("u1")[0], + gs2._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s2, gs2.to_arrow()) @@ -2659,11 +2659,11 @@ def query_GPU_memory(note=""): cudaDF = cudaDF[boolmask] assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col0"].index._values.data_array_view.device_ctypes_pointer ) assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col1"].index._values.data_array_view.device_ctypes_pointer ) diff --git a/python/cudf/cudf/tests/test_dataframe_copy.py b/python/cudf/cudf/tests/test_dataframe_copy.py index 1a9098c70db..85e994bd733 100644 --- a/python/cudf/cudf/tests/test_dataframe_copy.py +++ b/python/cudf/cudf/tests/test_dataframe_copy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from copy import copy, deepcopy import numpy as np @@ -160,7 +160,7 @@ def test_kernel_deep_copy(): cdf = gdf.copy(deep=True) sr = gdf["b"] - add_one[1, len(sr)](sr._column.data_array_view) + add_one[1, len(sr)](sr._column.data_array_view(mode="write")) assert not gdf.to_string().split() == cdf.to_string().split() diff --git a/python/cudf/cudf/tests/test_df_protocol.py b/python/cudf/cudf/tests/test_df_protocol.py index 0981e850c10..7dbca90ab03 100644 --- a/python/cudf/cudf/tests/test_df_protocol.py +++ b/python/cudf/cudf/tests/test_df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. from typing import Any, Tuple @@ -41,7 +41,7 @@ def assert_buffer_equal(buffer_and_dtype: Tuple[_CuDFBuffer, Any], cudfcol): if dtype[0] != _DtypeKind.BOOL: array_from_dlpack = cp.from_dlpack(buf.__dlpack__()).get() - col_array = cp.asarray(cudfcol.data_array_view).get() + col_array = cp.asarray(cudfcol.data_array_view(mode="read")).get() assert_eq( array_from_dlpack[non_null_idxs.to_numpy()].flatten(), col_array[non_null_idxs.to_numpy()].flatten(), diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index d27d6732226..3e1f001e7d1 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -804,8 +804,8 @@ def test_multiindex_copy_deep(data, deep): lchildren = reduce(operator.add, lchildren) rchildren = reduce(operator.add, rchildren) - lptrs = [child.base_data.ptr for child in lchildren] - rptrs = [child.base_data.ptr for child in rchildren] + lptrs = [child.base_data.get_ptr(mode="read") for child in lchildren] + rptrs = [child.base_data.get_ptr(mode="read") for child in rchildren] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) @@ -814,20 +814,36 @@ def test_multiindex_copy_deep(data, deep): mi2 = mi1.copy(deep=deep) # Assert ._levels identity - lptrs = [lv._data._data[None].base_data.ptr for lv in mi1._levels] - rptrs = [lv._data._data[None].base_data.ptr for lv in mi2._levels] + lptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi1._levels + ] + rptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi2._levels + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._codes identity - lptrs = [c.base_data.ptr for _, c in mi1._codes._data.items()] - rptrs = [c.base_data.ptr for _, c in mi2._codes._data.items()] + lptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi1._codes._data.items() + ] + rptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi2._codes._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._data identity - lptrs = [d.base_data.ptr for _, d in mi1._data.items()] - rptrs = [d.base_data.ptr for _, d in mi2._data.items()] + lptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi1._data.items() + ] + rptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi2._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) diff --git a/python/cudf/cudf/tests/test_pack.py b/python/cudf/cudf/tests/test_pack.py index b6bda7ef5fa..9972071122e 100644 --- a/python/cudf/cudf/tests/test_pack.py +++ b/python/cudf/cudf/tests/test_pack.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -129,7 +129,9 @@ def assert_packed_frame_unique_pointers(df): for col in df: if df._data[col].data: - assert df._data[col].data.ptr != unpacked._data[col].data.ptr + assert df._data[col].data.get_ptr(mode="read") != unpacked._data[ + col + ].data.get_ptr(mode="read") def test_packed_dataframe_unique_pointers_numeric(): diff --git a/python/cudf/cudf/tests/test_repr.py b/python/cudf/cudf/tests/test_repr.py index 5ba0bec3dc4..bae0fde6463 100644 --- a/python/cudf/cudf/tests/test_repr.py +++ b/python/cudf/cudf/tests/test_repr.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. import textwrap @@ -31,7 +31,7 @@ def test_null_series(nrows, dtype): sr[np.random.choice([False, True], size=size)] = None if dtype != "category" and cudf.dtype(dtype).kind in {"u", "i"}: ps = pd.Series( - sr._column.data_array_view.copy_to_host(), + sr._column.data_array_view(mode="read").copy_to_host(), dtype=np_dtypes_to_pandas_dtypes.get( cudf.dtype(dtype), cudf.dtype(dtype) ), diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 4788736966a..88ce908aa5f 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -119,7 +119,7 @@ def test_spillable_buffer(manager: SpillManager): buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) assert isinstance(buf, SpillableBuffer) assert buf.spillable - buf.ptr # Expose pointer + buf.mark_exposed() assert buf.exposed assert not buf.spillable buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) @@ -137,7 +137,6 @@ def test_spillable_buffer(manager: SpillManager): @pytest.mark.parametrize( "attribute", [ - "ptr", "get_ptr", "memoryview", "is_spilled", @@ -210,7 +209,7 @@ def test_spilling_buffer(manager: SpillManager): buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) buf.spill(target="cpu") assert buf.is_spilled - buf.ptr # Expose pointer and trigger unspill + buf.mark_exposed() # Expose pointer and trigger unspill assert not buf.is_spilled with pytest.raises(ValueError, match="unspillable buffer"): buf.spill(target="cpu") @@ -378,10 +377,10 @@ def test_get_ptr(manager: SpillManager, target): assert buf.spillable assert len(buf._spill_locks) == 0 with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable assert not buf.spillable assert buf.spillable @@ -501,7 +500,7 @@ def test_serialize_cuda_dataframe(manager: SpillManager): assert len(buf._base._spill_locks) == 1 assert len(frames) == 1 assert isinstance(frames[0], Buffer) - assert frames[0].ptr == buf.ptr + assert frames[0].get_ptr(mode="read") == buf.get_ptr(mode="read") frames[0] = cupy.array(frames[0], copy=True) df2 = protocol.deserialize(header, frames) @@ -540,6 +539,42 @@ def test_df_transpose(manager: SpillManager): assert df2._data._data[1].data.exposed +def test_as_buffer_of_spillable_buffer(manager: SpillManager): + data = cupy.arange(10, dtype="u1") + b1 = as_buffer(data, exposed=False) + assert isinstance(b1, SpillableBuffer) + assert b1.owner is data + b2 = as_buffer(b1) + assert b1 is b2 + + with pytest.raises( + ValueError, + match="buffer must either be exposed or spilled locked", + ): + # Use `memory_info` to access device point _without_ making + # the buffer unspillable. + b3 = as_buffer(b1.memory_info()[0], size=b1.size, owner=b1) + + with acquire_spill_lock(): + b3 = as_buffer(b1.get_ptr(mode="read"), size=b1.size, owner=b1) + assert isinstance(b3, SpillableBufferSlice) + assert b3.owner is b1 + + b4 = as_buffer( + b1.get_ptr(mode="write") + data.itemsize, + size=b1.size - data.itemsize, + owner=b3, + ) + assert isinstance(b4, SpillableBufferSlice) + assert b4.owner is b1 + assert all(cupy.array(b4.memoryview()) == data[1:]) + + b5 = as_buffer(b4.get_ptr(mode="write"), size=b4.size - 1, owner=b4) + assert isinstance(b5, SpillableBufferSlice) + assert b5.owner is b1 + assert all(cupy.array(b5.memoryview()) == data[1:-1]) + + @pytest.mark.parametrize("dtype", ["uint8", "uint64"]) def test_memoryview_slice(manager: SpillManager, dtype): """Check .memoryview() of a sliced spillable buffer""" @@ -589,7 +624,7 @@ def test_statistics_expose(manager: SpillManager): ] # Expose the first buffer - buffers[0].ptr + buffers[0].mark_exposed() assert len(manager.statistics.exposes) == 1 stat = list(manager.statistics.exposes.values())[0] assert stat.count == 1 @@ -598,7 +633,7 @@ def test_statistics_expose(manager: SpillManager): # Expose all 10 buffers for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() # The rest of the ptr accesses should accumulate to a single stat # because they resolve to the same traceback. @@ -618,7 +653,7 @@ def test_statistics_expose(manager: SpillManager): # Expose the new buffers and check that they are counted as spilled for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() assert len(manager.statistics.exposes) == 3 stat = list(manager.statistics.exposes.values())[2] assert stat.count == 10 diff --git a/python/cudf/cudf/utils/applyutils.py b/python/cudf/cudf/utils/applyutils.py index 89331b933a8..7e998413642 100644 --- a/python/cudf/cudf/utils/applyutils.py +++ b/python/cudf/cudf/utils/applyutils.py @@ -1,13 +1,15 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import functools from typing import Any, Dict +import cupy as cp from numba import cuda from numba.core.utils import pysignature import cudf from cudf import _lib as libcudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column from cudf.utils import utils from cudf.utils.docutils import docfmt_partial @@ -139,21 +141,25 @@ def __init__( self.cache_key = cache_key self.kernel = self.compile(func, sig.parameters.keys(), kwargs.keys()) + @acquire_spill_lock() def run(self, df, **launch_params): # Get input columns if isinstance(self.incols, dict): inputs = { - v: df[k]._column.data_array_view + v: df[k]._column.data_array_view(mode="read") for (k, v) in self.incols.items() } else: - inputs = {k: df[k]._column.data_array_view for k in self.incols} + inputs = { + k: df[k]._column.data_array_view(mode="read") + for k in self.incols + } # Allocate output columns outputs = {} for k, dt in self.outcols.items(): outputs[k] = column.column_empty( len(df), dt, False - ).data_array_view + ).data_array_view(mode="write") # Bind argument args = {} for dct in [inputs, outputs, self.kwargs]: @@ -174,7 +180,7 @@ def run(self, df, **launch_params): ) if out_mask is not None: outdf._data[k] = outdf[k]._column.set_mask( - out_mask.data_array_view + out_mask.data_array_view(mode="write") ) return outdf @@ -213,11 +219,12 @@ def launch_kernel(self, df, args, chunks, blkct=None, tpb=None): def normalize_chunks(self, size, chunks): if isinstance(chunks, int): # *chunks* is the chunksize - return column.arange(0, size, chunks).data_array_view + return cuda.as_cuda_array( + cp.arange(start=0, stop=size, step=chunks) + ).view("int64") else: # *chunks* is an array of chunk leading offset - chunks = column.as_column(chunks) - return chunks.data_array_view + return cuda.as_cuda_array(cp.asarray(chunks)).view("int64") def _make_row_wise_kernel(func, argnames, extras): diff --git a/python/cudf/cudf/utils/queryutils.py b/python/cudf/cudf/utils/queryutils.py index 25b3d517e1c..4ce89b526d6 100644 --- a/python/cudf/cudf/utils/queryutils.py +++ b/python/cudf/cudf/utils/queryutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import ast import datetime @@ -8,6 +8,7 @@ from numba import cuda import cudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column_empty from cudf.utils import applyutils from cudf.utils.dtypes import ( @@ -191,6 +192,7 @@ def _add_prefix(arg): return kernel +@acquire_spill_lock() def query_execute(df, expr, callenv): """Compile & execute the query expression @@ -220,7 +222,7 @@ def query_execute(df, expr, callenv): "or bool dtypes." ) - colarrays = [col.data_array_view for col in colarrays] + colarrays = [col.data_array_view(mode="read") for col in colarrays] kernel = compiled["kernel"] # process env args diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 99e4046b0b3..80deb881ec8 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import operator @@ -12,6 +12,8 @@ from numba.cuda.cudadecl import registry as cuda_decl_registry from numba.cuda.cudadrv import nvvm +import rmm + data_layout = nvvm.data_layout # libcudf size_type @@ -112,7 +114,9 @@ def prepare_args(self, ty, val, **kwargs): if isinstance(ty, types.CPointer) and isinstance( ty.dtype, (StringView, UDFString) ): - return types.uint64, val.ptr + return types.uint64, val.ptr if isinstance( + val, rmm._lib.device_buffer.DeviceBuffer + ) else val.get_ptr(mode="read") else: return ty, val