Skip to content

Commit

Permalink
Implement Tokenizer op (#31)
Browse files Browse the repository at this point in the history
* Implement separator tokenizer with TST.
  TODO: Clarify what to do if the output is empty and no start/end text
  markers required. Also see if the current search algo is acceptable.

* Add utf8 util test

* For empty output produce [C] -> [C][0], [N][C] -> [N][C][0]

* Augument TST search with match conflict resolution in favor of
  earlier specified pattern matches.

* Address MAcOS build error.

* Adjust error message

* Address review comments.

* Remove nested loops.

* Remove 3rd party utf8 validation code.

* Address review comments part I.

* Move padding outside start/end markers.
  Split unit tests for invidividual test cases.

* Fix a common prefix bug reported by Xavier.
  • Loading branch information
yuslepukhin authored Dec 6, 2018
1 parent a68f5cc commit c52636e
Show file tree
Hide file tree
Showing 6 changed files with 1,440 additions and 0 deletions.
32 changes: 32 additions & 0 deletions onnxruntime/contrib_ops/contrib_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ Sample echo operator.)DOC");
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(Returns which elements of the input are NaN.)DOC");

ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "X", "Strings to tokenize", "T")
.Output(0, "Y", "Tokenized strings", "T")
.TypeConstraint(
"T",
{"tensor(string)"},
"Input/Output is a string tensor")
.Attr(
"mark",
"Boolean whether to mark the beginning/end character with start of text character (0x02)/end of text character (0x03).",
AttributeProto::INT)
.Attr(
"pad_value",
"The string used to pad output tensors when the tokens extracted doesn't match the maximum number of tokens found. If start/end markers are needed, padding will appear outside the markers.",
AttributeProto::STRING)
.Attr(
"separators",
"The list of separators, two consecutive segments in X connected by a separator would be divided into two tokens.",
AttributeProto::STRINGS)
.Attr(
"mincharnum",
"Minimum number of characters allowed in the output. For example, if mincharnum is 2, tokens such as \"A\" and \"B\" would be ignored",
AttributeProto::INT)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
})
.SetDoc(R"DOC(Tokenizer divides each string in X into a vector of strings along the last axis. All input strings including attributes are UTF-8 encoded.)DOC");

// Operators for linear 8 bit quanitzation support.
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
.SetDomain(kMSDomain)
Expand Down Expand Up @@ -491,6 +521,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
Expand All @@ -505,6 +536,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
Expand Down
Loading

0 comments on commit c52636e

Please sign in to comment.