Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pairs #13807

Merged
merged 31 commits into from
Aug 17, 2023
Merged
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
11a8dd0
Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pairs
davidwendt Aug 2, 2023
890c0f6
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 2, 2023
b3e01f0
Use linear probing + use proper key sentinel
PointKernel Aug 2, 2023
ac73b10
Use != operator for map iterators
PointKernel Aug 2, 2023
326c2fa
Use access operator
PointKernel Aug 2, 2023
5d4223e
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 3, 2023
fceefc4
use linear-probe scheme
davidwendt Aug 3, 2023
b93b1cc
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
PointKernel Aug 3, 2023
bfa1a13
Merge branch 'branch-23.10' into fix-bpe-static-map
PointKernel Aug 3, 2023
860b1b8
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 3, 2023
148fc83
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 5, 2023
91ab4fb
keep merge-pairs column-device-view alive
davidwendt Aug 6, 2023
6a4fcf6
cleanup col-device-view declaration in impl class
davidwendt Aug 6, 2023
0008bd2
cleanup code; add comments
davidwendt Aug 7, 2023
40755df
Update cuco git tag
PointKernel Aug 7, 2023
cda11b8
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
PointKernel Aug 7, 2023
353971f
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 9, 2023
baf54e5
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 10, 2023
440b8b5
reuse already created col-dev-view
davidwendt Aug 11, 2023
d6234c6
Revert temporary CMake changes
PointKernel Aug 11, 2023
caf5e68
remove unneeded include
davidwendt Aug 11, 2023
38171e3
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
davidwendt Aug 11, 2023
42f1b92
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 11, 2023
06136dd
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 13, 2023
d884817
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 14, 2023
f262d74
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 15, 2023
7ffb1ab
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
davidwendt Aug 15, 2023
708d8fd
fix some comments
davidwendt Aug 16, 2023
7054430
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
24548cd
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
8b5aba7
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions cpp/include/nvtext/bpe_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -68,12 +68,6 @@ struct bpe_merge_pairs {
* @return The number of merge pairs in the table
*/
cudf::size_type get_size();
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Returns the number of unique merge pairs in the table.
*
* @return The number of unique merge pairs in the table
*/
std::size_t get_map_size();
};

