Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make bpe_merge_pairs_impl member private #14543

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cpp/include/nvtext/byte_pair_encoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ namespace nvtext {
*/
struct bpe_merge_pairs {
struct bpe_merge_pairs_impl;
bpe_merge_pairs_impl* impl{}; ///< Implementation of the BPE merge pairs table.

/**
* @brief Construct a new bpe merge pairs object
Expand All @@ -62,6 +61,10 @@ struct bpe_merge_pairs {

~bpe_merge_pairs();
bpe_merge_pairs();

private:
friend bpe_merge_pairs_impl const* get_bpe_merge_pairs_impl(bpe_merge_pairs const&);
bpe_merge_pairs_impl* impl{}; ///< Implementation of the BPE merge pairs table.
};

/**
Expand Down Expand Up @@ -97,12 +100,9 @@ std::unique_ptr<bpe_merge_pairs> load_merge_pairs(
/**
* @brief Byte pair encode the input strings.
*
* This will split each string on whitespace, perform the encoding,
* and then build the output column using the given `separator`.
*
* The encoding algorithm rebuilds each string by matching substrings
* in the `merge_pairs` table and iteratively removing the minimum ranked pair
* until no pairs are left. Then, a space is inserted between the remaining
* until no pairs are left. Then, the separator is inserted between the remaining
* pairs before the result is joined to make the output string.
*
* @code{.pseudo}
Expand Down
18 changes: 16 additions & 2 deletions cpp/src/text/bpe/byte_pair_encoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@
#include <thrust/unique.h>

namespace nvtext {

/**
* @brief Access the bpe_merge_pairs impl member
*
* This is used by the encoder to access the impl member functions.
*
* @param bpe The merge pairs struct
* @return The impl object with detailed, internal member data
*/
bpe_merge_pairs::bpe_merge_pairs_impl const* get_bpe_merge_pairs_impl(bpe_merge_pairs const& bpe)
{
return bpe.impl;
}

namespace detail {
namespace {

Expand Down Expand Up @@ -364,7 +378,7 @@ std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const
// this kernel locates unpairable sections of strings to create artificial string row
// boundaries; the boundary values are recorded as offsets in d_up_offsets
auto const d_up_offsets = d_working.data(); // store unpairable offsets here
auto const mp_map = merge_pairs.impl->get_mp_table_ref(); // lookup table
auto const mp_map = get_bpe_merge_pairs_impl(merge_pairs)->get_mp_table_ref(); // lookup table
auto const d_chars_span = cudf::device_span<char const>(d_input_chars, chars_size);
auto up_fn = bpe_unpairable_offsets_fn<decltype(mp_map)>{d_chars_span, first_offset, mp_map};
thrust::transform(rmm::exec_policy_nosync(stream), chars_begin, chars_end, d_up_offsets, up_fn);
Expand Down Expand Up @@ -398,7 +412,7 @@ std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const
// launch the byte-pair-encoding kernel on the temp column
rmm::device_uvector<int8_t> d_rerank(chars_size, stream); // more working memory;
auto const d_ranks = d_working.data(); // store pair ranks here
auto const pair_map = merge_pairs.impl->get_merge_pairs_ref();
auto const pair_map = get_bpe_merge_pairs_impl(merge_pairs)->get_merge_pairs_ref();
bpe_parallel_fn<decltype(pair_map)><<<tmp_size, block_size, 0, stream.value()>>>(
*d_tmp_strings, pair_map, d_spaces.data(), d_ranks, d_rerank.data());
}
Expand Down