Skip to content

Commit

Permalink
Python ops for new RoundRobinTrimmer kernels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 513312038
  • Loading branch information
broken authored and tf-text-github-robot committed Mar 1, 2023
1 parent c0d91f7 commit cc439a2
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 76 deletions.
6 changes: 5 additions & 1 deletion tensorflow_text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1252,9 +1252,11 @@ py_library(

# tokenization_layers_py_test

py_library(
py_tf_text_library(
name = "trimmer_ops",
srcs = ["python/ops/trimmer_ops.py"],
cc_op_defs = ["//tensorflow_text/core/ops:trimmer_ops.cc"],
cc_op_kernels = ["//tensorflow_text/core/kernels:round_robin_trimmer_kernel"],
deps = [
":item_selector_ops",
# python:array_ops tensorflow dep,
Expand All @@ -1273,8 +1275,10 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":tensorflow_text",
":trimmer_ops",
"@absl_py//absl/testing:parameterized",
# tensorflow package dep,
# python:client_testlib tensorflow dep,
# python:constant_op tensorflow dep,
# python:framework_test_lib tensorflow dep,
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@
tflite_registrar.AddFastWordpieceDetokenize,
tflite_registrar.AddNgramsStringJoin,
tflite_registrar.AddRaggedTensorToTensor,
tflite_registrar.AddRoundRobinGenerateMasks,
tflite_registrar.AddRoundRobinTrim,
tflite_registrar.AddSentenceFragmenterV2,
tflite_registrar.AddWhitespaceTokenize
tflite_registrar.AddWhitespaceTokenize,
]

remove_undocumented(__name__, _allowed_symbols)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ tflite_cc_library(
"fast_wordpiece_tokenizer_tflite.h",
"ngrams_tflite.h",
"ragged_tensor_to_tensor_tflite.h",
"round_robin_trimmer_tflite.h",
"sentence_fragmenter_v2_tflite.h",
"whitespace_tokenizer_tflite.h",
"//tensorflow_text/core/kernels/sentencepiece:sp_headers",
Expand All @@ -1045,6 +1046,7 @@ tflite_cc_library(
":fast_wordpiece_tokenizer_tflite",
":ngrams_tflite",
":ragged_tensor_to_tensor_tflite",
":round_robin_trimmer_tflite",
":sentence_fragmenter_v2_tflite",
":whitespace_tokenizer_tflite",
# lite:mutable_op_resolver tensorflow dep,
Expand Down
57 changes: 28 additions & 29 deletions tensorflow_text/core/kernels/round_robin_trimmer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ namespace text {

template <typename T, typename Tsplits = int32_t>
class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {
using Mask = std::vector<bool>;
using Values = Values<T>;
using ValuesSpan = ValuesSpan<T>;
using RowSplits = RowSplits<Tsplits>;
using RowSplitsSpan = RowSplitsSpan<Tsplits>;
using Values_ = Values<T>;
using ValuesSpan_ = ValuesSpan<T>;
using RowSplits_ = RowSplits<Tsplits>;
using RowSplitsSpan_ = RowSplitsSpan<Tsplits>;

public:
RoundRobinTrimmer(int max_sequence_length)
Expand All @@ -40,7 +39,7 @@ class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {

// Generates masks for a single batch of values.
std::vector<Mask> GenerateMasks(
const std::vector<Values>& values) const;
const std::vector<Values_>& values) const;

// Generates masks for a batch of values row splits.
//
Expand All @@ -51,12 +50,12 @@ class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {
// The returned value is a flattened list of mask values which can be split
// into batches using the same input row splits.
std::vector<Mask> GenerateMasksBatch(
const std::vector<RowSplits>& row_splits) const;
const std::vector<RowSplits_>& row_splits) const;
std::vector<Mask> GenerateMasksBatch(
const std::vector<RowSplitsSpan>& row_splits) const;
const std::vector<RowSplitsSpan_>& row_splits) const;

// Trims a single batch of values.
void Trim(std::vector<Values>* values) const;
void Trim(std::vector<Values_>* values) const;

// Trims a batch of values given their flattened values and row splits.
//
Expand All @@ -66,12 +65,12 @@ class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {
//
// Returns:
// The returned values are the flattened trimmed values and new row splits.
std::pair<std::vector<Values>, std::vector<RowSplits>> TrimBatch(
const std::vector<Values>& flat_values,
const std::vector<RowSplits>& row_splits) const;
std::pair<std::vector<Values>, std::vector<RowSplits>> TrimBatch(
const std::vector<ValuesSpan>& flat_values,
const std::vector<RowSplitsSpan>& row_splits) const;
std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch(
const std::vector<Values_>& flat_values,
const std::vector<RowSplits_>& row_splits) const;
std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch(
const std::vector<ValuesSpan_>& flat_values,
const std::vector<RowSplitsSpan_>& row_splits) const;

protected:
// Used for holding data about value sizes and how much of it is used.
Expand All @@ -89,7 +88,7 @@ class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {

// Internal execution to share code for Span & Vector row_splits.
template <typename ValuesIterator, typename RowSplitsIterator>
std::pair<std::vector<Values>, std::vector<RowSplits>> TrimInternal(
std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimInternal(
ValuesIterator flat_values_begin,
ValuesIterator flat_values_end,
RowSplitsIterator row_splits_begin,
Expand Down Expand Up @@ -119,7 +118,7 @@ class RoundRobinTrimmer : Trimmer<T>, BatchTrimmer<T, Tsplits> {

template <typename T, typename Tsplits>
std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasks(
const std::vector<Values>& values) const {
const std::vector<Values_>& values) const {
std::vector<Mask> masks(values.size());
ProcessBatch(values.begin(), values.end(),
[&masks](std::vector<Row>* value_row_sizes) {
Expand All @@ -136,13 +135,13 @@ std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasks(

template <typename T, typename Tsplits>
std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasksBatch(
const std::vector<RowSplits>& row_splits) const {
const std::vector<RowSplits_>& row_splits) const {
return GenerateMasksInternal(row_splits.begin(), row_splits.end());
}

template <typename T, typename Tsplits>
std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasksBatch(
const std::vector<RowSplitsSpan>& row_splits) const {
const std::vector<RowSplitsSpan_>& row_splits) const {
return GenerateMasksInternal(row_splits.begin(), row_splits.end());
}

Expand All @@ -169,7 +168,7 @@ std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasksInternal(
}

template <typename T, typename Tsplits>
void RoundRobinTrimmer<T, Tsplits>::Trim(std::vector<Values>* values) const {
void RoundRobinTrimmer<T, Tsplits>::Trim(std::vector<Values_>* values) const {
ProcessBatch(values->begin(), values->end(),
[values] (std::vector<Row>* value_row_sizes) {
for (int s = 0; s < values->size(); ++s) {
Expand All @@ -181,8 +180,8 @@ void RoundRobinTrimmer<T, Tsplits>::Trim(std::vector<Values>* values) const {
template <typename T, typename Tsplits>
std::pair<std::vector<Values<T>>, std::vector<RowSplits<Tsplits>>>
RoundRobinTrimmer<T, Tsplits>::TrimBatch(
const std::vector<Values>& flat_values,
const std::vector<RowSplits>& row_splits) const {
const std::vector<Values_>& flat_values,
const std::vector<RowSplits_>& row_splits) const {
return TrimInternal(
flat_values.begin(), flat_values.end(),
row_splits.begin(), row_splits.end());
Expand All @@ -191,8 +190,8 @@ RoundRobinTrimmer<T, Tsplits>::TrimBatch(
template <typename T, typename Tsplits>
std::pair<std::vector<Values<T>>, std::vector<RowSplits<Tsplits>>>
RoundRobinTrimmer<T, Tsplits>::TrimBatch(
const std::vector<ValuesSpan>& flat_values,
const std::vector<RowSplitsSpan>& row_splits) const {
const std::vector<ValuesSpan_>& flat_values,
const std::vector<RowSplitsSpan_>& row_splits) const {
return TrimInternal(
flat_values.begin(), flat_values.end(),
row_splits.begin(), row_splits.end());
Expand All @@ -206,9 +205,9 @@ RoundRobinTrimmer<T, Tsplits>::TrimInternal(
ValuesIterator flat_values_end,
RowSplitsIterator splits_begin,
RowSplitsIterator splits_end) const {
std::pair<std::vector<Values>, std::vector<RowSplits>> trimmed(
{std::vector<Values>(flat_values_end - flat_values_begin),
std::vector<RowSplits>(splits_end - splits_begin)});
std::pair<std::vector<Values_>, std::vector<RowSplits_>> trimmed(
{std::vector<Values_>(flat_values_end - flat_values_begin),
std::vector<RowSplits_>(splits_end - splits_begin)});
// All row splits start at index 0
for (int i = 0; i < trimmed.second.size(); ++i) {
trimmed.second[i].push_back({0});
Expand All @@ -219,8 +218,8 @@ RoundRobinTrimmer<T, Tsplits>::TrimInternal(
auto values_it = flat_values_begin;
auto splits_it = splits_begin;
for (int s = 0; s < values_row->size(); ++s, ++values_it, ++splits_it) {
Values* vals = &trimmed.first[s];
RowSplits* splits = &trimmed.second[s];
Values_* vals = &trimmed.first[s];
RowSplits_* splits = &trimmed.second[s];
auto start = values_it->begin() + (*splits_it)[splits->size()-1];
vals->insert(vals->end(), start, start + (*values_row)[s].used);
splits->insert(splits->end(), splits->back() + (*values_row)[s].used);
Expand Down
30 changes: 15 additions & 15 deletions tensorflow_text/core/kernels/trimmer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@ using RowSplitsSpan = absl::Span<Tsplits>;

template <typename T>
class Trimmer {
using Values = Values<T>;
using ValuesT = Values<T>;

public:
// Generates masks for a single batch of values.
virtual std::vector<Mask> GenerateMasks(
const std::vector<Values>& values) const = 0;
const std::vector<ValuesT>& values) const = 0;

// Trims a single batch of values.
virtual void Trim(std::vector<Values>* values) const = 0;
virtual void Trim(std::vector<ValuesT>* values) const = 0;

virtual ~Trimmer() = default;
};

template <typename T, typename Tsplits>
class BatchTrimmer {
using Values = Values<T>;
using ValuesSpan = ValuesSpan<T>;
using RowSplits = RowSplits<Tsplits>;
using RowSplitsSpan = RowSplitsSpan<Tsplits>;
using Values_ = Values<T>;
using ValuesSpan_ = ValuesSpan<T>;
using RowSplits_ = RowSplits<Tsplits>;
using RowSplitsSpan_ = RowSplitsSpan<Tsplits>;

public:
// Generates masks for a batch of value row splits.
Expand All @@ -64,9 +64,9 @@ class BatchTrimmer {
// The returned value is a flattened list of mask values which can be split
// into batches using the same input row splits.
virtual std::vector<Mask> GenerateMasksBatch(
const std::vector<RowSplits>& row_splits) const = 0;
const std::vector<RowSplits_>& row_splits) const = 0;
virtual std::vector<Mask> GenerateMasksBatch(
const std::vector<RowSplitsSpan>& row_splits) const = 0;
const std::vector<RowSplitsSpan_>& row_splits) const = 0;

// Trims a batch of values given their flattened values and row splits.
//
Expand All @@ -76,12 +76,12 @@ class BatchTrimmer {
//
// Returns:
// The returned values are the flattened trimmed values and new row splits.
virtual std::pair<std::vector<Values>, std::vector<RowSplits>> TrimBatch(
const std::vector<Values>& flat_values,
const std::vector<RowSplits>& row_splits) const = 0;
virtual std::pair<std::vector<Values>, std::vector<RowSplits>> TrimBatch(
const std::vector<ValuesSpan>& flat_values,
const std::vector<RowSplitsSpan>& row_splits) const = 0;
virtual std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch(
const std::vector<Values_>& flat_values,
const std::vector<RowSplits_>& row_splits) const = 0;
virtual std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch(
const std::vector<ValuesSpan_>& flat_values,
const std::vector<RowSplitsSpan_>& row_splits) const = 0;

virtual ~BatchTrimmer() = default;
};
Expand Down
11 changes: 9 additions & 2 deletions tensorflow_text/core/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ pybind_extension(
"tflite_registrar.cc",
],
additional_exported_symbols = [
"AddByteSplit",
"AddByteSplitByOffsets",
"AddFastBertNormalize",
"AddFastSentencepieceTokenize",
"AddFastSentencepieceDetokenize",
"AddFastWordpieceDetokenize",
"AddFastSentencepieceTokenize",
"AddFastWordpieceTokenize",
"AddFastWordpieceDetokenize",
"AddNgramsStringJoin",
"AddRaggedTensorToTensor",
"AddRoundRobinGenerateMasks",
"AddRoundRobinTrim",
"AddSentenceFragmenterV2",
"AddWhitespaceTokenize",
"SELECT_TFTEXT_OPS",
],
module_name = "tflite_registrar",
deps = [
Expand Down
41 changes: 27 additions & 14 deletions tensorflow_text/core/pybinds/tflite_registrar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "tensorflow_text/core/kernels/byte_splitter_tflite.h"
#include "tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h"
#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h"
#include "tensorflow_text/core/kernels/ngrams_tflite.h"
#include "tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h"
#include "tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h"
#include "tensorflow_text/core/kernels/round_robin_trimmer_tflite.h"
#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h"
#include "tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h"
#include "tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h"

PYBIND11_MODULE(tflite_registrar, m) {
Expand All @@ -30,18 +30,12 @@ PYBIND11_MODULE(tflite_registrar, m) {
A module with a Python wrapper for TFLite TFText ops.
)pbdoc";
m.attr("_allowed_symbols") = pybind11::make_tuple(
"AddByteSplit",
"AddByteSplitByOffsets",
"AddFastBertNormalize",
"AddFastSentencepieceDetokenize",
"AddFastSentencepieceTokenize",
"AddFastWordpieceTokenize",
"AddFastWordpieceDetokenize",
"AddNgramsStringJoin",
"AddRaggedTensorToTensor",
"AddSentenceFragmenterV2",
"AddWhitespaceTokenize",
"SELECT_TFTEXT_OPS");
"AddByteSplit", "AddByteSplitByOffsets", "AddFastBertNormalize",
"AddFastSentencepieceDetokenize", "AddFastSentencepieceTokenize",
"AddFastWordpieceTokenize", "AddFastWordpieceDetokenize",
"AddNgramsStringJoin", "AddRaggedTensorToTensor",
"AddRoundRobinGenerateMasks", "AddRoundRobinTrim",
"AddSentenceFragmenterV2", "AddWhitespaceTokenize", "SELECT_TFTEXT_OPS");
m.def(
"AddByteSplit",
[](uintptr_t resolver) {
Expand Down Expand Up @@ -123,6 +117,25 @@ PYBIND11_MODULE(tflite_registrar, m) {
R"pbdoc(
The function that adds AddRaggedTensorToTensor to the TFLite interpreter.
)pbdoc");
m.def(
"AddRoundRobinGenerateMasks",
[](uintptr_t resolver) {
tflite::ops::custom::text::AddRoundRobinGenerateMasks(
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
},
R"pbdoc(
The function that adds AddRoundRobinGenerateMasks to the TFLite
interpreter.
)pbdoc");
m.def(
"AddRoundRobinTrim",
[](uintptr_t resolver) {
tflite::ops::custom::text::AddRoundRobinTrim(
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
},
R"pbdoc(
The function that adds AddRoundRobinTrim to the TFLite interpreter.
)pbdoc");
m.def(
"AddSentenceFragmenterV2",
[](uintptr_t resolver) {
Expand Down
Loading

0 comments on commit cc439a2

Please sign in to comment.