diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 4e52044ffb1..89f3f3a5976 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -29,6 +29,7 @@ export CONDA_ARTIFACT_PATH="$WORKSPACE/ci/artifacts/cudf/cpu/.conda-bld/" # Parse git describe export GIT_DESCRIBE_TAG=`git describe --tags` export MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'` +unset GIT_DESCRIBE_TAG # Dask & Distributed option to install main(nightly) or `conda-forge` packages. export INSTALL_DASK_MAIN=1 @@ -79,30 +80,11 @@ conda info conda config --show-sources conda list --show-channel-urls -gpuci_logger "Install dependencies" -gpuci_mamba_retry install -y \ - "cudatoolkit=$CUDA_REL" \ - "rapids-build-env=$MINOR_VERSION.*" \ - "rapids-notebook-env=$MINOR_VERSION.*" \ - "dask-cuda=${MINOR_VERSION}" \ - "rmm=$MINOR_VERSION.*" \ - "ucx-py=${UCX_PY_VERSION}" - -# https://docs.rapids.ai/maintainers/depmgmt/ -# gpuci_conda_retry remove --force rapids-build-env rapids-notebook-env -# gpuci_mamba_retry install -y "your-pkg=1.0.0" - - gpuci_logger "Check compiler versions" python --version $CC --version $CXX --version -gpuci_logger "Check conda environment" -conda info -conda config --show-sources -conda list --show-channel-urls - function install_dask { # Install the conda-forge or nightly version of dask and distributed gpuci_logger "Install the conda-forge or nightly version of dask and distributed" @@ -125,6 +107,19 @@ function install_dask { if [[ -z "$PROJECT_FLASH" || "$PROJECT_FLASH" == "0" ]]; then + gpuci_logger "Install dependencies" + gpuci_mamba_retry install -y \ + "cudatoolkit=$CUDA_REL" \ + "rapids-build-env=$MINOR_VERSION.*" \ + "rapids-notebook-env=$MINOR_VERSION.*" \ + "dask-cuda=${MINOR_VERSION}" \ + "rmm=$MINOR_VERSION.*" \ + "ucx-py=${UCX_PY_VERSION}" + + # https://docs.rapids.ai/maintainers/depmgmt/ + # gpuci_conda_retry remove --force rapids-build-env rapids-notebook-env + # gpuci_mamba_retry install -y "your-pkg=1.0.0" + install_dask ################################################################################ @@ -171,8 +166,19 @@ else gpuci_logger "Check GPU usage" nvidia-smi + gpuci_logger "Installing libcudf, libcudf_kafka and libcudf-tests" gpuci_mamba_retry install -y -c ${CONDA_ARTIFACT_PATH} libcudf libcudf_kafka libcudf-tests + gpuci_logger "Building cudf, dask-cudf, cudf_kafka and custreamz" + export CONDA_BLD_DIR="$WORKSPACE/.conda-bld" + gpuci_conda_retry build --croot ${CONDA_BLD_DIR} conda/recipes/cudf --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} + gpuci_conda_retry build --croot ${CONDA_BLD_DIR} conda/recipes/dask-cudf --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} + gpuci_conda_retry build --croot ${CONDA_BLD_DIR} conda/recipes/cudf_kafka --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} + gpuci_conda_retry build --croot ${CONDA_BLD_DIR} conda/recipes/custreamz --python=$PYTHON -c ${CONDA_ARTIFACT_PATH} + + gpuci_logger "Installing cudf, dask-cudf, cudf_kafka and custreamz" + gpuci_mamba_retry install cudf dask-cudf cudf_kafka custreamz -c "${CONDA_BLD_DIR}" -c "${CONDA_ARTIFACT_PATH}" + gpuci_logger "GoogleTests" # Run libcudf and libcudf_kafka gtests from libcudf-tests package for gt in "$CONDA_PREFIX/bin/gtests/libcudf"*/* ; do @@ -209,12 +215,6 @@ else # test-results/*.cs.log are processed in gpuci fi fi - - install_dask - - gpuci_logger "Build python libs from source" - "$WORKSPACE/build.sh" cudf dask_cudf cudf_kafka --ptds - fi # Both regular and Project Flash proceed here diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index e93b2bf4f25..04dcf51dd40 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -242,6 +242,10 @@ ConfigureBench(PARQUET_WRITER_BENCH io/parquet/parquet_writer.cpp) # * orc writer benchmark -------------------------------------------------------------------------- ConfigureBench(ORC_WRITER_BENCH io/orc/orc_writer.cpp) +# ################################################################################################## +# * orc writer chunks benchmark --------------------------------------------------------------- +ConfigureNVBench(ORC_WRITER_CHUNKS_NVBENCH io/orc/orc_writer_chunks.cpp) + # ################################################################################################## # * csv writer benchmark -------------------------------------------------------------------------- ConfigureBench(CSV_WRITER_BENCH io/csv/csv_writer.cpp) diff --git a/cpp/benchmarks/io/orc/orc_writer_chunks.cpp b/cpp/benchmarks/io/orc/orc_writer_chunks.cpp new file mode 100644 index 00000000000..dc82772fa83 --- /dev/null +++ b/cpp/benchmarks/io/orc/orc_writer_chunks.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +// to enable, run cmake with -DBUILD_BENCHMARKS=ON + +constexpr int64_t data_size = 512 << 20; + +namespace cudf_io = cudf::io; + +void nvbench_orc_write(nvbench::state& state) +{ + cudf::size_type num_cols = state.get_int64("num_columns"); + + auto tbl = + create_random_table(cycle_dtypes(get_type_or_group({int32_t(type_group_id::INTEGRAL_SIGNED), + int32_t(type_group_id::FLOATING_POINT), + int32_t(type_group_id::FIXED_POINT), + int32_t(type_group_id::TIMESTAMP), + int32_t(cudf::type_id::STRING), + int32_t(cudf::type_id::STRUCT), + int32_t(cudf::type_id::LIST)}), + num_cols), + table_size_bytes{data_size}); + cudf::table_view view = tbl->view(); + + auto mem_stats_logger = cudf::memory_stats_logger(); + + state.add_global_memory_reads(data_size); + state.add_element_count(view.num_columns() * view.num_rows()); + + size_t encoded_file_size = 0; + + state.exec(nvbench::exec_tag::timer | nvbench::exec_tag::sync, + [&](nvbench::launch& launch, auto& timer) { + cuio_source_sink_pair source_sink(io_type::VOID); + timer.start(); + + cudf_io::orc_writer_options opts = + cudf_io::orc_writer_options::builder(source_sink.make_sink_info(), view); + cudf_io::write_orc(opts); + + timer.stop(); + encoded_file_size = source_sink.size(); + }); + + state.add_buffer_size(mem_stats_logger.peak_memory_usage(), "pmu", "Peak Memory Usage"); + state.add_buffer_size(encoded_file_size, "efs", "Encoded File Size"); + state.add_buffer_size(view.num_rows(), "trc", "Total Rows"); +} + +void nvbench_orc_chunked_write(nvbench::state& state) +{ + cudf::size_type num_cols = state.get_int64("num_columns"); + cudf::size_type num_tables = state.get_int64("num_chunks"); + + std::vector> tables; + for (cudf::size_type idx = 0; idx < num_tables; idx++) { + tables.push_back( + create_random_table(cycle_dtypes(get_type_or_group({int32_t(type_group_id::INTEGRAL_SIGNED), + int32_t(type_group_id::FLOATING_POINT), + int32_t(type_group_id::FIXED_POINT), + int32_t(type_group_id::TIMESTAMP), + int32_t(cudf::type_id::STRING), + int32_t(cudf::type_id::STRUCT), + int32_t(cudf::type_id::LIST)}), + num_cols), + table_size_bytes{size_t(data_size / num_tables)})); + } + + auto mem_stats_logger = cudf::memory_stats_logger(); + + auto size_iter = thrust::make_transform_iterator( + tables.begin(), [](auto const& i) { return i->num_columns() * i->num_rows(); }); + auto row_count_iter = + thrust::make_transform_iterator(tables.begin(), [](auto const& i) { return i->num_rows(); }); + auto total_elements = std::accumulate(size_iter, size_iter + num_tables, 0); + auto total_rows = std::accumulate(row_count_iter, row_count_iter + num_tables, 0); + + state.add_global_memory_reads(data_size); + state.add_element_count(total_elements); + + size_t encoded_file_size = 0; + + state.exec( + nvbench::exec_tag::timer | nvbench::exec_tag::sync, [&](nvbench::launch& launch, auto& timer) { + cuio_source_sink_pair source_sink(io_type::VOID); + timer.start(); + + cudf_io::chunked_orc_writer_options opts = + cudf_io::chunked_orc_writer_options::builder(source_sink.make_sink_info()); + cudf_io::orc_chunked_writer writer(opts); + std::for_each(tables.begin(), + tables.end(), + [&writer](std::unique_ptr const& tbl) { writer.write(*tbl); }); + writer.close(); + + timer.stop(); + encoded_file_size = source_sink.size(); + }); + + state.add_buffer_size(mem_stats_logger.peak_memory_usage(), "pmu", "Peak Memory Usage"); + state.add_buffer_size(encoded_file_size, "efs", "Encoded File Size"); + state.add_buffer_size(total_rows, "trc", "Total Rows"); +} + +NVBENCH_BENCH(nvbench_orc_write) + .set_name("orc_write") + .set_min_samples(4) + .add_int64_axis("num_columns", {8, 64}); + +NVBENCH_BENCH(nvbench_orc_chunked_write) + .set_name("orc_chunked_write") + .set_min_samples(4) + .add_int64_axis("num_columns", {8, 64}) + .add_int64_axis("num_chunks", {8, 64}); diff --git a/cpp/src/io/orc/writer_impl.cu b/cpp/src/io/orc/writer_impl.cu index ecd2d6f6ec0..0ad33821dd7 100644 --- a/cpp/src/io/orc/writer_impl.cu +++ b/cpp/src/io/orc/writer_impl.cu @@ -54,6 +54,9 @@ #include #include +#include +#include + #include namespace cudf { @@ -1233,8 +1236,7 @@ writer::impl::encoded_footer_statistics writer::impl::finish_statistic_blobs( auto const num_stripe_blobs = thrust::reduce(stripe_size_iter, stripe_size_iter + per_chunk_stats.stripe_stat_merge.size()); auto const num_file_blobs = num_columns; - auto const num_blobs = single_write_mode ? static_cast(num_stripe_blobs + num_file_blobs) - : static_cast(num_stripe_blobs); + auto const num_blobs = static_cast(num_stripe_blobs + num_file_blobs); if (num_stripe_blobs == 0) { return {}; } @@ -1242,46 +1244,53 @@ writer::impl::encoded_footer_statistics writer::impl::finish_statistic_blobs( rmm::device_uvector stat_chunks(num_blobs, stream); hostdevice_vector stats_merge(num_blobs, stream); - size_t chunk_offset = 0; - size_t merge_offset = 0; + // we need to merge the stat arrays from the persisted data. + // this needs to be done carefully because each array can contain + // a different number of stripes and stripes from each column must be + // located next to each other. We know the total number of stripes and + // we know the size of each array. The number of stripes per column in a chunk array can + // be calculated by dividing the number of chunks by the number of columns. + // That many chunks need to be copied at a time to the proper destination. + size_t num_entries_seen = 0; for (size_t i = 0; i < per_chunk_stats.stripe_stat_chunks.size(); ++i) { - auto chunk_bytes = per_chunk_stats.stripe_stat_chunks[i].size() * sizeof(statistics_chunk); - auto merge_bytes = per_chunk_stats.stripe_stat_merge[i].size() * sizeof(statistics_merge_group); - cudaMemcpyAsync(stat_chunks.data() + chunk_offset, - per_chunk_stats.stripe_stat_chunks[i].data(), - chunk_bytes, - cudaMemcpyDeviceToDevice, - stream); - cudaMemcpyAsync(stats_merge.device_ptr() + merge_offset, - per_chunk_stats.stripe_stat_merge[i].device_ptr(), - merge_bytes, - cudaMemcpyDeviceToDevice, - stream); - chunk_offset += per_chunk_stats.stripe_stat_chunks[i].size(); - merge_offset += per_chunk_stats.stripe_stat_merge[i].size(); + auto const stripes_per_col = per_chunk_stats.stripe_stat_chunks[i].size() / num_columns; + + auto const chunk_bytes = stripes_per_col * sizeof(statistics_chunk); + auto const merge_bytes = stripes_per_col * sizeof(statistics_merge_group); + for (size_t col = 0; col < num_columns; ++col) { + cudaMemcpyAsync(stat_chunks.data() + (num_stripes * col) + num_entries_seen, + per_chunk_stats.stripe_stat_chunks[i].data() + col * stripes_per_col, + chunk_bytes, + cudaMemcpyDeviceToDevice, + stream); + cudaMemcpyAsync(stats_merge.device_ptr() + (num_stripes * col) + num_entries_seen, + per_chunk_stats.stripe_stat_merge[i].device_ptr() + col * stripes_per_col, + merge_bytes, + cudaMemcpyDeviceToDevice, + stream); + } + num_entries_seen += stripes_per_col; } - if (single_write_mode) { - std::vector file_stats_merge(num_file_blobs); - for (auto i = 0u; i < num_file_blobs; ++i) { - auto col_stats = &file_stats_merge[i]; - col_stats->col_dtype = per_chunk_stats.col_types[i]; - col_stats->stats_dtype = per_chunk_stats.stats_dtypes[i]; - col_stats->start_chunk = static_cast(i * num_stripes); - col_stats->num_chunks = static_cast(num_stripes); - } + std::vector file_stats_merge(num_file_blobs); + for (auto i = 0u; i < num_file_blobs; ++i) { + auto col_stats = &file_stats_merge[i]; + col_stats->col_dtype = per_chunk_stats.col_types[i]; + col_stats->stats_dtype = per_chunk_stats.stats_dtypes[i]; + col_stats->start_chunk = static_cast(i * num_stripes); + col_stats->num_chunks = static_cast(num_stripes); + } - auto d_file_stats_merge = stats_merge.device_ptr(num_stripe_blobs); - cudaMemcpyAsync(d_file_stats_merge, - file_stats_merge.data(), - num_file_blobs * sizeof(statistics_merge_group), - cudaMemcpyHostToDevice, - stream); + auto d_file_stats_merge = stats_merge.device_ptr(num_stripe_blobs); + cudaMemcpyAsync(d_file_stats_merge, + file_stats_merge.data(), + num_file_blobs * sizeof(statistics_merge_group), + cudaMemcpyHostToDevice, + stream); - auto file_stat_chunks = stat_chunks.data() + num_stripe_blobs; - detail::merge_group_statistics( - file_stat_chunks, stat_chunks.data(), d_file_stats_merge, num_file_blobs, stream); - } + auto file_stat_chunks = stat_chunks.data() + num_stripe_blobs; + detail::merge_group_statistics( + file_stat_chunks, stat_chunks.data(), d_file_stats_merge, num_file_blobs, stream); hostdevice_vector blobs = allocate_and_encode_blobs(stats_merge, stat_chunks, num_blobs, stream); @@ -1295,14 +1304,12 @@ writer::impl::encoded_footer_statistics writer::impl::finish_statistic_blobs( stripe_blobs[i].assign(stat_begin, stat_end); } - std::vector file_blobs(single_write_mode ? num_file_blobs : 0); - if (single_write_mode) { - auto file_stat_merge = stats_merge.host_ptr(num_stripe_blobs); - for (auto i = 0u; i < num_file_blobs; i++) { - auto const stat_begin = blobs.host_ptr(file_stat_merge[i].start_chunk); - auto const stat_end = stat_begin + file_stat_merge[i].num_chunks; - file_blobs[i].assign(stat_begin, stat_end); - } + std::vector file_blobs(num_file_blobs); + auto file_stat_merge = stats_merge.host_ptr(num_stripe_blobs); + for (auto i = 0u; i < num_file_blobs; i++) { + auto const stat_begin = blobs.host_ptr(file_stat_merge[i].start_chunk); + auto const stat_end = stat_begin + file_stat_merge[i].num_chunks; + file_blobs[i].assign(stat_begin, stat_end); } return {std::move(stripe_blobs), std::move(file_blobs)}; @@ -1937,6 +1944,91 @@ string_dictionaries allocate_dictionaries(orc_table_view const& orc_table, std::move(is_dict_enabled)}; } +struct string_length_functor { + __device__ inline size_type operator()(int const i) const + { + // we translate from 0 -> num_chunks * 2 because each statistic has a min and max + // string and we need to calculate lengths for both. + if (i >= num_chunks * 2) return 0; + + // min strings are even values, max strings are odd values of i + auto const should_copy_min = i % 2 == 0; + // index of the chunk + auto const idx = i / 2; + auto& str_val = should_copy_min ? stripe_stat_chunks[idx].min_value.str_val + : stripe_stat_chunks[idx].max_value.str_val; + auto const str = stripe_stat_merge[idx].stats_dtype == dtype_string; + return str ? str_val.length : 0; + } + + int const num_chunks; + statistics_chunk const* stripe_stat_chunks; + statistics_merge_group const* stripe_stat_merge; +}; + +__global__ void copy_string_data(char* string_pool, + size_type* offsets, + statistics_chunk* chunks, + statistics_merge_group const* groups) +{ + auto const idx = blockIdx.x / 2; + if (groups[idx].stats_dtype == dtype_string) { + // min strings are even values, max strings are odd values of i + auto const should_copy_min = blockIdx.x % 2 == 0; + auto& str_val = should_copy_min ? chunks[idx].min_value.str_val : chunks[idx].max_value.str_val; + auto dst = &string_pool[offsets[blockIdx.x]]; + auto src = str_val.ptr; + + for (int i = threadIdx.x; i < str_val.length; i += blockDim.x) { + dst[i] = src[i]; + } + if (threadIdx.x == 0) { str_val.ptr = dst; } + } +} + +void writer::impl::persisted_statistics::persist(int num_table_rows, + bool single_write_mode, + intermediate_statistics& intermediate_stats, + rmm::cuda_stream_view stream) +{ + if (not single_write_mode) { + // persist the strings in the chunks into a string pool and update pointers + auto const num_chunks = static_cast(intermediate_stats.stripe_stat_chunks.size()); + // min offset and max offset + 1 for total size + rmm::device_uvector offsets((num_chunks * 2) + 1, stream); + + auto iter = cudf::detail::make_counting_transform_iterator( + 0, + string_length_functor{num_chunks, + intermediate_stats.stripe_stat_chunks.data(), + intermediate_stats.stripe_stat_merge.device_ptr()}); + thrust::exclusive_scan(rmm::exec_policy(stream), iter, iter + offsets.size(), offsets.begin()); + + // pull size back to host + auto const total_string_pool_size = offsets.element(num_chunks * 2, stream); + if (total_string_pool_size > 0) { + rmm::device_uvector string_pool(total_string_pool_size, stream); + + // offsets describes where in the string pool each string goes. Going with the simple + // approach for now, but it is possible something fancier with breaking up each thread into + // copying x bytes instead of a single string is the better method since we are dealing in + // min/max strings they almost certainly will not be uniform length. + copy_string_data<<>>( + string_pool.data(), + offsets.data(), + intermediate_stats.stripe_stat_chunks.data(), + intermediate_stats.stripe_stat_merge.device_ptr()); + string_pools.emplace_back(std::move(string_pool)); + } + } + + stripe_stat_chunks.emplace_back(std::move(intermediate_stats.stripe_stat_chunks)); + stripe_stat_merge.emplace_back(std::move(intermediate_stats.stripe_stat_merge)); + stats_dtypes = std::move(intermediate_stats.stats_dtypes); + col_types = std::move(intermediate_stats.col_types); + num_rows = num_table_rows; +} + void writer::impl::write(table_view const& table) { CUDF_EXPECTS(not closed, "Data has already been flushed to out and closed"); @@ -2075,13 +2167,8 @@ void writer::impl::write(table_view const& table) auto intermediate_stats = gather_statistic_blobs(stats_freq_, orc_table, segmentation); if (intermediate_stats.stripe_stat_chunks.size() > 0) { - persisted_stripe_statistics.stripe_stat_chunks.emplace_back( - std::move(intermediate_stats.stripe_stat_chunks)); - persisted_stripe_statistics.stripe_stat_merge.emplace_back( - std::move(intermediate_stats.stripe_stat_merge)); - persisted_stripe_statistics.stats_dtypes = std::move(intermediate_stats.stats_dtypes); - persisted_stripe_statistics.col_types = std::move(intermediate_stats.col_types); - persisted_stripe_statistics.num_rows = orc_table.num_rows(); + persisted_stripe_statistics.persist( + orc_table.num_rows(), single_write_mode, intermediate_stats, stream); } // Write stripes @@ -2141,7 +2228,6 @@ void writer::impl::write(table_view const& table) } out_sink_->host_write(buffer_.data(), buffer_.size()); } - for (auto const& task : write_tasks) { task.wait(); } @@ -2204,7 +2290,7 @@ void writer::impl::close() auto const statistics = finish_statistic_blobs(ff.stripes.size(), persisted_stripe_statistics); // File-level statistics - if (single_write_mode and not statistics.file_level.empty()) { + if (not statistics.file_level.empty()) { buffer_.resize(0); pbw_.put_uint(encode_field_number(1)); pbw_.put_uint(persisted_stripe_statistics.num_rows); diff --git a/cpp/src/io/orc/writer_impl.hpp b/cpp/src/io/orc/writer_impl.hpp index d823c73007f..577c22f8ac3 100644 --- a/cpp/src/io/orc/writer_impl.hpp +++ b/cpp/src/io/orc/writer_impl.hpp @@ -304,7 +304,7 @@ class writer::impl { stats_dtypes(std::move(sdt)), col_types(std::move(sct)){}; - // blobs for the rowgroups and stripes. Not persisted + // blobs for the rowgroups. Not persisted std::vector rowgroup_blobs; rmm::device_uvector stripe_stat_chunks; @@ -322,13 +322,20 @@ class writer::impl { { stripe_stat_chunks.clear(); stripe_stat_merge.clear(); + string_pools.clear(); stats_dtypes.clear(); col_types.clear(); num_rows = 0; } + void persist(int num_table_rows, + bool single_write_mode, + intermediate_statistics& intermediate_stats, + rmm::cuda_stream_view stream); + std::vector> stripe_stat_chunks; std::vector> stripe_stat_merge; + std::vector> string_pools; std::vector stats_dtypes; std::vector col_types; int num_rows = 0; diff --git a/cpp/src/join/semi_join.cu b/cpp/src/join/semi_join.cu index 687e553fefd..b7b33000707 100644 --- a/cpp/src/join/semi_join.cu +++ b/cpp/src/join/semi_join.cu @@ -137,19 +137,28 @@ std::unique_ptr> left_semi_anti_join( auto gather_map = std::make_unique>(left_num_rows, stream, mr); - // gather_map_end will be the end of valid data in gather_map - auto gather_map_end = thrust::copy_if( + rmm::device_uvector flagged(left_num_rows, stream, mr); + auto flagged_d = flagged.data(); + + auto counting_iter = thrust::counting_iterator(0); + thrust::for_each( rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(left_num_rows), - gather_map->begin(), - [hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( - size_type const idx) { - // Look up this row. The hash function used here needs to map a (left) row index to the hash - // of the row, so it's a row hash. The equality check needs to verify - return hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; + counting_iter, + counting_iter + left_num_rows, + [flagged_d, hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( + const size_type idx) { + flagged_d[idx] = + hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; }); + // gather_map_end will be the end of valid data in gather_map + auto gather_map_end = + thrust::copy_if(rmm::exec_policy(stream), + counting_iter, + counting_iter + left_num_rows, + gather_map->begin(), + [flagged_d] __device__(size_type const idx) { return flagged_d[idx]; }); + auto join_size = thrust::distance(gather_map->begin(), gather_map_end); gather_map->resize(join_size, stream); return gather_map; diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index c4ffa7f0fb1..987cd076fd0 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -27,13 +25,8 @@ #include #include #include -#include #include -#include - -#include -#include namespace cudf { namespace strings { @@ -41,51 +34,52 @@ namespace detail { namespace { /** - * @brief This functor handles both contains_re and match_re to minimize the number - * of regex calls to find() to be inlined greatly reducing compile time. + * @brief This functor handles both contains_re and match_re to regex-match a pattern + * to each string in a column. */ -template struct contains_fn { - reprog_device prog; column_device_view const d_strings; - bool const beginning_only; // do not make this a template parameter to keep compile times down + bool const beginning_only; - __device__ bool operator()(size_type idx) + __device__ bool operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) { if (d_strings.is_null(idx)) return false; auto const d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = beginning_only ? 1 // match only the beginning of the string; - : -1; // match anywhere in the string - return static_cast(prog.find(idx, d_str, begin, end)); + + size_type begin = 0; + size_type end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string + return static_cast(prog.find(thread_idx, d_str, begin, end)); } }; -struct contains_dispatch_fn { - reprog_device d_prog; - bool const beginning_only; +std::unique_ptr contains_impl(strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + bool const beginning_only, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto results = make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + if (input.is_empty()) { return results; } - template - std::unique_ptr operator()(strings_column_view const& input, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto results = make_numeric_column(data_type{type_id::BOOL8}, - input.size(), - cudf::detail::copy_bitmask(input.parent(), stream, mr), - input.null_count(), - stream, - mr); - - auto const d_strings = column_device_view::create(input.parent(), stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.size()), - results->mutable_view().data(), - contains_fn{d_prog, *d_strings, beginning_only}); - return results; - } -}; + auto d_prog = reprog_device::create(pattern, flags, stream); + + auto d_results = results->mutable_view().data(); + auto const d_strings = column_device_view::create(input.parent(), stream); + + launch_transform_kernel( + contains_fn{*d_strings, beginning_only}, *d_prog, d_results, input.size(), stream); + + return results; +} } // namespace @@ -96,10 +90,7 @@ std::unique_ptr contains_re( rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - - return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr); + return contains_impl(input, pattern, flags, false, stream, mr); } std::unique_ptr matches_re( @@ -109,21 +100,18 @@ std::unique_ptr matches_re( rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - - return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr); + return contains_impl(input, pattern, flags, true, stream, mr); } -std::unique_ptr count_re(strings_column_view const& input, - std::string const& pattern, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr count_re( + strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const d_strings = column_device_view::create(input.parent(), stream); diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index a850315dfec..d807482a3a7 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -15,41 +15,35 @@ */ #include -#include -#include +#include #include #include #include -#include - -#include -#include - namespace cudf { namespace strings { namespace detail { namespace { /** - * @brief Functor counts the total matches to the given regex in each string. + * @brief Kernel counts the total matches for the given regex in each string. */ -template -struct count_matches_fn { +struct count_fn { column_device_view const d_strings; - reprog_device prog; - __device__ size_type operator()(size_type idx) + __device__ int32_t operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) { - if (d_strings.is_null(idx)) { return 0; } - size_type count = 0; + if (d_strings.is_null(idx)) return 0; auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); + int32_t count = 0; - int32_t begin = 0; - int32_t end = nchars; - while ((begin < end) && (prog.find(idx, d_str, begin, end) > 0)) { + size_type begin = 0; + size_type end = nchars; + while ((begin < end) && (prog.find(thread_idx, d_str, begin, end) > 0)) { ++count; begin = end + (begin == end); end = nchars; @@ -58,41 +52,26 @@ struct count_matches_fn { } }; -struct count_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, - size_type output_size, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - assert(output_size >= d_strings.size() and "Unexpected output size"); - - auto results = make_numeric_column( - data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr); - - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(d_strings.size()), - results->mutable_view().data(), - count_matches_fn{d_strings, d_prog}); - return results; - } -}; - } // namespace -/** - * @copydoc cudf::strings::detail::count_matches - */ std::unique_ptr count_matches(column_device_view const& d_strings, - reprog_device const& d_prog, + reprog_device& d_prog, size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, output_size, stream, mr); + assert(output_size >= d_strings.size() and "Unexpected output size"); + + auto results = make_numeric_column( + data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr); + + if (d_strings.size() == 0) return results; + + auto d_results = results->mutable_view().data(); + + launch_transform_kernel(count_fn{d_strings}, d_prog, d_results, d_strings.size(), stream); + + return results; } } // namespace detail diff --git a/cpp/src/strings/count_matches.hpp b/cpp/src/strings/count_matches.hpp index efff3958c65..d4bcdaf4042 100644 --- a/cpp/src/strings/count_matches.hpp +++ b/cpp/src/strings/count_matches.hpp @@ -39,10 +39,11 @@ class reprog_device; * @param output_size Number of rows for the output column. * @param stream CUDA stream used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned column's device memory. + * @return Integer column of match counts */ std::unique_ptr count_matches( column_device_view const& d_strings, - reprog_device const& d_prog, + reprog_device& d_prog, size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index 9e987cf5879..59b90952d97 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include -#include -#include +#include #include #include @@ -31,7 +29,7 @@ #include #include -#include +#include #include #include #include @@ -47,28 +45,26 @@ using string_index_pair = thrust::pair; /** * @brief This functor handles extracting strings by applying the compiled regex pattern * and creating string_index_pairs for all the substrings. - * - * @tparam stack_size Correlates to the regex instructions state to maintain for each string. - * Each instruction requires a fixed amount of overhead data. */ -template struct extract_fn { - reprog_device prog; column_device_view const d_strings; cudf::detail::device_2dspan d_indices; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, + reprog_device const d_prog, + int32_t const prog_idx) { - auto const groups = prog.group_counts(); + auto const groups = d_prog.group_counts(); auto d_output = d_indices[idx]; if (d_strings.is_valid(idx)) { auto const d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = -1; // handles empty strings automatically - if (prog.find(idx, d_str, begin, end) > 0) { + + size_type begin = 0; + size_type end = -1; // handles empty strings automatically + if (d_prog.find(prog_idx, d_str, begin, end) > 0) { for (auto col_idx = 0; col_idx < groups; ++col_idx) { - auto const extracted = prog.extract(idx, d_str, begin, end, col_idx); + auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx); d_output[col_idx] = [&] { if (!extracted) return string_index_pair{nullptr, 0}; auto const offset = d_str.byte_offset((*extracted).first); @@ -85,33 +81,17 @@ struct extract_fn { } }; -struct extract_dispatch_fn { - reprog_device d_prog; - - template - void operator()(column_device_view const& d_strings, - cudf::detail::device_2dspan& d_indices, - rmm::cuda_stream_view stream) - { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - extract_fn{d_prog, d_strings, d_indices}); - } -}; } // namespace // -std::unique_ptr extract( - strings_column_view const& input, - std::string const& pattern, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr
extract(strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern"); @@ -121,7 +101,8 @@ std::unique_ptr
extract( cudf::detail::device_2dspan(indices.data(), input.size(), groups); auto const d_strings = column_device_view::create(input.parent(), stream); - regex_dispatcher(*d_prog, extract_dispatch_fn{*d_prog}, *d_strings, d_indices, stream); + + launch_for_each_kernel(extract_fn{*d_strings, d_indices}, *d_prog, input.size(), stream); // build a result column for each group std::vector> results(groups); diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 7dce369a24f..95b8a43a9d4 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -30,9 +28,7 @@ #include #include -#include #include -#include #include namespace cudf { @@ -49,14 +45,14 @@ namespace { * The `d_offsets` are pre-computed to identify the location of where each * string's output groups are to be written. */ -template struct extract_fn { column_device_view const d_strings; - reprog_device d_prog; offset_type const* d_offsets; string_index_pair* d_indices; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, + reprog_device const d_prog, + int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } @@ -64,16 +60,17 @@ struct extract_fn { auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); - int32_t begin = 0; - int32_t end = d_str.length(); + size_type begin = 0; + size_type end = nchars; // match the regex - while ((begin < end) && d_prog.find(idx, d_str, begin, end) > 0) { + while ((begin < end) && d_prog.find(prog_idx, d_str, begin, end) > 0) { // extract each group into the output for (auto group_idx = 0; group_idx < groups; ++group_idx) { // result is an optional containing the bounds of the extracted string at group_idx - auto const extracted = d_prog.extract(idx, d_str, begin, end, group_idx); + auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, group_idx); d_output[group_idx + output_idx] = [&] { if (!extracted) { return string_index_pair{nullptr, 0}; } @@ -84,33 +81,12 @@ struct extract_fn { } // continue to next match begin = end; - end = d_str.length(); + end = nchars; output_idx += groups; } } }; -struct extract_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, - size_type total_groups, - offset_type const* d_offsets, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - rmm::device_uvector indices(total_groups, stream); - - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - extract_fn{d_strings, d_prog, d_offsets, indices.data()}); - - return make_strings_column(indices.begin(), indices.end(), stream, mr); - } -}; - } // namespace /** @@ -129,8 +105,7 @@ std::unique_ptr extract_all_record( auto const d_strings = column_device_view::create(input.parent(), stream); // Compile regex into device object. - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, flags, stream); // The extract pattern should always include groups. auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "extract_all requires group indicators in the regex pattern."); @@ -168,8 +143,12 @@ std::unique_ptr extract_all_record( auto const total_groups = cudf::detail::get_value(offsets->view(), strings_count, stream); - auto strings_output = regex_dispatcher( - *d_prog, extract_dispatch_fn{*d_prog}, *d_strings, total_groups, d_offsets, stream, mr); + rmm::device_uvector indices(total_groups, stream); + + launch_for_each_kernel( + extract_fn{*d_strings, d_offsets, indices.data()}, *d_prog, strings_count, stream); + + auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); // Build the lists column from the offsets and the strings. return make_lists_column(strings_count, diff --git a/cpp/src/strings/regex/dispatcher.hpp b/cpp/src/strings/regex/dispatcher.hpp deleted file mode 100644 index 9ff51d1c979..00000000000 --- a/cpp/src/strings/regex/dispatcher.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace cudf { -namespace strings { -namespace detail { - -/** - * The stack is used to keep progress (state) on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are four call types based on the number of regex instructions in the given pattern. - * Small, medium, and large instruction counts can use the stack effectively. - * Smaller stack sizes execute faster. - * - * Patterns with instruction counts bigger than large use global memory rather than the stack - * for managing the evaluation state data. - * - * @tparam Functor The functor to invoke with stack size templated value. - * @tparam Ts Parameter types for the functor call. - */ -template -constexpr decltype(auto) regex_dispatcher(reprog_device d_prog, Functor f, Ts&&... args) -{ - auto const num_regex_insts = d_prog.insts_counts(); - if (num_regex_insts <= RX_SMALL_INSTS) { - return f.template operator()(std::forward(args)...); - } - if (num_regex_insts <= RX_MEDIUM_INSTS) { - return f.template operator()(std::forward(args)...); - } - if (num_regex_insts <= RX_LARGE_INSTS) { - return f.template operator()(std::forward(args)...); - } - - return f.template operator()(std::forward(args)...); -} - -} // namespace detail -} // namespace strings -} // namespace cudf diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index bcdd15bceda..5ccc70222d5 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -39,23 +39,9 @@ struct relist; using match_pair = thrust::pair; using match_result = thrust::optional; -constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size -constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size -constexpr int32_t RX_STACK_LARGE = 2560; ///< fast stack size -constexpr int32_t RX_STACK_ANY = 8; ///< slowest: uses global memory - -/** - * @brief Mapping the number of instructions to device code stack memory size. - * - * ``` - * 10128 ≈ 1000 instructions - * Formula is based on relist::data_size_for() calculation; - * Stack ≈ (8+2)*x + (x/8) = 10.125x < 11x where x is number of instructions - * ``` - */ -constexpr int32_t RX_SMALL_INSTS = (RX_STACK_SMALL / 11); -constexpr int32_t RX_MEDIUM_INSTS = (RX_STACK_MEDIUM / 11); -constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE / 11); +constexpr int32_t MAX_SHARED_MEM = 2048; ///< Memory size for storing prog instruction data +constexpr std::size_t MAX_WORKING_MEM = 0x01FFFFFFFF; ///< Memory size for state data +constexpr int32_t MINIMUM_THREADS = 256; // Minimum threads for computing working memory /** * @brief Regex class stored on the device and executed by reprog_device. @@ -75,6 +61,12 @@ struct alignas(16) reclass_device { * * Once created, the find/extract methods are used to evaluate the regex instructions * against a single string. + * + * An instance of the class requires working memory for evaluating the regex + * instructions for the string. Determine the size of the required memory by + * calling either `working_memory_size()` or `compute_strided_working_memory()`. + * Once the buffer is allocated, pass the device pointer to the `set_working_memory()` + * member function. */ class reprog_device { public: @@ -92,33 +84,22 @@ class reprog_device { * regex. * * @param pattern The regex pattern to compile. - * @param codepoint_flags The code point lookup table for character types. - * @param strings_count Number of strings that will be evaluated. * @param stream CUDA stream used for device memory operations and kernel launches. * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream); + std::string const& pattern, rmm::cuda_stream_view stream); /** * @brief Create the device program instance from a regex pattern. * * @param pattern The regex pattern to compile. * @param re_flags Regex flags for interpreting special characters in the pattern. - * @param codepoint_flags The code point lookup table for character types. - * @param strings_count Number of strings that will be evaluated. * @param stream CUDA stream used for device memory operations and kernel launches * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, - regex_flags const re_flags, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream); + std::string const& pattern, regex_flags const re_flags, rmm::cuda_stream_view stream); /** * @brief Called automatically by the unique_ptr returned from create(). @@ -143,12 +124,75 @@ class reprog_device { */ [[nodiscard]] __device__ inline bool is_empty() const; + /** + * @brief Returns the size needed for working memory for the given thread count. + * + * @param num_threads Number of threads to be executed in parallel + * @return Size of working memory in bytes + */ + [[nodiscard]] std::size_t working_memory_size(int32_t num_threads) const; + + /** + * @brief Compute working memory for the given thread count with a maximum size. + * + * The `min_rows` overrules the `requested_max_size`. + * That is, the `requested_max_size` may be + * exceeded to keep the number of rows greater than `min_rows`. + * Also, if `rows < min_rows` then `min_rows` is not enforced. + * + * @param rows Number of rows to execute in parallel + * @param min_rows The least number of rows to meet `max_size` + * @param requested_max_size Requested maximum bytes for the working memory + * @return The size of the working memory and the number of parallel rows it will support + */ + [[nodiscard]] std::pair compute_strided_working_memory( + int32_t rows, + int32_t min_rows = MINIMUM_THREADS, + std::size_t requested_max_size = MAX_WORKING_MEM) const; + + /** + * @brief Set the device working memory buffer to use for the regex execution. + * + * @param buffer Device memory pointer. + * @param thread_count Number of threads the memory buffer will support. + * @param max_insts Set to the maximum instruction count if reusing the + * memory buffer for other regex calls. + */ + void set_working_memory(void* buffer, int32_t thread_count, int32_t max_insts = 0); + + /** + * @brief Returns the size of shared memory required to hold this instance. + * + * This can be called on the CPU for specifying the shared-memory size in the + * kernel launch parameters. + * This may return 0 if the MAX_SHARED_MEM value is exceeded. + */ + [[nodiscard]] int32_t compute_shared_memory_size() const; + + /** + * @brief Returns the thread count passed on `set_working_memory`. + */ + [[nodiscard]] __device__ inline int32_t thread_count() const { return _thread_count; } + + /** + * @brief Store this object into the given device pointer (e.g. shared memory). + * + * No data is stored if MAX_SHARED_MEM is exceeded for this object. + */ + __device__ inline void store(void* buffer) const; + + /** + * @brief Load an instance of this class from a device buffer (e.g. shared memory). + * + * Data is loaded from the given buffer if MAX_SHARED_MEM is not exceeded for the given object. + * Otherwise, a copy of the object is returned. + */ + [[nodiscard]] __device__ static inline reprog_device load(reprog_device const prog, void* buffer); + /** * @brief Does a find evaluation using the compiled expression on the given string. * - * @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`. - * @param idx The string index used for mapping the state memory for this string in global memory - * (if necessary). + * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. * @param[in,out] begin Position index to begin the search. If found, returns the position found * in the string. @@ -156,8 +200,7 @@ class reprog_device { * matching in the string. * @return Returns 0 if no match is found. */ - template - __device__ inline int32_t find(int32_t idx, + __device__ inline int32_t find(int32_t const thread_idx, string_view const d_str, cudf::size_type& begin, cudf::size_type& end) const; @@ -169,9 +212,7 @@ class reprog_device { * The find() function should be called first to locate the begin/end bounds of the * the matched section. * - * @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`. - * @param idx The string index used for mapping the state memory for this string in global - * memory (if necessary). + * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. * @param begin Position index to begin the search. If found, returns the position found * in the string. @@ -180,8 +221,7 @@ class reprog_device { * @param group_id The specific group to return its matching position values. * @return If valid, returns the character position of the matched group in the given string, */ - template - __device__ inline match_result extract(cudf::size_type idx, + __device__ inline match_result extract(int32_t const thread_idx, string_view const d_str, cudf::size_type begin, cudf::size_type end, @@ -220,8 +260,7 @@ class reprog_device { /** * @brief Utility wrapper to setup state memory structures for calling regexec */ - template - __device__ inline int32_t call_regexec(int32_t idx, + __device__ inline int32_t call_regexec(int32_t const thread_idx, string_view const d_str, cudf::size_type& begin, cudf::size_type& end, @@ -234,13 +273,16 @@ class reprog_device { int32_t _insts_count; // number of instructions int32_t _starts_count; // number of start-insts ids int32_t _classes_count; // number of classes + int32_t _max_insts; // for partitioning working memory uint8_t const* _codepoint_flags{}; // table of character types reinst const* _insts{}; // array of regex instructions int32_t const* _startinst_ids{}; // array of start instruction ids reclass_device const* _classes{}; // array of regex classes - void* _relists_mem{}; // runtime relist memory for regexec() + std::size_t _prog_size{}; // total size of this instance + void* _buffer{}; // working memory buffer + int32_t _thread_count{}; // threads available in working memory }; } // namespace detail diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index 8bb12187d72..8e2194f2094 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -45,10 +45,9 @@ struct alignas(8) relist { /** * @brief Compute the aligned memory allocation size. */ - constexpr inline static std::size_t alloc_size(int32_t insts) + constexpr inline static std::size_t alloc_size(int32_t insts, int32_t num_threads) { - return cudf::util::round_up_unsafe(data_size_for(insts) + sizeof(relist), - sizeof(ranges[0])); + return cudf::util::round_up_unsafe(data_size_for(insts) * num_threads, sizeof(restate)); } struct alignas(16) restate { @@ -57,16 +56,16 @@ struct alignas(8) relist { int32_t reserved; }; - __device__ __forceinline__ relist(int16_t insts, u_char* data = nullptr) - : masksize(cudf::util::div_rounding_up_unsafe(insts, 8)) + __device__ __forceinline__ + relist(int16_t insts, int32_t num_threads, u_char* gp_ptr, int32_t index) + : masksize(cudf::util::div_rounding_up_unsafe(insts, 8)), stride(num_threads) { - auto ptr = data == nullptr ? reinterpret_cast(this) + sizeof(relist) : data; - ranges = reinterpret_cast(ptr); - ptr += insts * sizeof(ranges[0]); - inst_ids = reinterpret_cast(ptr); - ptr += insts * sizeof(inst_ids[0]); - mask = ptr; - reset(); + auto const rdata_size = sizeof(ranges[0]); + auto const idata_size = sizeof(inst_ids[0]); + ranges = reinterpret_cast(gp_ptr + (index * rdata_size)); + inst_ids = + reinterpret_cast(gp_ptr + (rdata_size * stride * insts) + (index * idata_size)); + mask = gp_ptr + ((rdata_size + idata_size) * stride * insts) + (index * masksize); } __device__ __forceinline__ void reset() @@ -79,15 +78,15 @@ struct alignas(8) relist { { if (readMask(id)) { return false; } writeMask(id); - inst_ids[size] = static_cast(id); - ranges[size] = int2{begin, end}; + inst_ids[size * stride] = static_cast(id); + ranges[size * stride] = int2{begin, end}; ++size; return true; } __device__ __forceinline__ restate get_state(int16_t idx) const { - return restate{ranges[idx], inst_ids[idx]}; + return restate{ranges[idx * stride], inst_ids[idx * stride]}; } __device__ __forceinline__ int16_t get_size() const { return size; } @@ -95,7 +94,7 @@ struct alignas(8) relist { private: int16_t size{}; int16_t const masksize; - int32_t reserved; + int32_t const stride; int2* __restrict__ ranges; // pair per instruction int16_t* __restrict__ inst_ids; // one per instruction u_char* __restrict__ mask; // bit per instruction @@ -177,6 +176,49 @@ __device__ __forceinline__ bool reprog_device::is_empty() const return insts_counts() == 0 || get_inst(0).type == END; } +__device__ __forceinline__ void reprog_device::store(void* buffer) const +{ + if (_prog_size > MAX_SHARED_MEM) { return; } + + auto ptr = static_cast(buffer); + + // create instance inside the given buffer + auto result = new (ptr) reprog_device(*this); + + // add the insts array + ptr += sizeof(reprog_device); + auto insts = reinterpret_cast(ptr); + result->_insts = insts; + for (int idx = 0; idx < _insts_count; ++idx) + *insts++ = _insts[idx]; + + // add the startinst_ids array + ptr += cudf::util::round_up_unsafe(_insts_count * sizeof(_insts[0]), sizeof(_startinst_ids[0])); + auto ids = reinterpret_cast(ptr); + result->_startinst_ids = ids; + for (int idx = 0; idx < _starts_count; ++idx) + *ids++ = _startinst_ids[idx]; + + // add the classes array + ptr += cudf::util::round_up_unsafe(_starts_count * sizeof(int32_t), sizeof(_classes[0])); + auto classes = reinterpret_cast(ptr); + result->_classes = classes; + // fill in each class + auto d_ptr = reinterpret_cast(classes + _classes_count); + for (int idx = 0; idx < _classes_count; ++idx) { + classes[idx] = _classes[idx]; + classes[idx].literals = d_ptr; + for (int jdx = 0; jdx < _classes[idx].count * 2; ++jdx) + *d_ptr++ = _classes[idx].literals[jdx]; + } +} + +__device__ __forceinline__ reprog_device reprog_device::load(reprog_device const prog, void* buffer) +{ + return (prog._prog_size > MAX_SHARED_MEM) ? reprog_device(prog) + : reinterpret_cast(buffer)[0]; +} + /** * @brief Evaluate a specific string against regex pattern compiled to this instance. * @@ -352,65 +394,43 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr return match; } -template -__device__ __forceinline__ int32_t reprog_device::find(int32_t idx, +__device__ __forceinline__ int32_t reprog_device::find(int32_t const thread_idx, string_view const dstr, cudf::size_type& begin, cudf::size_type& end) const { - int32_t rtn = call_regexec(idx, dstr, begin, end); + auto const rtn = call_regexec(thread_idx, dstr, begin, end); if (rtn <= 0) begin = end = -1; return rtn; } -template -__device__ __forceinline__ match_result reprog_device::extract(cudf::size_type idx, +__device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx, string_view const dstr, cudf::size_type begin, cudf::size_type end, cudf::size_type const group_id) const { end = begin + 1; - return call_regexec(idx, dstr, begin, end, group_id + 1) > 0 - ? match_result({begin, end}) - : thrust::nullopt; + return call_regexec(thread_idx, dstr, begin, end, group_id + 1) > 0 ? match_result({begin, end}) + : thrust::nullopt; } -template -__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t idx, +__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t const thread_idx, string_view const dstr, cudf::size_type& begin, cudf::size_type& end, cudf::size_type const group_id) const { - u_char data1[stack_size], data2[stack_size]; + auto gp_ptr = reinterpret_cast(_buffer); + relist list1(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); - relist list1(static_cast(_insts_count), data1); - relist list2(static_cast(_insts_count), data2); + gp_ptr += relist::alloc_size(_max_insts, _thread_count); + relist list2(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); reljunk jnk(&list1, &list2, get_inst(_startinst_id)); return regexec(dstr, jnk, begin, end, group_id); } -template <> -__device__ __forceinline__ int32_t -reprog_device::call_regexec(int32_t idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const -{ - auto const relists_size = relist::alloc_size(_insts_count); - auto* listmem = reinterpret_cast(_relists_mem); // beginning of relist buffer; - listmem += (idx * relists_size * 2); // two relist ptrs in reljunk: - - auto* list1 = new (listmem) relist(static_cast(_insts_count)); - auto* list2 = new (listmem + relists_size) relist(static_cast(_insts_count)); - - reljunk jnk(list1, list2, get_inst(_startinst_id)); - return regexec(dstr, jnk, begin, end, group_id); -} - } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 70d6079972a..4b58d9d8a88 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -35,27 +36,21 @@ reprog_device::reprog_device(reprog& prog) _num_capturing_groups{prog.groups_count()}, _insts_count{prog.insts_count()}, _starts_count{prog.starts_count()}, - _classes_count{prog.classes_count()} + _classes_count{prog.classes_count()}, + _max_insts{prog.insts_count()}, + _codepoint_flags{get_character_flags_table()} { } std::unique_ptr> reprog_device::create( - std::string const& pattern, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream) + std::string const& pattern, rmm::cuda_stream_view stream) { - return reprog_device::create( - pattern, regex_flags::MULTILINE, codepoint_flags, strings_count, stream); + return reprog_device::create(pattern, regex_flags::MULTILINE, stream); } // Create instance of the reprog that can be passed into a device kernel std::unique_ptr> reprog_device::create( - std::string const& pattern, - regex_flags const flags, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream) + std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream) { // compile pattern into host object reprog h_prog = reprog::create_from(pattern, flags); @@ -82,7 +77,7 @@ std::unique_ptr> reprog_devic auto d_buffer = new rmm::device_buffer(memsize, stream); // output device memory; auto d_ptr = reinterpret_cast(d_buffer->data()); // running device pointer - // put everything into a flat host buffer first + // create our device object; this is managed separately and returned to the caller reprog_device* d_prog = new reprog_device(h_prog); // copy the instructions array first (fixed-sized structs) @@ -120,32 +115,58 @@ std::unique_ptr> reprog_devic } // initialize the rest of the elements - d_prog->_codepoint_flags = codepoint_flags; - - // allocate execute memory if needed - rmm::device_buffer* d_relists{}; - if (insts_count > RX_LARGE_INSTS) { - // two relist state structures are needed for execute per string - auto const rlm_size = relist::alloc_size(insts_count) * 2 * strings_count; - d_relists = new rmm::device_buffer(rlm_size, stream); - d_prog->_relists_mem = d_relists->data(); - } + d_prog->_max_insts = insts_count; + d_prog->_prog_size = memsize + sizeof(reprog_device); // copy flat prog to device memory CUDF_CUDA_TRY(cudaMemcpyAsync( d_buffer->data(), h_buffer.data(), memsize, cudaMemcpyHostToDevice, stream.value())); // build deleter to cleanup device memory - auto deleter = [d_buffer, d_relists](reprog_device* t) { + auto deleter = [d_buffer](reprog_device* t) { t->destroy(); delete d_buffer; - delete d_relists; }; + return std::unique_ptr>(d_prog, deleter); } void reprog_device::destroy() { delete this; } +std::size_t reprog_device::working_memory_size(int32_t num_threads) const +{ + return relist::alloc_size(_insts_count, num_threads) * 2; +} + +std::pair reprog_device::compute_strided_working_memory( + int32_t rows, int32_t min_rows, std::size_t requested_max_size) const +{ + auto thread_count = rows; + auto buffer_size = working_memory_size(thread_count); + while ((buffer_size > requested_max_size) && (thread_count > min_rows)) { + thread_count = thread_count / 2; + buffer_size = working_memory_size(thread_count); + } + // clamp to min_rows but only if rows is greater than min_rows + if (rows > min_rows && thread_count < min_rows) { + thread_count = min_rows; + buffer_size = working_memory_size(thread_count); + } + return std::make_pair(buffer_size, thread_count); +} + +void reprog_device::set_working_memory(void* buffer, int32_t thread_count, int32_t max_insts) +{ + _buffer = buffer; + _thread_count = thread_count; + _max_insts = _max_insts > 0 ? _max_insts : _insts_count; +} + +int32_t reprog_device::compute_shared_memory_size() const +{ + return _prog_size < MAX_SHARED_MEM ? static_cast(_prog_size) : 0; +} + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/utilities.cuh b/cpp/src/strings/regex/utilities.cuh new file mode 100644 index 00000000000..9a80be25b3b --- /dev/null +++ b/cpp/src/strings/regex/utilities.cuh @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { + +constexpr auto regex_launch_kernel_block_size = 256; + +template +__global__ void for_each_kernel(ForEachFunction fn, reprog_device const d_prog, size_type size) +{ + extern __shared__ u_char shmem[]; + if (threadIdx.x == 0) { d_prog.store(shmem); } + __syncthreads(); + auto const s_prog = reprog_device::load(d_prog, shmem); + + auto const thread_idx = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = s_prog.thread_count(); + for (auto idx = thread_idx; idx < size; idx += stride) { + fn(idx, s_prog, thread_idx); + } +} + +template +void launch_for_each_kernel(ForEachFunction fn, + reprog_device& d_prog, + size_type size, + rmm::cuda_stream_view stream) +{ + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(size); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, regex_launch_kernel_block_size}; + for_each_kernel<<>>( + fn, d_prog, size); +} + +template +__global__ void transform_kernel(TransformFunction fn, + reprog_device const d_prog, + OutputType* d_output, + size_type size) +{ + extern __shared__ u_char shmem[]; + if (threadIdx.x == 0) { d_prog.store(shmem); } + __syncthreads(); + auto const s_prog = reprog_device::load(d_prog, shmem); + + auto const thread_idx = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = s_prog.thread_count(); + for (auto idx = thread_idx; idx < size; idx += stride) { + d_output[idx] = fn(idx, s_prog, thread_idx); + } +} + +template +void launch_transform_kernel(TransformFunction fn, + reprog_device& d_prog, + OutputType* d_output, + size_type size, + rmm::cuda_stream_view stream) +{ + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(size); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, regex_launch_kernel_block_size}; + transform_kernel<<>>( + fn, d_prog, d_output, size); +} + +template +auto make_strings_children(SizeAndExecuteFunction size_and_exec_fn, + reprog_device& d_prog, + size_type strings_count, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto offsets = make_numeric_column( + data_type{type_id::INT32}, strings_count + 1, mask_state::UNALLOCATED, stream, mr); + auto d_offsets = offsets->mutable_view().template data(); + size_and_exec_fn.d_offsets = d_offsets; + + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(strings_count); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, 256}; + + // Compute the output size for each row + if (strings_count > 0) { + for_each_kernel<<>>( + size_and_exec_fn, d_prog, strings_count); + } + + // Convert sizes to offsets + thrust::exclusive_scan( + rmm::exec_policy(stream), d_offsets, d_offsets + strings_count + 1, d_offsets); + + // Now build the chars column + auto const char_bytes = cudf::detail::get_value(offsets->view(), strings_count, stream); + std::unique_ptr chars = create_chars_child_column(char_bytes, stream, mr); + if (char_bytes > 0) { + size_and_exec_fn.d_chars = chars->mutable_view().template data(); + for_each_kernel<<>>( + size_and_exec_fn, d_prog, strings_count); + } + + return std::make_pair(std::move(offsets), std::move(chars)); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 384813d6e3d..107adf07263 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -16,9 +16,7 @@ #include "backref_re.cuh" -#include -#include -#include +#include #include #include @@ -43,7 +41,7 @@ namespace { * @brief Return the capturing group index pattern to use with the given replacement string. * * Only two patterns are supported at this time `\d` and `${d}` where `d` is an integer in - * the range 1-99. The `\d` pattern is returned by default unless no `\d` pattern is found in + * the range 0-99. The `\d` pattern is returned by default unless no `\d` pattern is found in * the `repl` string, * * Reference: https://www.regular-expressions.info/refreplacebackref.html @@ -98,45 +96,15 @@ std::pair> parse_backrefs(std::string con return {rtn, backrefs}; } -template -struct replace_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(strings_column_view const& input, - string_view const& d_repl_template, - Iterator backrefs_begin, - Iterator backrefs_end, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - - auto children = make_strings_children( - backrefs_fn{ - *d_strings, d_prog, d_repl_template, backrefs_begin, backrefs_end}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace // -std::unique_ptr replace_with_backrefs( - strings_column_view const& input, - std::string const& pattern, - std::string const& replacement, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr replace_with_backrefs(strings_column_view const& input, + std::string const& pattern, + std::string const& replacement, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { if (input.is_empty()) return make_empty_column(type_id::STRING); @@ -144,8 +112,7 @@ std::unique_ptr replace_with_backrefs( CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty"); // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); // parse the repl string for back-ref indicators auto group_count = std::min(99, d_prog->group_counts()); // group count should NOT exceed 99 @@ -155,15 +122,21 @@ std::unique_ptr replace_with_backrefs( string_scalar repl_scalar(parse_result.first, true, stream); string_view const d_repl_template = repl_scalar.value(); + auto const d_strings = column_device_view::create(input.parent(), stream); + using BackRefIterator = decltype(backrefs.begin()); - return regex_dispatcher(*d_prog, - replace_dispatch_fn{*d_prog}, - input, - d_repl_template, - backrefs.begin(), - backrefs.end(), - stream, - mr); + auto children = make_strings_children( + backrefs_fn{*d_strings, d_repl_template, backrefs.begin(), backrefs.end()}, + *d_prog, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index 13a67e3b4d7..db5b8a1eb17 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -14,13 +14,13 @@ * limitations under the License. */ +#include + #include #include #include #include -#include - #include #include @@ -39,17 +39,16 @@ using backref_type = thrust::pair; * * The logic includes computing the size of each string and also writing the output. */ -template +template struct backrefs_fn { column_device_view const d_strings; - reprog_device prog; string_view const d_repl; // string replacement template Iterator backrefs_begin; Iterator backrefs_end; int32_t* d_offsets{}; char* d_chars{}; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { if (!d_chars) d_offsets[idx] = 0; @@ -65,7 +64,7 @@ struct backrefs_fn { size_type end = nchars; // last character position (exclusive) // copy input to output replacing strings as we go - while (prog.find(idx, d_str, begin, end) > 0) // inits the begin/end vars + while (prog.find(prog_idx, d_str, begin, end) > 0) // inits the begin/end vars { auto spos = d_str.byte_offset(begin); // get offset for the auto epos = d_str.byte_offset(end); // character position values; @@ -84,7 +83,7 @@ struct backrefs_fn { lpos_template += copy_length; } // extract the specific group's string for this backref's index - auto extracted = prog.extract(idx, d_str, begin, end, backref.first - 1); + auto extracted = prog.extract(prog_idx, d_str, begin, end, backref.first - 1); if (!extracted || (extracted.value().second <= extracted.value().first)) { return; // no value for this backref number; that is ok } diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 3189739e492..a3f2631f424 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include #include -#include #include #include @@ -32,6 +30,7 @@ #include #include +#include #include #include @@ -47,7 +46,6 @@ using found_range = thrust::pair; * @brief This functor handles replacing strings by applying the compiled regex patterns * and inserting the corresponding new string within the matched range of characters. */ -template struct replace_multi_regex_fn { column_device_view const d_strings; device_span progs; // array of regex progs @@ -84,9 +82,9 @@ struct replace_multi_regex_fn { continue; // or later in the string reprog_device prog = progs[ptn_idx]; - auto begin = static_cast(ch_pos); - auto end = static_cast(nchars); - if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) + auto begin = ch_pos; + auto end = nchars; + if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) d_ranges[ptn_idx] = found_range{begin, end}; // found a match else d_ranges[ptn_idx] = found_range{nchars, nchars}; // this pattern is done @@ -123,33 +121,6 @@ struct replace_multi_regex_fn { } }; -struct replace_dispatch_fn { - template - std::unique_ptr operator()(strings_column_view const& input, - device_span d_progs, - strings_column_view const& replacements, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - auto const d_repls = column_device_view::create(replacements.parent(), stream); - - auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); - - auto children = make_strings_children( - replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace std::unique_ptr replace_re( @@ -168,15 +139,12 @@ std::unique_ptr replace_re( CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls"); // compile regexes into device objects - auto const d_char_table = get_character_flags_table(); auto h_progs = std::vector>>( patterns.size()); - std::transform(patterns.begin(), - patterns.end(), - h_progs.begin(), - [flags, d_char_table, input, stream](auto const& ptn) { - return reprog_device::create(ptn, flags, d_char_table, input.size(), stream); - }); + std::transform( + patterns.begin(), patterns.end(), h_progs.begin(), [flags, stream](auto const& ptn) { + return reprog_device::create(ptn, flags, stream); + }); // get the longest regex for the dispatcher auto const max_prog = @@ -184,15 +152,37 @@ std::unique_ptr replace_re( return lhs->insts_counts() < rhs->insts_counts(); }); + auto d_max_prog = **max_prog; + auto const buffer_size = d_max_prog.working_memory_size(input.size()); + auto d_buffer = rmm::device_buffer(buffer_size, stream); + // copy all the reprog_device instances to a device memory array std::vector progs; - std::transform(h_progs.begin(), h_progs.end(), std::back_inserter(progs), [](auto const& d_prog) { - return *d_prog; - }); + std::transform(h_progs.begin(), + h_progs.end(), + std::back_inserter(progs), + [d_buffer = d_buffer.data(), size = input.size()](auto& prog) { + prog->set_working_memory(d_buffer, size); + return *prog; + }); auto d_progs = cudf::detail::make_device_uvector_async(progs, stream); - return regex_dispatcher( - **max_prog, replace_dispatch_fn{}, input, d_progs, replacements, stream, mr); + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_repls = column_device_view::create(replacements.parent(), stream); + + auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); + + auto children = make_strings_children( + replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index af74d8bdb92..159f83453bd 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include -#include -#include +#include #include #include @@ -38,16 +36,14 @@ namespace { * @brief This functor handles replacing strings by applying the compiled regex pattern * and inserting the new string within the matched range of characters. */ -template struct replace_regex_fn { column_device_view const d_strings; - reprog_device prog; string_view const d_repl; size_type const maxrepl; int32_t* d_offsets{}; char* d_chars{}; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { if (!d_chars) d_offsets[idx] = 0; @@ -62,13 +58,13 @@ struct replace_regex_fn { auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) : nullptr; size_type last_pos = 0; - int32_t begin = 0; // these are for calling prog.find - int32_t end = -1; // matches final word-boundary if at the end of the string + size_type begin = 0; // these are for calling prog.find + size_type end = -1; // matches final word-boundary if at the end of the string // copy input to output replacing strings as we go while (mxn-- > 0 && begin <= nchars) { // maximum number of replaces - if (prog.is_empty() || prog.find(idx, d_str, begin, end) <= 0) { + if (prog.is_empty() || prog.find(prog_idx, d_str, begin, end) <= 0) { break; // no more matches } @@ -100,32 +96,6 @@ struct replace_regex_fn { } }; -struct replace_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(strings_column_view const& input, - string_view const& d_replacement, - size_type max_replace_count, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - - auto children = make_strings_children( - replace_regex_fn{*d_strings, d_prog, d_replacement, max_replace_count}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace // @@ -144,13 +114,20 @@ std::unique_ptr replace_re( string_view d_repl(replacement.data(), replacement.size()); // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const maxrepl = max_replace_count.value_or(-1); - return regex_dispatcher( - *d_prog, replace_dispatch_fn{*d_prog}, input, d_repl, maxrepl, stream, mr); + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + replace_regex_fn{*d_strings, d_repl, maxrepl}, *d_prog, input.size(), stream, mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 323ad2cbc09..64e46d07e25 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -33,7 +31,6 @@ #include #include -#include #include #include @@ -52,14 +49,12 @@ namespace { * For strings with fewer matches, null entries are appended into `d_indices` * up to the maximum column count. */ -template struct findall_fn { column_device_view const d_strings; - reprog_device prog; size_type const* d_counts; ///< match counts for each string indices_span d_indices; ///< 2D-span: output matches added here - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { auto const match_count = d_counts[idx]; @@ -72,7 +67,7 @@ struct findall_fn { int32_t begin = 0; int32_t end = -1; for (auto col_idx = 0; col_idx < match_count; ++col_idx) { - if (prog.find(idx, d_str, begin, end) > 0) { + if (prog.find(prog_idx, d_str, begin, end) > 0) { auto const begin_offset = d_str.byte_offset(begin); auto const end_offset = d_str.byte_offset(end); d_output[col_idx] = @@ -82,28 +77,12 @@ struct findall_fn { end = nchars; } } - // fill the remaining entries for this row with nulls thrust::fill( thrust::seq, d_output.begin() + match_count, d_output.end(), string_index_pair{nullptr, 0}); } }; -struct findall_dispatch_fn { - reprog_device d_prog; - - template - void operator()(column_device_view const& d_strings, - size_type const* d_find_counts, - indices_span& d_indices, - rmm::cuda_stream_view stream) - { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - findall_fn{d_strings, d_prog, d_find_counts, d_indices}); - } -}; } // namespace std::unique_ptr
findall(strings_column_view const& input, @@ -115,11 +94,10 @@ std::unique_ptr
findall(strings_column_view const& input, auto const strings_count = input.size(); // compile regex into device object - auto const d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto const d_prog = reprog_device::create(pattern, flags, stream); auto const d_strings = column_device_view::create(input.parent(), stream); - auto find_counts = count_matches(*d_strings, *d_prog, strings_count + 1, stream); + auto find_counts = count_matches(*d_strings, *d_prog, strings_count, stream); auto d_find_counts = find_counts->view().data(); size_type const columns_count = thrust::reduce( @@ -139,9 +117,8 @@ std::unique_ptr
findall(strings_column_view const& input, } else { // place all matching strings into the indices vector auto d_indices = indices_span(indices.data(), strings_count, columns_count); - regex_dispatcher( - *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, d_find_counts, d_indices, stream); - + launch_for_each_kernel( + findall_fn{*d_strings, d_find_counts, d_indices}, *d_prog, strings_count, stream); results.resize(columns_count); } diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index 46155bd7cf5..2f4b9ce5b24 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -32,8 +30,6 @@ #include #include -#include -#include #include #include @@ -49,55 +45,48 @@ namespace { * @brief This functor handles extracting matched strings by applying the compiled regex pattern * and creating string_index_pairs for all the substrings. */ -template struct findall_fn { column_device_view const d_strings; - reprog_device prog; offset_type const* d_offsets; string_index_pair* d_indices; - __device__ void operator()(size_type const idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - int32_t begin = 0; - int32_t end = d_str.length(); - while ((begin < end) && (prog.find(idx, d_str, begin, end) > 0)) { + size_type begin = 0; + size_type end = nchars; + while ((begin < end) && (prog.find(prog_idx, d_str, begin, end) > 0)) { auto const spos = d_str.byte_offset(begin); // convert auto const epos = d_str.byte_offset(end); // to bytes d_output[output_idx++] = string_index_pair{d_str.data() + spos, (epos - spos)}; begin = end + (begin == end); - end = d_str.length(); + end = nchars; } } }; -struct findall_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, +std::unique_ptr findall_util(column_device_view const& d_strings, + reprog_device& d_prog, size_type total_matches, offset_type const* d_offsets, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) - { - rmm::device_uvector indices(total_matches, stream); +{ + rmm::device_uvector indices(total_matches, stream); - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - findall_fn{d_strings, d_prog, d_offsets, indices.data()}); + launch_for_each_kernel( + findall_fn{d_strings, d_offsets, indices.data()}, d_prog, d_strings.size(), stream); - return make_strings_column(indices.begin(), indices.end(), stream, mr); - } -}; + return make_strings_column(indices.begin(), indices.end(), stream, mr); +} } // namespace @@ -113,8 +102,7 @@ std::unique_ptr findall_record( auto const d_strings = column_device_view::create(input.parent(), stream); // compile regex into device object - auto const d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto const d_prog = reprog_device::create(pattern, flags, stream); // Create lists offsets column auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); @@ -128,8 +116,7 @@ std::unique_ptr findall_record( auto const total_matches = cudf::detail::get_value(offsets->view(), strings_count, stream); - auto strings_output = regex_dispatcher( - *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, total_matches, d_offsets, stream, mr); + auto strings_output = findall_util(*d_strings, *d_prog, total_matches, d_offsets, stream, mr); // Build the lists column from the offsets and the strings return make_lists_column(strings_count, diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 3ec6df058c6..16edd0606e9 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -28,12 +26,10 @@ #include #include #include -#include #include #include -#include #include #include #include @@ -59,18 +55,17 @@ enum class split_direction { * The `d_token_offsets` specifies the output position within `d_tokens` * for each string. */ -template struct token_reader_fn { column_device_view const d_strings; - reprog_device prog; split_direction const direction; offset_type const* d_token_offsets; string_index_pair* d_tokens; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto const token_offset = d_token_offsets[idx]; auto const token_count = d_token_offsets[idx + 1] - token_offset; @@ -78,9 +73,9 @@ struct token_reader_fn { size_type token_idx = 0; size_type begin = 0; // characters - size_type end = d_str.length(); + size_type end = nchars; size_type last_pos = 0; // bytes - while (prog.find(idx, d_str, begin, end) > 0) { + while (prog.find(prog_idx, d_str, begin, end) > 0) { // get the token (characters just before this match) auto const token = string_index_pair{d_str.data() + last_pos, d_str.byte_offset(begin) - last_pos}; @@ -97,7 +92,7 @@ struct token_reader_fn { // setup for next match last_pos = d_str.byte_offset(end); begin = end + (begin == end); - end = d_str.length(); + end = nchars; } // set the last token to the remainder of the string @@ -116,28 +111,6 @@ struct token_reader_fn { } }; -struct generate_dispatch_fn { - reprog_device d_prog; - - template - rmm::device_uvector operator()(column_device_view const& d_strings, - size_type total_tokens, - split_direction direction, - offset_type const* d_offsets, - rmm::cuda_stream_view stream) - { - rmm::device_uvector tokens(total_tokens, stream); - - thrust::for_each_n( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - token_reader_fn{d_strings, d_prog, direction, d_offsets, tokens.data()}); - - return tokens; - } -}; - /** * @brief Call regex to split each input string into tokens. * @@ -176,8 +149,15 @@ rmm::device_uvector generate_tokens(column_device_view const& // the last offset entry is the total number of tokens to be generated auto const total_tokens = cudf::detail::get_value(offsets, strings_count, stream); - return regex_dispatcher( - d_prog, generate_dispatch_fn{d_prog}, d_strings, total_tokens, direction, d_offsets, stream); + rmm::device_uvector tokens(total_tokens, stream); + if (total_tokens == 0) { return tokens; } + + launch_for_each_kernel(token_reader_fn{d_strings, direction, d_offsets, tokens.data()}, + d_prog, + d_strings.size(), + stream); + + return tokens; } /** @@ -221,7 +201,7 @@ std::unique_ptr
split_re(strings_column_view const& input, } // create the regex device prog from the given pattern - auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, stream); auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string @@ -283,7 +263,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, auto const strings_count = input.size(); // create the regex device prog from the given pattern - auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, stream); auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string diff --git a/docs/cudf/source/conf.py b/docs/cudf/source/conf.py index c8b30120924..0ffbdf47d54 100644 --- a/docs/cudf/source/conf.py +++ b/docs/cudf/source/conf.py @@ -197,8 +197,9 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - "python": ("https://docs.python.org/", None), + "python": ("https://docs.python.org/3", None), "cupy": ("https://docs.cupy.dev/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable", None), } # Config numpydoc diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index cc1bc35f951..e871da18966 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -233,8 +233,10 @@ public final ColumnView getChildColumnView(int childIndex) { /** * Get a ColumnView that is the offsets for this list. + * Please note that it is the responsibility of the caller to close this view, and the parent + * column must out live this view. */ - ColumnView getListOffsetsView() { + public ColumnView getListOffsetsView() { assert(getType().equals(DType.LIST)); return new ColumnView(getListOffsetCvPointer(viewHandle)); } diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java index 763ecc763a5..8b1a9a63131 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java @@ -448,11 +448,8 @@ public HostColumnVector.StructData getStruct(int rowIndex) { * @return true if null else false */ public boolean isNull(long rowIndex) { - assert (rowIndex >= 0 && rowIndex < rows) : "index is out of range 0 <= " + rowIndex + " < " + rows; - if (hasValidityVector()) { - return BitVectorHelper.isNull(offHeap.valid, rowIndex); - } - return false; + return rowIndex < 0 || rowIndex >= rows // unknown, hence NULL + || hasValidityVector() && BitVectorHelper.isNull(offHeap.valid, rowIndex); } /** diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index d0e9e6d94c1..e75cf47bb7c 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -538,7 +538,7 @@ def to_cupy( Parameters ---------- dtype : str or numpy.dtype, optional - The dtype to pass to :meth:`numpy.asarray`. + The dtype to pass to :func:`numpy.asarray`. copy : bool, default False Whether to ensure that the returned value is not a view on another array. Note that ``copy=False`` does not *ensure* that @@ -573,7 +573,7 @@ def to_numpy( Parameters ---------- dtype : str or numpy.dtype, optional - The dtype to pass to :meth:`numpy.asarray`. + The dtype to pass to :func:`numpy.asarray`. copy : bool, default True Whether to ensure that the returned value is not a view on another array. This parameter must be ``True`` since cuDF must copy diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index e9f836d9702..679edefcc83 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -311,7 +311,11 @@ def gen_rand(dtype, size, **kwargs): np.random.randint(low=low, high=high, size=size), unit=time_unit ) elif dtype.kind in ("O", "U"): - return pd._testing.rands_array(10, size) + low = kwargs.get("low", 10) + high = kwargs.get("high", 11) + return pd._testing.rands_array( + np.random.randint(low=low, high=high, size=1)[0], size + ) raise NotImplementedError(f"dtype.kind={dtype.kind}") diff --git a/python/cudf/cudf/tests/test_orc.py b/python/cudf/cudf/tests/test_orc.py index c28358f5fa0..c547c80e48b 100644 --- a/python/cudf/cudf/tests/test_orc.py +++ b/python/cudf/cudf/tests/test_orc.py @@ -301,27 +301,36 @@ def test_orc_read_rows(datadir, skiprows, num_rows): assert_eq(pdf, gdf) -def test_orc_read_skiprows(tmpdir): +def test_orc_read_skiprows(): buff = BytesIO() - df = pd.DataFrame( - {"a": [1, 0, 1, 0, None, 1, 1, 1, 0, None, 0, 0, 1, 1, 1, 1]}, - dtype=pd.BooleanDtype(), - ) + data = [ + True, + False, + True, + False, + None, + True, + True, + True, + False, + None, + False, + False, + True, + True, + True, + True, + ] writer = pyorc.Writer(buff, pyorc.Struct(a=pyorc.Boolean())) - tuples = list( - map( - lambda x: (None,) if x[0] is pd.NA else (bool(x[0]),), - list(df.itertuples(index=False, name=None)), - ) - ) - writer.writerows(tuples) + writer.writerows([(d,) for d in data]) writer.close() + # testing 10 skiprows due to a boolean specific bug fix that didn't + # repro for other sizes of data skiprows = 10 - expected = cudf.read_orc(buff)[skiprows::].reset_index(drop=True) + expected = cudf.read_orc(buff)[skiprows:].reset_index(drop=True) got = cudf.read_orc(buff, skiprows=skiprows) - assert_eq(expected, got) @@ -724,6 +733,105 @@ def test_orc_write_statistics(tmpdir, datadir, nrows, stats_freq): assert stats_num_vals == actual_num_vals +@pytest.mark.parametrize("stats_freq", ["STRIPE", "ROWGROUP"]) +@pytest.mark.parametrize("nrows", [2, 100, 6000000]) +def test_orc_chunked_write_statistics(tmpdir, datadir, nrows, stats_freq): + supported_stat_types = supported_numpy_dtypes + ["str"] + # Can't write random bool columns until issue #6763 is fixed + if nrows == 6000000: + supported_stat_types.remove("bool") + + gdf_fname = tmpdir.join("chunked_stats.orc") + writer = ORCWriter(gdf_fname) + + max_char_length = 1000 if nrows < 10000 else 100 + + # Make a dataframe + gdf = cudf.DataFrame( + { + "col_" + + str(dtype): gen_rand_series( + dtype, + int(nrows / 2), + has_nulls=True, + low=0, + high=max_char_length, + ) + for dtype in supported_stat_types + } + ) + + pdf1 = gdf.to_pandas() + writer.write_table(gdf) + # gdf is specifically being reused here to ensure the data is destroyed + # before the next write_table call to ensure the data is persisted inside + # write and no pointers are saved into the original table + gdf = cudf.DataFrame( + { + "col_" + + str(dtype): gen_rand_series( + dtype, + int(nrows / 2), + has_nulls=True, + low=0, + high=max_char_length, + ) + for dtype in supported_stat_types + } + ) + pdf2 = gdf.to_pandas() + writer.write_table(gdf) + writer.close() + + # pandas is unable to handle min/max of string col with nulls + expect = cudf.DataFrame(pd.concat([pdf1, pdf2]).reset_index(drop=True)) + + # Read back written ORC's statistics + orc_file = pa.orc.ORCFile(gdf_fname) + ( + file_stats, + stripes_stats, + ) = cudf.io.orc.read_orc_statistics([gdf_fname]) + + # check file stats + for col in expect: + if "minimum" in file_stats[0][col]: + stats_min = file_stats[0][col]["minimum"] + actual_min = expect[col].min() + assert normalized_equals(actual_min, stats_min) + if "maximum" in file_stats[0][col]: + stats_max = file_stats[0][col]["maximum"] + actual_max = expect[col].max() + assert normalized_equals(actual_max, stats_max) + if "number_of_values" in file_stats[0][col]: + stats_num_vals = file_stats[0][col]["number_of_values"] + actual_num_vals = expect[col].count() + assert stats_num_vals == actual_num_vals + + # compare stripe statistics with actual min/max + for stripe_idx in range(0, orc_file.nstripes): + stripe = orc_file.read_stripe(stripe_idx) + # pandas is unable to handle min/max of string col with nulls + stripe_df = cudf.DataFrame(stripe.to_pandas()) + for col in stripe_df: + if "minimum" in stripes_stats[stripe_idx][col]: + actual_min = stripe_df[col].min() + stats_min = stripes_stats[stripe_idx][col]["minimum"] + assert normalized_equals(actual_min, stats_min) + + if "maximum" in stripes_stats[stripe_idx][col]: + actual_max = stripe_df[col].max() + stats_max = stripes_stats[stripe_idx][col]["maximum"] + assert normalized_equals(actual_max, stats_max) + + if "number_of_values" in stripes_stats[stripe_idx][col]: + stats_num_vals = stripes_stats[stripe_idx][col][ + "number_of_values" + ] + actual_num_vals = stripe_df[col].count() + assert stats_num_vals == actual_num_vals + + @pytest.mark.parametrize("nrows", [1, 100, 6000000]) def test_orc_write_bool_statistics(tmpdir, datadir, nrows): # Make a dataframe