Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nvtext::byte_pair_encoding API #10270

Merged
merged 25 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
56758d8
Add nvtext::byte_pair_encoding API
davidwendt Feb 10, 2022
fcb540b
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 10, 2022
ae2baa0
fix call to detail::rsplit_record
davidwendt Feb 11, 2022
5ee29cc
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 11, 2022
aa6f8e8
change algorithm to use cuco::static-map
davidwendt Feb 17, 2022
c215c55
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 17, 2022
85df96e
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 18, 2022
3df89a0
handle sliced input column
davidwendt Feb 18, 2022
ad438f1
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 22, 2022
6eb6171
add leading space to test
davidwendt Feb 23, 2022
9466521
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Feb 23, 2022
84a2cbe
add separator test
davidwendt Feb 23, 2022
1d35f19
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 1, 2022
61195b7
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 2, 2022
fe7ada7
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 7, 2022
d282330
fix typos in and clarify comments
davidwendt Mar 7, 2022
ad398af
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 8, 2022
f9cdc4f
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 14, 2022
2fc267c
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 15, 2022
93b0842
fix grammar and typos
davidwendt Mar 15, 2022
bbc3744
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 15, 2022
845a414
add more entries in load_merge_pairs_file doxygen example
davidwendt Mar 15, 2022
060077b
add check for unexpected data format
davidwendt Mar 15, 2022
cdee746
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 16, 2022
4da9b53
Merge branch 'branch-22.04' into fea-byte-pair-encoder
davidwendt Mar 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,10 @@ add_library(
src/text/normalize.cu
src/text/replace.cu
src/text/stemmer.cu
src/text/subword/bpe_tokenizer.cu
src/text/subword/data_normalizer.cu
src/text/subword/load_hash_file.cu
src/text/subword/load_merges_file.cu
src/text/subword/subword_tokenize.cu
src/text/subword/wordpiece_tokenizer.cu
src/text/tokenize.cu
Expand Down
16 changes: 15 additions & 1 deletion cpp/include/cudf/strings/detail/combine.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,6 +54,20 @@ std::unique_ptr<column> join_strings(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @copydoc join_list_elements(table_view const&,string_scalar const&,string_scalar
* const&,separator_on_nulls,output_if_empty_list,rmm::mr::device_memory_resource*)
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
*
* @param stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<column> join_list_elements(lists_column_view const& lists_strings_column,
string_scalar const& separator,
string_scalar const& narep,
separator_on_nulls separate_nulls,
output_if_empty_list empty_list_policy,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

} // namespace detail
} // namespace strings
} // namespace cudf
113 changes: 113 additions & 0 deletions cpp/include/nvtext/bpe_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/strings_column_view.hpp>

namespace nvtext {

/**
* @addtogroup nvtext_tokenize
* @{
* @file
*/

/**
* @brief The table of merge pairs for the BPE encoder.
*
* To create an instance, call nvtext::load_merges_table
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
*/
struct bpe_merge_pairs {
struct bpe_merge_pairs_impl;
std::unique_ptr<bpe_merge_pairs_impl> impl{};

bpe_merge_pairs(std::unique_ptr<cudf::column>&& input,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

bpe_merge_pairs(cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

~bpe_merge_pairs();

cudf::size_type get_size();
std::size_t get_map_size();
};

/**
* @brief Create a nvtext::bpe_merge_pairs from an input file.
*
* The file should contain a pair of strings per line separated by
* a single space.
*
* Example:
* @code{.txt}
* e n
* i t
* i s
* ...
* @endcode
*
* The pairs are expected to be ordered in the file by their rank
* relative to each other. A pair earlier in the file has priority over
* any pairs below it.
*
* @param filename_merges Local file path of pairs encoded in UTF-8.
* @param mr Memory resource to allocate any returned objects.
*/
std::unique_ptr<bpe_merge_pairs> load_merge_pairs_file(
std::string const& filename_merges,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Byte pair encode the input strings.
*
* This will split each string on whitespace, perform the encoding,
* and then build the output column using the given `separator`.
*
* The encoding algorithm rebuilds each string by matching substrings
* in the `merge_pairs` table and iteratively removing the minimum ranked pair
* until no pairs are left. Then, a space is inserted between the remaining
* pairs before the result is joined to make the output string.
*
* @code{.pseudo}
* mps = load_merges_file("merges.txt")
* input = ["test sentence", "thisis test"]
* result = byte_pair_encoding(input, mps)
* result is now ["test sent ence", "this is test"]
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
* @endcode
*
* @throw cudf::logic_error if `merge_pairs` is empty
* @throw cudf::logic_error if `separator` is invalid
*
* @param input Strings to encode.
* @param merge_pairs Created by a call to nvtext::load_merges_file.
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
* @param separator String used to build the output after encoding.
* Default is a space.
* @param mr Memory resource to allocate any returned objects.
*/
std::unique_ptr<cudf::column> byte_pair_encoding(
cudf::strings_column_view const& input,
bpe_merge_pairs const& merges_pairs,
cudf::string_scalar const& separator = cudf::string_scalar(" "),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
} // namespace nvtext
Loading