From 2318548ec821cf8777303966466d5dc24d9c5338 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Mon, 11 Dec 2023 21:37:32 -0500 Subject: [PATCH] Make bpe_merge_pairs_impl member private (#14543) Changes the `impl` member of the public `bpe_merge_pairs` struct to private. Also fixes the `nvtext::byte_pair_encoding` doxygen to remove outdated details. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Karthikeyan (https://github.com/karthikeyann) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/14543 --- cpp/include/nvtext/byte_pair_encoding.hpp | 10 +++++----- cpp/src/text/bpe/byte_pair_encoding.cu | 18 ++++++++++++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/cpp/include/nvtext/byte_pair_encoding.hpp b/cpp/include/nvtext/byte_pair_encoding.hpp index f9790a1a701..4d6d8335eac 100644 --- a/cpp/include/nvtext/byte_pair_encoding.hpp +++ b/cpp/include/nvtext/byte_pair_encoding.hpp @@ -36,7 +36,6 @@ namespace nvtext { */ struct bpe_merge_pairs { struct bpe_merge_pairs_impl; - bpe_merge_pairs_impl* impl{}; ///< Implementation of the BPE merge pairs table. /** * @brief Construct a new bpe merge pairs object @@ -62,6 +61,10 @@ struct bpe_merge_pairs { ~bpe_merge_pairs(); bpe_merge_pairs(); + + private: + friend bpe_merge_pairs_impl const* get_bpe_merge_pairs_impl(bpe_merge_pairs const&); + bpe_merge_pairs_impl* impl{}; ///< Implementation of the BPE merge pairs table. }; /** @@ -97,12 +100,9 @@ std::unique_ptr load_merge_pairs( /** * @brief Byte pair encode the input strings. * - * This will split each string on whitespace, perform the encoding, - * and then build the output column using the given `separator`. - * * The encoding algorithm rebuilds each string by matching substrings * in the `merge_pairs` table and iteratively removing the minimum ranked pair - * until no pairs are left. Then, a space is inserted between the remaining + * until no pairs are left. Then, the separator is inserted between the remaining * pairs before the result is joined to make the output string. * * @code{.pseudo} diff --git a/cpp/src/text/bpe/byte_pair_encoding.cu b/cpp/src/text/bpe/byte_pair_encoding.cu index 5be35119003..2d53faf548e 100644 --- a/cpp/src/text/bpe/byte_pair_encoding.cu +++ b/cpp/src/text/bpe/byte_pair_encoding.cu @@ -43,6 +43,20 @@ #include namespace nvtext { + +/** + * @brief Access the bpe_merge_pairs impl member + * + * This is used by the encoder to access the impl member functions. + * + * @param bpe The merge pairs struct + * @return The impl object with detailed, internal member data + */ +bpe_merge_pairs::bpe_merge_pairs_impl const* get_bpe_merge_pairs_impl(bpe_merge_pairs const& bpe) +{ + return bpe.impl; +} + namespace detail { namespace { @@ -364,7 +378,7 @@ std::unique_ptr byte_pair_encoding(cudf::strings_column_view const // this kernel locates unpairable sections of strings to create artificial string row // boundaries; the boundary values are recorded as offsets in d_up_offsets auto const d_up_offsets = d_working.data(); // store unpairable offsets here - auto const mp_map = merge_pairs.impl->get_mp_table_ref(); // lookup table + auto const mp_map = get_bpe_merge_pairs_impl(merge_pairs)->get_mp_table_ref(); // lookup table auto const d_chars_span = cudf::device_span(d_input_chars, chars_size); auto up_fn = bpe_unpairable_offsets_fn{d_chars_span, first_offset, mp_map}; thrust::transform(rmm::exec_policy_nosync(stream), chars_begin, chars_end, d_up_offsets, up_fn); @@ -398,7 +412,7 @@ std::unique_ptr byte_pair_encoding(cudf::strings_column_view const // launch the byte-pair-encoding kernel on the temp column rmm::device_uvector d_rerank(chars_size, stream); // more working memory; auto const d_ranks = d_working.data(); // store pair ranks here - auto const pair_map = merge_pairs.impl->get_merge_pairs_ref(); + auto const pair_map = get_bpe_merge_pairs_impl(merge_pairs)->get_merge_pairs_ref(); bpe_parallel_fn<<>>( *d_tmp_strings, pair_map, d_spaces.data(), d_ranks, d_rerank.data()); }