diff --git a/cpp/include/nvtext/byte_pair_encoding.hpp b/cpp/include/nvtext/byte_pair_encoding.hpp index ab862df044d..71b68565e77 100644 --- a/cpp/include/nvtext/byte_pair_encoding.hpp +++ b/cpp/include/nvtext/byte_pair_encoding.hpp @@ -122,6 +122,7 @@ std::unique_ptr load_merge_pairs( * @param merges_pairs Created by a call to @ref nvtext::load_merge_pairs. * @param separator String used to build the output after encoding. * Default is a space. + * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Memory resource to allocate any returned objects. * @return An encoded column of strings. */ @@ -129,6 +130,7 @@ std::unique_ptr byte_pair_encoding( cudf::strings_column_view const& input, bpe_merge_pairs const& merges_pairs, cudf::string_scalar const& separator = cudf::string_scalar(" "), + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); /** @} */ // end of group diff --git a/cpp/src/text/bpe/byte_pair_encoding.cu b/cpp/src/text/bpe/byte_pair_encoding.cu index f46f49ddc0e..0aacfd16f67 100644 --- a/cpp/src/text/bpe/byte_pair_encoding.cu +++ b/cpp/src/text/bpe/byte_pair_encoding.cu @@ -459,10 +459,11 @@ std::unique_ptr byte_pair_encoding(cudf::strings_column_view const std::unique_ptr byte_pair_encoding(cudf::strings_column_view const& input, bpe_merge_pairs const& merges_table, cudf::string_scalar const& separator, + rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { CUDF_FUNC_RANGE(); - return detail::byte_pair_encoding(input, merges_table, separator, cudf::get_default_stream(), mr); + return detail::byte_pair_encoding(input, merges_table, separator, stream, mr); } } // namespace nvtext diff --git a/cpp/src/text/bpe/load_merge_pairs.cu b/cpp/src/text/bpe/load_merge_pairs.cu index cd68566bdec..a13a435a271 100644 --- a/cpp/src/text/bpe/load_merge_pairs.cu +++ b/cpp/src/text/bpe/load_merge_pairs.cu @@ -103,7 +103,8 @@ std::unique_ptr create_bpe_merge_pairs_im rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto pairs = cudf::strings::split_record(input, cudf::string_scalar(" "), 1, stream, mr); + auto pairs = + cudf::strings::split_record(input, cudf::string_scalar(" ", true, stream, mr), 1, stream, mr); auto content = pairs->release(); return create_bpe_merge_pairs_impl(std::move(content.children.back()), stream); } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 8928d27a871..adf512811cc 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -742,6 +742,7 @@ ConfigureTest( ) ConfigureTest( STREAM_TEXT_TEST + streams/text/bpe_test.cpp streams/text/edit_distance_test.cpp streams/text/ngrams_test.cpp streams/text/replace_test.cpp diff --git a/cpp/tests/streams/text/bpe_test.cpp b/cpp/tests/streams/text/bpe_test.cpp new file mode 100644 index 00000000000..0510edc122a --- /dev/null +++ b/cpp/tests/streams/text/bpe_test.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +struct TextBytePairEncoding : public cudf::test::BaseFixture {}; + +TEST_F(TextBytePairEncoding, BytePairEncoding) +{ + auto stream = cudf::test::get_default_stream(); + // partial table based on values from https://huggingface.co/gpt2/raw/main/merges.txt + auto mpt = cudf::test::strings_column_wrapper({ + "e n", // 14 + "i t", // 16 + "i s", // 17 + "e s", // 20 + "en t", // 44 + "c e", // 90 + "es t", // 141 + "en ce", // 340 + "t h", // 146 + "h i", // 5049 + "th is", // 5407 + "t est", // 9034 + "s i", // 13142 + "s ent" // 33832 + }); + + auto merge_pairs = nvtext::load_merge_pairs(cudf::strings_column_view(mpt), stream); + + auto validity = cudf::test::iterators::null_at(4); + cudf::test::strings_column_wrapper input( + {"thisisit", "thisis test-sentence-1", "thisistestsentence-2", "this-istestsentence 3", "", ""}, + validity); + auto sv = cudf::strings_column_view(input); + + auto results = + nvtext::byte_pair_encoding(sv, *merge_pairs, cudf::string_scalar(" ", true, stream), stream); +}