From a9b4212eb374efed368416d34bd035988865dc69 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed <10833894+tarekgh@users.noreply.github.com> Date: Sat, 26 Oct 2024 16:07:20 -0700 Subject: [PATCH] Address the feedback regarding Bert tokenizer (#7280) * Address the feedback regarding Bert tokenizer * Small fix --- .../Model/BertTokenizer.cs | 167 +++++++++++++++--- .../Model/WordPieceTokenizer.cs | 4 +- .../Normalizer/BertNormalizer.cs | 22 +-- 3 files changed, 151 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index cc2004b19b..41a5a71eeb 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -290,9 +290,25 @@ public IReadOnlyList BuildInputsWithSpecialTokens(IEnumerable tokenIds throw new ArgumentNullException(nameof(tokenIds0)); } - // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. - int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); - List ids = new List(capacity: capacity) { ClsTokenId }; + List ids; + + if (tokenIds0 is ICollection c1) + { + int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens. + + if (tokenIds1 is not null) + { + capacity += tokenIds1 is ICollection c2 ? c2.Count + 1 : c1.Count + 1; + } + + ids = new(capacity) { ClsTokenId }; + } + else + { + // slow path + ids = new List(10) { ClsTokenId }; + } + ids.AddRange(tokenIds0); ids.Add(SepTokenId); @@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable tokenIds0, throw new ArgumentNullException(nameof(tokenIds0)); } - // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. - int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); - if (buffer.Length < capacity) + written = 0; + if (buffer.Length < 1) { - written = 0; return OperationStatus.DestinationTooSmall; } - written = 0; buffer[written++] = ClsTokenId; foreach (int id in tokenIds0) { + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } + buffer[written++] = id; } + + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } buffer[written++] = SepTokenId; if (tokenIds1 is not null) { foreach (int id in tokenIds1) { + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } buffer[written++] = id; } + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } buffer[written++] = SepTokenId; } @@ -367,11 +402,22 @@ public IReadOnlyList GetSpecialTokensMask(IEnumerable tokenIds0, IEnum throw new ArgumentNullException(nameof(tokenIds0)); } - int capacity = alreadyHasSpecialTokens ? - tokenIds0.Count() + (tokenIds1?.Count() ?? 0) : - tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. + List mask; + if (tokenIds0 is ICollection c1) + { + int capcity = c1.Count + 2; + + if (tokenIds1 is not null) + { + capcity += tokenIds1 is ICollection c2 ? c2.Count + 1 : c1.Count + 1; + } - List mask = new List(capacity: capacity); + mask = new List(capcity); + } + else + { + mask = new List(10); + } if (!alreadyHasSpecialTokens) { @@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable tokenIds0, Span tokenIds0, Span tokenIds0, Span CreateTokenTypeIdsFromSequences(IEnumerable token throw new ArgumentNullException(nameof(tokenIds0)); } - // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. - int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); + List typeIds; + if (tokenIds0 is ICollection c1) + { + int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens. + + if (tokenIds1 is not null) + { + capacity += tokenIds1 is ICollection c2 ? c2.Count + 1 : c1.Count + 1; + } - List typeIds = new List(capacity); - for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens. + typeIds = new List(capacity); + } + else + { + typeIds = new List(10); + } + + foreach (var id in tokenIds0) { typeIds.Add(0); } + typeIds.Add(0); // [CLS] + typeIds.Add(0); // [SEP] if (tokenIds1 is not null) { - for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token. + foreach (int id in tokenIds1) { typeIds.Add(1); } + + typeIds.Add(1); // [SEP] } return typeIds; @@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable tokenIds // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); - if (buffer.Length < capacity) + if (buffer.Length < 2) { return OperationStatus.DestinationTooSmall; } + buffer[written++] = 0; // [CLS] + buffer[written++] = 0; // [SEP] - for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens. + foreach (int id in tokenIds0) { + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } buffer[written++] = 0; } if (tokenIds1 is not null) { - for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token. + foreach (int id in tokenIds1) { + if (buffer.Length <= written) + { + written = 0; + return OperationStatus.DestinationTooSmall; + } buffer[written++] = 1; } + + if (buffer.Length < written) + { + return OperationStatus.DestinationTooSmall; + } + buffer[written++] = 1; // [SEP] } return OperationStatus.Done; diff --git a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs index 511891a7ac..4357ce086d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs @@ -233,7 +233,7 @@ await CreateAsync( continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, - disposeStream: true); + disposeStream: true).ConfigureAwait(false); /// /// Create a new instance of the class asynchronously. @@ -259,7 +259,7 @@ public static async Task CreateAsync( string continuingSubwordPrefix = DefaultContinuingSubwordPrefix, int maxInputCharsPerWord = DefaultMaxInputCharsPerWord, CancellationToken cancellationToken = default) => - await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false); + await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false); private static async Task CreateAsync( Stream vocabStream, diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs index 51e7e98e4f..7bdff506f5 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs @@ -69,7 +69,7 @@ public override string Normalize(string original) if (category == UnicodeCategory.SpaceSeparator) { - InsertChar(ref buffer, ref index, ' '); + AddChar(ref buffer, ref index, ' '); i += inc; continue; } @@ -85,7 +85,7 @@ public override string Normalize(string original) int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer); Debug.Assert(length > 0); - InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length)); + AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length)); i += inc; continue; @@ -93,22 +93,22 @@ public override string Normalize(string original) if (_tokenizeChineseChars && IsChineseChar(codePoint)) { - InsertChar(ref buffer, ref index, ' '); - InsertChar(ref buffer, ref index, c); + AddChar(ref buffer, ref index, ' '); + AddChar(ref buffer, ref index, c); if (inc > 0) { - InsertChar(ref buffer, ref index, original[i + 1]); + AddChar(ref buffer, ref index, original[i + 1]); } - InsertChar(ref buffer, ref index, ' '); + AddChar(ref buffer, ref index, ' '); i += inc; continue; } - InsertChar(ref buffer, ref index, c); + AddChar(ref buffer, ref index, c); if (inc > 0) { - InsertChar(ref buffer, ref index, original[i + 1]); + AddChar(ref buffer, ref index, original[i + 1]); } i += inc; } @@ -147,7 +147,7 @@ public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAcc } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void InsertChar(ref char[] buffer, ref int index, char c) + private static void AddChar(ref char[] buffer, ref int index, char c) { if (index >= buffer.Length) { @@ -158,9 +158,9 @@ private static void InsertChar(ref char[] buffer, ref int index, char c) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void InsertSpan(ref char[] buffer, ref int index, Span chars) + private static void AddSpan(ref char[] buffer, ref int index, Span chars) { - if (index + buffer.Length >= buffer.Length) + if (index + chars.Length >= buffer.Length) { Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10); }