/**
38 changes: 12 additions & 26 deletions cpp/src/text/subword/bpe_tokenizer.cu
Original file line number Diff line number Diff line change
@@ -80,10 +80,11 @@ __device__ cudf::string_view get_first_token(cudf::string_view const& d_str)
*
* @see The byte_pair_encoding_fn::operator() function below for details.
*/
template <typename MapRefType>
struct byte_pair_encoding_fn {
cudf::column_device_view const d_merges;
cudf::column_device_view const d_strings;
merge_pairs_map_type::device_view const d_map;
MapRefType const d_map;
cudf::size_type* d_sizes; // output size of encoded string
string_hasher_type const hasher;
cudf::size_type* d_byte_indices;
@@ -135,18 +136,7 @@ struct byte_pair_encoding_fn {
return cudf::string_view(d_str.data() + *begin, size);
}

/**
* @brief Compute the hash over the input strings.
*
* The input strings are combined with a space to produce hash for matching
* a merge pair within the `d_map`.
*
* @param lhs First string.
* @param rhs Second string.
* @return The hash value to match with `d_map`.
*/
__device__ cudf::hash_value_type compute_hash(cudf::string_view const& lhs,
cudf::string_view const& rhs)
__device__ auto get_merge_pair(cudf::string_view const& lhs, cudf::string_view const& rhs)
{
__shared__ char shmem[48 * 1024]; // max for Pascal
auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1;
@@ -155,7 +145,7 @@ struct byte_pair_encoding_fn {
// Edge case check.
// Empirically found only two merge pair strings that were greater than 70 bytes
// and they both looked like ignorable errors. Double check this analysis with Vibhu.
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
if (thread_memory_size < total_size) { return 0; }
if (thread_memory_size < total_size) { return d_map.end(); }

// build the target string in shared memory
char* ptr = &shmem[threadIdx.x * thread_memory_size];
@@ -165,8 +155,8 @@ struct byte_pair_encoding_fn {
memcpy(ptr + lhs.size_bytes(), " ", 1);
memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes());

auto const d_hash_str = cudf::string_view(ptr, total_size);
return hasher(d_hash_str); // return the hash for the temp string
auto const d_str = cudf::string_view(ptr, total_size);
return d_map.find(d_str);
}

/**
@@ -233,11 +223,10 @@ struct byte_pair_encoding_fn {
auto const rhs = next_substr(itr, end, d_str);
if (rhs.empty()) break; // no more adjacent pairs

auto const hash = compute_hash(lhs, rhs);
auto const map_itr = d_map.find(hash, thrust::identity<cudf::hash_value_type>{});
if (map_itr != d_map.end()) {
auto const map_itr = get_merge_pair(lhs, rhs);
if (!(map_itr == d_map.end())) {
// found a match; record the rank (and other min_ vars)
auto const rank = static_cast<cudf::size_type>(map_itr->second);
auto const rank = static_cast<cudf::size_type>((*map_itr).second);
if (rank < min_rank) {
min_rank = rank;
min_itr = itr;
@@ -369,12 +358,9 @@ std::unique_ptr<cudf::column> byte_pair_encoding(
rmm::mr::get_current_device_resource());
auto d_offsets = offsets->mutable_view().data<cudf::offset_type>();

byte_pair_encoding_fn fn{*d_merges,
*d_strings,
merge_pairs.get_merge_pairs_map(),
d_offsets,
string_hasher_type{},
d_byte_indices.data()};
auto map_ref = merge_pairs.get_merge_pairs_map();
byte_pair_encoding_fn<decltype(map_ref)> fn{
*d_merges, *d_strings, map_ref, d_offsets, string_hasher_type{}, d_byte_indices.data()};
thrust::for_each_n(
rmm::exec_policy(stream), thrust::make_counting_iterator<cudf::size_type>(0), input.size(), fn);

52 changes: 46 additions & 6 deletions cpp/src/text/subword/bpe_tokenizer.cuh
Original file line number Diff line number Diff line change
@@ -21,27 +21,67 @@
#include <hash/hash_allocator.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/hashing/detail/murmurhash3_x86_32.cuh>
#include <cudf/strings/string_view.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/polymorphic_allocator.hpp>

#include <cuco/operator.hpp>
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
#include <cuco/static_map.cuh>

#include <cstdint>

namespace nvtext {
namespace detail {

using hash_value_type = uint32_t;
using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;

struct bpe_hasher {
cudf::column_device_view const d_strings;
string_hasher_type hasher{};
// used by insert
__device__ hash_value_type operator()(cudf::size_type index) const
{
return hasher(d_strings.element<cudf::string_view>(index));
}
// used by find
__device__ hash_value_type operator()(cudf::string_view const& s) const { return hasher(s); }
};

struct bpe_equal {
cudf::column_device_view const d_strings;
// used by insert
__device__ bool operator()(cudf::size_type lhs, cudf::size_type rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == d_strings.element<cudf::string_view>(rhs);
}
// used by find
__device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == rhs;
}
};

using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor<default_allocator<char>>;

using merge_pairs_map_type = cuco::static_map<cudf::hash_value_type,
cudf::size_type,
cuda::thread_scope_device,
hash_table_allocator_type>;
using probe_scheme = cuco::experimental::double_hashing<1, bpe_hasher>;
PointKernel marked this conversation as resolved.
Show resolved Hide resolved

using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;
using merge_pairs_map_type = cuco::experimental::static_map<cudf::size_type,
cudf::size_type,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
bpe_equal,
probe_scheme,
hash_table_allocator_type>;

// using merge_pairs_map_type = cuco::static_map<cudf::hash_value_type,
// cudf::size_type,
// cuda::thread_scope_device,
// hash_table_allocator_type>;

} // namespace detail

@@ -53,7 +93,7 @@ struct bpe_merge_pairs::bpe_merge_pairs_impl {
std::unique_ptr<detail::merge_pairs_map_type>&& merge_pairs_map);

auto get_merge_pairs() const { return merge_pairs->view(); }
auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); }
auto get_merge_pairs_map() const { return merge_pairs_map->ref(cuco::experimental::op::find); }
};

} // namespace nvtext
24 changes: 7 additions & 17 deletions cpp/src/text/subword/load_merges_file.cu
Original file line number Diff line number Diff line change
@@ -36,20 +36,13 @@

namespace nvtext {
namespace detail {

namespace {

struct make_pair_function {
/**
* @brief Hash the merge pair entry
*/
__device__ cuco::pair<cudf::hash_value_type, cudf::size_type> operator()(cudf::size_type idx)
__device__ cuco::pair<cudf::size_type, cudf::size_type> operator()(cudf::size_type idx)
{
auto const result = _hasher(d_strings.element<cudf::string_view>(idx));
return cuco::make_pair(result, idx);
return cuco::make_pair(idx, idx);
}

string_hasher_type const _hasher;
cudf::column_device_view const d_strings;
};

@@ -103,24 +96,22 @@ std::unique_ptr<cudf::column> load_file_to_column(std::string const& filename_me
std::unique_ptr<detail::merge_pairs_map_type> initialize_merge_pairs_map(
cudf::strings_column_view const& input, rmm::cuda_stream_view stream)
{
auto d_strings = cudf::column_device_view::create(input.parent(), stream);
// Ensure capacity is at least (size/0.7) as documented here:
// https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182
auto merge_pairs_map = std::make_unique<merge_pairs_map_type>(
static_cast<size_t>(input.size() * 2), // capacity is 2x;
cuco::empty_key{std::numeric_limits<cudf::hash_value_type>::max()},
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
cuco::empty_value{-1}, // empty value is not used
bpe_equal{*d_strings},
probe_scheme{bpe_hasher{*d_strings}},
hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto d_strings = cudf::column_device_view::create(input.parent(), stream);
make_pair_function pair_func{string_hasher_type{}, *d_strings};
make_pair_function pair_func{*d_strings};
auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func);

merge_pairs_map->insert(iter,
iter + input.size(),
thrust::identity<cudf::hash_value_type>{},
thrust::equal_to<cudf::hash_value_type>{},
stream.value());
merge_pairs_map->insert(iter, iter + input.size(), stream.value());
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

return merge_pairs_map;
}
@@ -185,6 +176,5 @@ bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input,
bpe_merge_pairs::~bpe_merge_pairs() = default;

cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); }
std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); }

} // namespace nvtext