From f543dfa1356f02ae6b581e3e2584fffccfc69c76 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 17 Aug 2023 11:19:19 -0400 Subject: [PATCH] Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pairs (#13807) Switching to use `cuco::experimental::static_map` for storing the unique merge-pair strings that can be looked up by `string_view`. This takes advantage of a feature of the `static_map` that allows storing with one key (index to a string entry) and lookup with a different type (string). The map uses a hash on the string for storing the index but allows lookup by string since the hash of string can resolve the entry and duplicates can be resolved by comparing the string with row entries. Authors: - David Wendt (https://github.com/davidwendt) - Yunsong Wang (https://github.com/PointKernel) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/13807 --- cpp/include/nvtext/bpe_tokenize.hpp | 15 +---- cpp/src/text/subword/bpe_tokenizer.cu | 43 ++++++--------- cpp/src/text/subword/bpe_tokenizer.cuh | 70 +++++++++++++++++++++--- cpp/src/text/subword/load_merges_file.cu | 48 ++++++---------- 4 files changed, 98 insertions(+), 78 deletions(-) diff --git a/cpp/include/nvtext/bpe_tokenize.hpp b/cpp/include/nvtext/bpe_tokenize.hpp index b93d93b07c6..c67f4bd8b1c 100644 --- a/cpp/include/nvtext/bpe_tokenize.hpp +++ b/cpp/include/nvtext/bpe_tokenize.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,19 +61,6 @@ struct bpe_merge_pairs { rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); ~bpe_merge_pairs(); - - /** - * @brief Returns the number of merge pairs in the table. - * - * @return The number of merge pairs in the table - */ - cudf::size_type get_size(); - /** - * @brief Returns the number of unique merge pairs in the table. - * - * @return The number of unique merge pairs in the table - */ - std::size_t get_map_size(); }; /** diff --git a/cpp/src/text/subword/bpe_tokenizer.cu b/cpp/src/text/subword/bpe_tokenizer.cu index ac55fe76db1..4c4f5b3a4b1 100644 --- a/cpp/src/text/subword/bpe_tokenizer.cu +++ b/cpp/src/text/subword/bpe_tokenizer.cu @@ -80,10 +80,11 @@ __device__ cudf::string_view get_first_token(cudf::string_view const& d_str) * * @see The byte_pair_encoding_fn::operator() function below for details. */ +template struct byte_pair_encoding_fn { cudf::column_device_view const d_merges; cudf::column_device_view const d_strings; - merge_pairs_map_type::device_view const d_map; + MapRefType const d_map; cudf::size_type* d_sizes; // output size of encoded string string_hasher_type const hasher; cudf::size_type* d_byte_indices; @@ -136,17 +137,13 @@ struct byte_pair_encoding_fn { } /** - * @brief Compute the hash over the input strings. + * @brief Look up the pair of strings in the d_map/d_merges * - * The input strings are combined with a space to produce hash for matching - * a merge pair within the `d_map`. - * - * @param lhs First string. - * @param rhs Second string. - * @return The hash value to match with `d_map`. + * @param lhs Left half of the string + * @param rhs Right half of the string + * @return Position of merge pair within d_map */ - __device__ cudf::hash_value_type compute_hash(cudf::string_view const& lhs, - cudf::string_view const& rhs) + __device__ auto get_merge_pair(cudf::string_view const& lhs, cudf::string_view const& rhs) { __shared__ char shmem[48 * 1024]; // max for Pascal auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1; @@ -154,8 +151,8 @@ struct byte_pair_encoding_fn { // Edge case check. // Empirically found only two merge pair strings that were greater than 70 bytes - // and they both looked like ignorable errors. Double check this analysis with Vibhu. - if (thread_memory_size < total_size) { return 0; } + // and they both looked like ignorable errors. + if (thread_memory_size < total_size) { return d_map.end(); } // build the target string in shared memory char* ptr = &shmem[threadIdx.x * thread_memory_size]; @@ -165,8 +162,8 @@ struct byte_pair_encoding_fn { memcpy(ptr + lhs.size_bytes(), " ", 1); memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes()); - auto const d_hash_str = cudf::string_view(ptr, total_size); - return hasher(d_hash_str); // return the hash for the temp string + auto const d_str = cudf::string_view(ptr, total_size); + return d_map.find(d_str); } /** @@ -233,11 +230,10 @@ struct byte_pair_encoding_fn { auto const rhs = next_substr(itr, end, d_str); if (rhs.empty()) break; // no more adjacent pairs - auto const hash = compute_hash(lhs, rhs); - auto const map_itr = d_map.find(hash, thrust::identity{}); + auto const map_itr = get_merge_pair(lhs, rhs); if (map_itr != d_map.end()) { // found a match; record the rank (and other min_ vars) - auto const rank = static_cast(map_itr->second); + auto const rank = map_itr->second; if (rank < min_rank) { min_rank = rank; min_itr = itr; @@ -354,12 +350,12 @@ std::unique_ptr byte_pair_encoding( bpe_merge_pairs::bpe_merge_pairs_impl const& merge_pairs, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(!merge_pairs.get_merge_pairs().is_empty(), "Merge pairs table must not be empty"); + auto const d_merges = merge_pairs.get_merge_pairs(); + CUDF_EXPECTS(d_merges.size() > 0, "Merge pairs table must not be empty"); // build working vector to hold index values per byte rmm::device_uvector d_byte_indices(input.chars().size(), stream); - auto const d_merges = cudf::column_device_view::create(merge_pairs.get_merge_pairs(), stream); auto const d_strings = cudf::column_device_view::create(input.parent(), stream); auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, @@ -369,12 +365,9 @@ std::unique_ptr byte_pair_encoding( rmm::mr::get_current_device_resource()); auto d_offsets = offsets->mutable_view().data(); - byte_pair_encoding_fn fn{*d_merges, - *d_strings, - merge_pairs.get_merge_pairs_map(), - d_offsets, - string_hasher_type{}, - d_byte_indices.data()}; + auto map_ref = merge_pairs.get_merge_pairs_ref(); + byte_pair_encoding_fn fn{ + d_merges, *d_strings, map_ref, d_offsets, string_hasher_type{}, d_byte_indices.data()}; thrust::for_each_n( rmm::exec_policy(stream), thrust::make_counting_iterator(0), input.size(), fn); diff --git a/cpp/src/text/subword/bpe_tokenizer.cuh b/cpp/src/text/subword/bpe_tokenizer.cuh index 0697a9961c7..83aa22aaae9 100644 --- a/cpp/src/text/subword/bpe_tokenizer.cuh +++ b/cpp/src/text/subword/bpe_tokenizer.cuh @@ -21,7 +21,9 @@ #include #include +#include #include +#include #include #include @@ -30,30 +32,84 @@ #include #include +#include namespace nvtext { namespace detail { +using hash_value_type = uint32_t; +using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32; + +/** + * @brief Hasher function used for building and using the cuco static-map + * + * This takes advantage of heterogeneous lookup feature in cuco static-map which + * allows inserting with one type (index) and looking up with a different type (string). + */ +struct bpe_hasher { + cudf::column_device_view const d_strings; + string_hasher_type hasher{}; + // used by insert + __device__ hash_value_type operator()(cudf::size_type index) const + { + return hasher(d_strings.element(index)); + } + // used by find + __device__ hash_value_type operator()(cudf::string_view const& s) const { return hasher(s); } +}; + +/** + * @brief Equal function used for building and using the cuco static-map + * + * This takes advantage of heterogeneous lookup feature in cuco static-map which + * allows inserting with one type (index) and looking up with a different type (string). + */ +struct bpe_equal { + cudf::column_device_view const d_strings; + // used by insert + __device__ bool operator()(cudf::size_type lhs, cudf::size_type rhs) const noexcept + { + return d_strings.element(lhs) == d_strings.element(rhs); + } + // used by find + __device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept + { + return d_strings.element(lhs) == rhs; + } +}; + using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor>; -using merge_pairs_map_type = cuco::static_map; +using probe_scheme = cuco::experimental::linear_probing<1, bpe_hasher>; -using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32; +using merge_pairs_map_type = cuco::experimental::static_map, + cuda::thread_scope_device, + bpe_equal, + probe_scheme, + hash_table_allocator_type>; } // namespace detail +// since column_device_view::create returns is a little more than +// std::unique_ptr this helper simplifies the return type in a more maintainable +// way +using col_device_view = std::invoke_result_t; + struct bpe_merge_pairs::bpe_merge_pairs_impl { std::unique_ptr const merge_pairs; + col_device_view const d_merge_pairs; std::unique_ptr merge_pairs_map; bpe_merge_pairs_impl(std::unique_ptr&& merge_pairs, + col_device_view&& d_merge_pairs, std::unique_ptr&& merge_pairs_map); - auto get_merge_pairs() const { return merge_pairs->view(); } - auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); } + auto const get_merge_pairs() const { return *d_merge_pairs; } + auto get_merge_pairs_ref() const { return merge_pairs_map->ref(cuco::experimental::op::find); } }; } // namespace nvtext diff --git a/cpp/src/text/subword/load_merges_file.cu b/cpp/src/text/subword/load_merges_file.cu index b39413af98f..1f1b90b3f49 100644 --- a/cpp/src/text/subword/load_merges_file.cu +++ b/cpp/src/text/subword/load_merges_file.cu @@ -36,23 +36,8 @@ namespace nvtext { namespace detail { - namespace { -struct make_pair_function { - /** - * @brief Hash the merge pair entry - */ - __device__ cuco::pair operator()(cudf::size_type idx) - { - auto const result = _hasher(d_strings.element(idx)); - return cuco::make_pair(result, idx); - } - - string_hasher_type const _hasher; - cudf::column_device_view const d_strings; -}; - /** * @brief Loads a text file of merge-pairs into a strings column. * @@ -101,26 +86,23 @@ std::unique_ptr load_file_to_column(std::string const& filename_me } std::unique_ptr initialize_merge_pairs_map( - cudf::strings_column_view const& input, rmm::cuda_stream_view stream) + cudf::column_device_view const& input, rmm::cuda_stream_view stream) { // Ensure capacity is at least (size/0.7) as documented here: // https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182 auto merge_pairs_map = std::make_unique( static_cast(input.size() * 2), // capacity is 2x; - cuco::empty_key{std::numeric_limits::max()}, + cuco::empty_key{-1}, cuco::empty_value{-1}, // empty value is not used + bpe_equal{input}, + probe_scheme{bpe_hasher{input}}, hash_table_allocator_type{default_allocator{}, stream}, stream.value()); - auto d_strings = cudf::column_device_view::create(input.parent(), stream); - make_pair_function pair_func{string_hasher_type{}, *d_strings}; - auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func); + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); }); - merge_pairs_map->insert(iter, - iter + input.size(), - thrust::identity{}, - thrust::equal_to{}, - stream.value()); + merge_pairs_map->insert_async(iter, iter + input.size(), stream.value()); return merge_pairs_map; } @@ -128,9 +110,10 @@ std::unique_ptr initialize_merge_pairs_map( std::unique_ptr create_bpe_merge_pairs_impl( std::unique_ptr&& input, rmm::cuda_stream_view stream) { - auto merge_pairs = initialize_merge_pairs_map(cudf::strings_column_view(input->view()), stream); - return std::make_unique(std::move(input), - std::move(merge_pairs)); + auto d_input = cudf::column_device_view::create(input->view(), stream); + auto merge_pairs = initialize_merge_pairs_map(*d_input, stream); + return std::make_unique( + std::move(input), std::move(d_input), std::move(merge_pairs)); } std::unique_ptr create_bpe_merge_pairs_impl( @@ -163,8 +146,12 @@ std::unique_ptr load_merge_pairs_file(std::string const& filena bpe_merge_pairs::bpe_merge_pairs_impl::bpe_merge_pairs_impl( std::unique_ptr&& merge_pairs, + std::unique_ptr>&& + d_merge_pairs, std::unique_ptr&& merge_pairs_map) - : merge_pairs(std::move(merge_pairs)), merge_pairs_map(std::move(merge_pairs_map)) + : merge_pairs(std::move(merge_pairs)), + d_merge_pairs(std::move(d_merge_pairs)), + merge_pairs_map(std::move(merge_pairs_map)) { } @@ -184,7 +171,4 @@ bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input, bpe_merge_pairs::~bpe_merge_pairs() = default; -cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); } -std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); } - } // namespace nvtext