Skip to content

Commit

Permalink
Address the feedback regarding Bert tokenizer (#7280)
Browse files Browse the repository at this point in the history
* Address the feedback regarding Bert tokenizer

* Small fix
  • Loading branch information
tarekgh authored Oct 26, 2024
1 parent a7a6d88 commit a9b4212
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 42 deletions.
167 changes: 138 additions & 29 deletions src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
List<int> ids;

if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens.

if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

ids = new(capacity) { ClsTokenId };
}
else
{
// slow path
ids = new List<int>(10) { ClsTokenId };
}

ids.AddRange(tokenIds0);
ids.Add(SepTokenId);

Expand Down Expand Up @@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
written = 0;
if (buffer.Length < 1)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}

written = 0;
buffer[written++] = ClsTokenId;
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}

buffer[written++] = id;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;
}

Expand All @@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
throw new ArgumentNullException(nameof(tokenIds0));
}

int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
List<int> mask;
if (tokenIds0 is ICollection<int> c1)
{
int capcity = c1.Count + 2;

if (tokenIds1 is not null)
{
capcity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

List<int> mask = new List<int>(capacity: capacity);
mask = new List<int>(capcity);
}
else
{
mask = new List<int>(10);
}

if (!alreadyHasSpecialTokens)
{
Expand Down Expand Up @@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
throw new ArgumentNullException(nameof(tokenIds0));
}

int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.

written = 0;
if (buffer.Length < capacity)
{
return OperationStatus.DestinationTooSmall;
}

if (!alreadyHasSpecialTokens)
{
if (buffer.Length < 1)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // CLS

foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP
}

Expand All @@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int

foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
}
Expand All @@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> typeIds;
if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens.

if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

List<int> typeIds = new List<int>(capacity);
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
typeIds = new List<int>(capacity);
}
else
{
typeIds = new List<int>(10);
}

foreach (var id in tokenIds0)
{
typeIds.Add(0);
}
typeIds.Add(0); // [CLS]
typeIds.Add(0); // [SEP]

if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
typeIds.Add(1);
}

typeIds.Add(1); // [SEP]
}

return typeIds;
Expand All @@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
if (buffer.Length < 2)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0; // [CLS]
buffer[written++] = 0; // [SEP]

for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1;
}

if (buffer.Length < written)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // [SEP]
}

return OperationStatus.Done;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ await CreateAsync(
continuingSubwordPrefix,
maxInputCharsPerWord,
cancellationToken,
disposeStream: true);
disposeStream: true).ConfigureAwait(false);

/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
Expand All @@ -259,7 +259,7 @@ public static async Task<WordPieceTokenizer> CreateAsync(
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
CancellationToken cancellationToken = default) =>
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false);

private static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
Expand Down
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public override string Normalize(string original)

if (category == UnicodeCategory.SpaceSeparator)
{
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
Expand All @@ -85,30 +85,30 @@ public override string Normalize(string original)
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
Debug.Assert(length > 0);

InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length));

i += inc;
continue;
}

if (_tokenizeChineseChars && IsChineseChar(codePoint))
{
InsertChar(ref buffer, ref index, ' ');
InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');

i += inc;
continue;
}

InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
i += inc;
}
Expand Down Expand Up @@ -147,7 +147,7 @@ public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAcc
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertChar(ref char[] buffer, ref int index, char c)
private static void AddChar(ref char[] buffer, ref int index, char c)
{
if (index >= buffer.Length)
{
Expand All @@ -158,9 +158,9 @@ private static void InsertChar(ref char[] buffer, ref int index, char c)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
private static void AddSpan(ref char[] buffer, ref int index, Span<char> chars)
{
if (index + buffer.Length >= buffer.Length)
if (index + chars.Length >= buffer.Length)
{
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
}
Expand Down

0 comments on commit a9b4212

Please sign in to comment.