Skip to content

Commit

Permalink
Change nvtext::load_vocabulary_file to return a unique ptr (#7424)
Browse files Browse the repository at this point in the history
Reference #5868 

This PR changes the `nvtext::load_vocabulary_file` to return a unique-pointer to make it easier to manage in Python/Cython class object. The original signature returned a flat structure that contained unique-pointers which would make it difficult to copy and manage.

The corresponding gtests and gbenchmarks were updated for this API change.

Authors:
  - David (@davidwendt)

Approvers:
  - Conor Hoekstra (@codereport)
  - Karthikeyan (@karthikeyann)

URL: #7424
  • Loading branch information
davidwendt authored Feb 24, 2021
1 parent dcf949c commit b0e5aef
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
4 changes: 2 additions & 2 deletions cpp/benchmarks/text/subword_benchmark.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,7 +65,7 @@ static void BM_cuda_tokenizer_cudf(benchmark::State& state)
auto vocab = nvtext::load_vocabulary_file(hash_file);
for (auto _ : state) {
auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings},
vocab,
*vocab,
max_sequence_length,
stride,
do_lower,
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/nvtext/detail/load_hash_file.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,9 +40,10 @@ namespace detail {
* @param mr Memory resource to allocate any returned objects.
* @return vocabulary hash-table elements
*/
hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabulary,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
std::unique_ptr<hashed_vocabulary> load_vocabulary_file(
std::string const& filename_hashed_vocabulary,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

} // namespace detail
} // namespace nvtext
4 changes: 2 additions & 2 deletions cpp/include/nvtext/subword_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,7 +59,7 @@ struct hashed_vocabulary {
* @param mr Memory resource to allocate any returned objects.
* @return vocabulary hash-table elements
*/
hashed_vocabulary load_vocabulary_file(
std::unique_ptr<hashed_vocabulary> load_vocabulary_file(
std::string const& filename_hashed_vocabulary,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
13 changes: 7 additions & 6 deletions cpp/src/text/subword/load_hash_file.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ uint64_t str_to_uint64(std::string const& str, uint64_t line_no)
* @param filename_hashed_vocabulary Path to text file containing hashed vocabulary
* @return object containing hash table elements for the wordpiece tokenizer
*/
hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabulary,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
std::unique_ptr<hashed_vocabulary> load_vocabulary_file(
std::string const& filename_hashed_vocabulary,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
hashed_vocabulary result;
std::ifstream hash_file(filename_hashed_vocabulary);
Expand Down Expand Up @@ -276,13 +277,13 @@ hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabu
detail::get_codepoint_metadata(stream);
detail::get_aux_codepoint_data(stream);

return result;
return std::make_unique<hashed_vocabulary>(std::move(result));
}

} // namespace detail

hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabulary,
rmm::mr::device_memory_resource* mr)
std::unique_ptr<hashed_vocabulary> load_vocabulary_file(
std::string const& filename_hashed_vocabulary, rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::load_vocabulary_file(filename_hashed_vocabulary, rmm::cuda_stream_default, mr);
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/text/subword/subword_tokenize.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -256,10 +256,10 @@ tokenizer_result subword_tokenize(cudf::strings_column_view const& strings,
uint32_t max_rows_tensor,
rmm::mr::device_memory_resource* mr)
{
hashed_vocabulary vocab_table = load_vocabulary_file(filename_hashed_vocabulary, mr);
auto vocab_table = load_vocabulary_file(filename_hashed_vocabulary, mr);
CUDF_FUNC_RANGE();
return detail::subword_tokenize(strings,
vocab_table,
*vocab_table,
max_sequence_length,
stride,
do_lower_case,
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/text/subword_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ TEST(TextSubwordTest, TokenizeFromVocabStruct)
cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end());
auto vocab = nvtext::load_vocabulary_file(hash_file);
auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings},
vocab,
*vocab,
8,
6,
true, // do_lower_case
Expand Down Expand Up @@ -307,7 +307,7 @@ TEST(TextSubwordTest, TokenizeWithSpecialTokens)
cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end());
auto vocab = nvtext::load_vocabulary_file(hash_file);
auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings},
vocab,
*vocab,
8,
6,
true, // do_lower_case
Expand Down

0 comments on commit b0e5aef

Please sign in to comment.