diff --git a/NuGet.config b/NuGet.config index 5f023aa721..c60a5b8571 100644 --- a/NuGet.config +++ b/NuGet.config @@ -15,6 +15,7 @@ + @@ -47,6 +48,9 @@ + + + diff --git a/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj b/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj index 628cbe5293..464a2cedd7 100644 --- a/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj +++ b/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj @@ -8,6 +8,7 @@ None + true diff --git a/eng/Versions.props b/eng/Versions.props index 12eda87457..48c8bb2e1c 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -41,6 +41,7 @@ 3.27.1 3.3.5 1.1.1 + 9.0.0-rc.1.24431.7 3.3.4 4.9.2 1.0.0-beta.24375.2 diff --git a/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj b/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj index c391b0a00b..2ae1ca8467 100644 --- a/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj +++ b/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj @@ -4,9 +4,10 @@ net6.0 false $(NoWarn) - + None + true diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 0486831b27..59cc59edc7 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -7,6 +7,10 @@ preview + + true + + diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index c368378337..13c598b4ec 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -255,7 +255,7 @@ public virtual IEnumerable GenerateStreaming( return tokens // Skip the first _ token automatically added by tokenizer - .Where(t => t.Offset != (0, 0)) + .Where(t => !t.Offset.Equals(new Range(0, 0))) .Select(t => t.Id) .ToArray(); })); @@ -268,13 +268,13 @@ public virtual IEnumerable GenerateStreaming( var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); var duplicateTokenString = this.Tokenizer switch { - SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), + SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), _ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"), }; var tokenString = this.Tokenizer switch { - SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), + SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), _ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"), }; diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs index ea6f49edf7..489acb6524 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs @@ -49,7 +49,7 @@ public static TiktokenTokenizer FromPretrained( string modelFile = "tokenizer.model") { var modelFilePath = Path.Join(modelWeightFolder, modelFile); - var preTokenizer = new TiktokenPreTokenizer(new Regex(_re), _specialTokens); + var preTokenizer = new RegexPreTokenizer(new Regex(_re), _specialTokens); return TiktokenTokenizer.Create(File.OpenRead(modelFilePath), preTokenizer, normalizer: null, specialTokens: _specialTokens); } } diff --git a/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj index 9fd5d267ac..81b334564e 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj +++ b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj @@ -7,6 +7,10 @@ true + + true + + diff --git a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj index 6dbf9f1aa5..4d0a2fb4b1 100644 --- a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj +++ b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj @@ -7,6 +7,10 @@ true + + true + + diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index b614d2f73a..0e2f8021a2 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -7,6 +7,10 @@ true + + true + + @@ -23,5 +27,5 @@ - + diff --git a/src/Microsoft.ML.Tokenizers/EncodedToken.cs b/src/Microsoft.ML.Tokenizers/EncodedToken.cs index 06a00c9126..e6f3411b14 100644 --- a/src/Microsoft.ML.Tokenizers/EncodedToken.cs +++ b/src/Microsoft.ML.Tokenizers/EncodedToken.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace Microsoft.ML.Tokenizers { /// @@ -23,7 +25,7 @@ public readonly struct EncodedToken /// /// Gets the offset mapping to the original string. /// - public (int Index, int Length) Offset { get; } + public Range Offset { get; } /// /// Construct a new Token object using the token value, Id, and the offset mapping to the original string. @@ -31,7 +33,7 @@ public readonly struct EncodedToken /// The Id value associated to the token. /// The token string value. /// The offset mapping to the original string. - public EncodedToken(int id, string value, (int, int) offset) + public EncodedToken(int id, string value, Range offset) { Id = id; Offset = offset; diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj index 93a6cbb644..56686641b6 100644 --- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj +++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj @@ -23,6 +23,7 @@ + diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs index d85464ba39..6b6ec7a234 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs @@ -29,6 +29,13 @@ public sealed class BpeTokenizer : Tokenizer private int? _unknownTokenId; private readonly PreTokenizer? _preTokenizer; private readonly Normalizer? _normalizer; + private readonly Dictionary? _addedTokens; + private readonly Dictionary? _addedTokensReverse; + + /// + /// Gets the added tokens. + /// + public IReadOnlyDictionary? AddedTokens { get; } /// /// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char @@ -80,7 +87,7 @@ private set /// The JSON file path containing the dictionary of string keys and their ids. /// The file path containing the tokens's pairs list. public static BpeTokenizer Create(string vocabFile, string? mergesFile) - => Create(vocabFile, mergesFile, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false); + => Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false); /// /// Create a new Bpe tokenizer object to use for text encoding. @@ -89,6 +96,7 @@ public static BpeTokenizer Create(string vocabFile, string? mergesFile) /// The file path containing the tokens's pairs list. /// The pre-tokenizer to use. /// The normalizer to use. + /// The additional tokens to add to the vocabulary. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. @@ -98,6 +106,7 @@ public static BpeTokenizer Create( string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, + IReadOnlyDictionary? addedTokens = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, @@ -113,7 +122,7 @@ public static BpeTokenizer Create( (Dictionary? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult(); - return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); + return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); } /// @@ -122,7 +131,7 @@ public static BpeTokenizer Create( /// The JSON stream containing the dictionary of string keys and their ids. /// The stream containing the tokens's pairs list. public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream) - => Create(vocabStream, mergesStream, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false); + => Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false); /// /// Create a new Bpe tokenizer object to use for text encoding. @@ -131,6 +140,7 @@ public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream) /// The stream containing the tokens's pairs list. /// The pre-tokenizer to use. /// The normalizer to use. + /// The additional tokens to add to the vocabulary. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. @@ -140,6 +150,7 @@ public static BpeTokenizer Create( Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, + IReadOnlyDictionary? addedTokens = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, @@ -152,7 +163,7 @@ public static BpeTokenizer Create( (Dictionary? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult(); - return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); + return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); } /// @@ -162,6 +173,7 @@ public static BpeTokenizer Create( /// The stream containing the tokens's pairs list. /// The pre-tokenizer to use. /// The normalizer to use. + /// The additional tokens to add to the vocabulary. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. @@ -171,6 +183,7 @@ public static async Task CreateAsync( Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, + IReadOnlyDictionary? addedTokens = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, @@ -183,7 +196,7 @@ public static async Task CreateAsync( (Dictionary? vocab, Vec<(string, string)> merges) result = await ReadModelDataAsync(vocabStream, mergesStream, useAsync: true).ConfigureAwait(false); - return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); + return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); } /// @@ -193,16 +206,26 @@ public static async Task CreateAsync( /// The pairs list help in merging tokens during the encoding process. /// The pre-tokenizer to use. /// The normalizer to use. + /// The additional tokens to add to the vocabulary. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. /// Indicate whether allowing multiple unknown tokens get fused. - private BpeTokenizer(Dictionary? vocab, Vec<(string, string)> merges, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens) + private BpeTokenizer( + Dictionary? vocab, + Vec<(string, string)> merges, + PreTokenizer? preTokenizer, + Normalizer? normalizer, + IReadOnlyDictionary? addedTokens, + string? unknownToken, + string? continuingSubwordPrefix, + string? endOfWordSuffix, + bool fuseUnknownTokens) { FuseUnknownTokens = fuseUnknownTokens; ContinuingSubwordPrefix = continuingSubwordPrefix; EndOfWordSuffix = endOfWordSuffix; - _preTokenizer = preTokenizer ?? WhiteSpacePreTokenizer.Instance; // Default to WhiteSpace pre-tokenizer + _preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpace(); // Default to WhiteSpace pre-tokenizer _normalizer = normalizer; _vocab = vocab ?? new Dictionary(); @@ -215,6 +238,13 @@ private BpeTokenizer(Dictionary? vocab, Vec<(string, VocabReverse.Add(kvp.Value, kvp.Key.Data!); } + if (addedTokens is not null) + { + AddedTokens = addedTokens; + _addedTokens = addedTokens.ToDictionary(kvp => new StringSpanOrdinalKey(kvp.Key), kvp => (kvp.Value, kvp.Key)); + _addedTokensReverse = addedTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + } + UnknownToken = unknownToken; int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length; @@ -568,7 +598,7 @@ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenC /// /// The list of ids that we want to decode. /// The decoded string. - public override string? Decode(IEnumerable ids) => Decode(ids, considerSpecialTokens: true); + public override string Decode(IEnumerable ids) => Decode(ids, considerSpecialTokens: true); /// /// Decode the given ids, back to a String. @@ -576,7 +606,7 @@ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenC /// The list of ids that we want to decode. /// Indicate whether to consider special tokens or not. /// The decoded string. - public string? Decode(IEnumerable ids, bool considerSpecialTokens) + public string Decode(IEnumerable ids, bool considerSpecialTokens) { if (ids is null) { @@ -936,6 +966,12 @@ internal Word MergeWord(ReadOnlySpan w, ref PriorityQueue? priority internal void EncodeWithCache(ReadOnlySpan text, List tokens, int offset, ref PriorityQueue? priorityQueue) { + if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true) + { + tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, new Range(offset, offset + text.Length))); + return; + } + Word word; if (Cache is not null) { @@ -1004,6 +1040,13 @@ internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int private int EncodeToIdsWithCache(ReadOnlySpan text, List? accumulatedIds, int maxTokens, out int charsConsumed, ref PriorityQueue? priorityQueue) { + if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true && maxTokens > 0) + { + accumulatedIds?.Add(value.addedTokenId); + charsConsumed = text.Length; + return 1; + } + Word word; if (Cache is not null) @@ -1032,6 +1075,13 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac { Word word; + if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true && maxTokens > 0) + { + accumulatedIds?.Add(value.addedTokenId); + textIndex = 0; + return 1; + } + if (Cache is not null) { if (Cache.TryGetValue(text, out Word hit)) diff --git a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs index fbfbba7f7e..c1fd6bb1ca 100644 --- a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs @@ -376,7 +376,7 @@ private EncodeResults EncodeToTokens(string? text, scoped ReadOnly List tokens = new(); if (addBos && BeginningOfSentenceId.HasValue) { - tokens.Add(new EncodedToken(BeginningOfSentenceId.Value, BeginningOfSentenceToken!, (0, 0))); + tokens.Add(new EncodedToken(BeginningOfSentenceId.Value, BeginningOfSentenceToken!, new Range(0, 0))); } PriorityQueue agenda = new(textSpanToEncode.Length); @@ -395,7 +395,8 @@ private EncodeResults EncodeToTokens(string? text, scoped ReadOnly if (addEos && EndOfSentenceId.HasValue) { - tokens.Add(new EncodedToken(EndOfSentenceId.Value, EndOfSentenceToken!, (addPrefixSpace ? Math.Max(0, textSpanToEncode.Length - 1) : textSpanToEncode.Length, 0))); + int index = addPrefixSpace ? Math.Max(0, textSpanToEncode.Length - 1) : textSpanToEncode.Length; + tokens.Add(new EncodedToken(EndOfSentenceId.Value, EndOfSentenceToken!, new Range(index, index))); } return new EncodeResults { Tokens = tokens, NormalizedText = normalizedString, CharsConsumed = textSpanToEncode.Length }; @@ -427,7 +428,8 @@ private void EncodeInternal(string? text, scoped ReadOnlySpan textSpan, Li if (_addedTokens is not null && _addedTokens.TryGetValue(textSpan, out (int addedTokenId, string addedToken) value)) { - tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, ((addPrefixSpace && offset > 0) ? offset - 1 : offset, (addPrefixSpace && offset == 0) ? textSpan.Length - 1 : textSpan.Length))); + int index = (addPrefixSpace && offset > 0) ? offset - 1 : offset; + tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, new Range(index, index + ((addPrefixSpace && offset == 0) ? textSpan.Length - 1 : textSpan.Length)))); return; } @@ -1027,11 +1029,11 @@ private int EncodeToIdsResult(List tokens, IList? accumulated for (tokenCount = 0; tokenCount < maxTokens; tokenCount++) { // maxTokens is less than tokens.Count, so it is safe to index maxTokens. - if (tokens[tokenCount].Offset.Index == tokens[tokenCount + 1].Offset.Index) + if (tokens[tokenCount].Offset.Start.Value == tokens[tokenCount + 1].Offset.Start.Value) { // Ensure we'll not break the text in the middle of a code-point int j = tokenCount + 2; - while (j < tokens.Count && tokens[j].Offset.Index == tokens[tokenCount].Offset.Index) + while (j < tokens.Count && tokens[j].Offset.Start.Value == tokens[tokenCount].Offset.Start.Value) { j++; } @@ -1042,7 +1044,7 @@ private int EncodeToIdsResult(List tokens, IList? accumulated for (int k = tokenCount; k < j; k++) { accumulatedIds?.Add(tokens[k].Id); - charsConsumed += tokens[k].Offset.Length; + charsConsumed += tokens[k].Offset.End.Value - tokens[k].Offset.Start.Value; } tokenCount = j - 1; } @@ -1054,7 +1056,7 @@ private int EncodeToIdsResult(List tokens, IList? accumulated else { accumulatedIds?.Add(tokens[tokenCount].Id); - charsConsumed += tokens[tokenCount].Offset.Length; + charsConsumed += tokens[tokenCount].Offset.End.Value - tokens[tokenCount].Offset.Start.Value; } } @@ -1082,7 +1084,7 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu int index = tokens.Count - maxTokens; // avoid breaking the text in the middle of a code-point - while (index < tokens.Count && tokens[index].Offset.Index == tokens[index - 1].Offset.Index) + while (index < tokens.Count && tokens[index].Offset.Start.Value == tokens[index - 1].Offset.Start.Value) { index++; } @@ -1090,7 +1092,7 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu for (int i = index; i < tokens.Count; i++) { accumulatedIds?.Add(tokens[i].Id); - textIndex -= tokens[i].Offset.Length; + textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value; } return tokens.Count - index; @@ -1229,7 +1231,7 @@ private int EncodeToIdsFromEndInternal(string? text, scoped ReadOnlySpan t /// /// The list of ids that we want to decode. /// The decoded string. - public override string? Decode(IEnumerable ids) => Decode(ids, hasPrefixSpace: AddPrefixSpace, considerSpecialTokens: false); + public override string Decode(IEnumerable ids) => Decode(ids, hasPrefixSpace: AddPrefixSpace, considerSpecialTokens: false); /// /// Decode the given ids, back to a String. @@ -1238,7 +1240,7 @@ private int EncodeToIdsFromEndInternal(string? text, scoped ReadOnlySpan t /// Indicate whether the encoded string has a leading space. /// Indicate whether to consider special tokens during decoding. /// The decoded string. - public string? Decode(IEnumerable ids, bool hasPrefixSpace, bool considerSpecialTokens) + public string Decode(IEnumerable ids, bool hasPrefixSpace, bool considerSpecialTokens) { if (ids is null) { @@ -1590,11 +1592,12 @@ private static void AppendTokenWithOffsetAdjusting(IReadOnlyList t { if (tokensToAdd.Count > 0) { - tokens.Add(new EncodedToken(tokensToAdd[0].Id, tokensToAdd[0].Value, (offset == 0 ? tokensToAdd[0].Offset.Index : tokensToAdd[0].Offset.Index + offset - 1, offset == 0 ? tokensToAdd[0].Offset.Length - 1 : tokensToAdd[0].Offset.Length))); + (int s, int e) r = offset == 0 ? (tokensToAdd[0].Offset.Start.Value, tokensToAdd[0].Offset.End.Value - 1) : (tokensToAdd[0].Offset.Start.Value + offset - 1, tokensToAdd[0].Offset.End.Value + offset - 1); + tokens.Add(new EncodedToken(tokensToAdd[0].Id, tokensToAdd[0].Value, new Range(r.s, r.e))); for (int i = 1; i < tokensToAdd.Count; i++) { - tokens.Add(new EncodedToken(tokensToAdd[i].Id, tokensToAdd[i].Value, (tokensToAdd[i].Offset.Index + offset - 1, tokensToAdd[i].Offset.Length))); + tokens.Add(new EncodedToken(tokensToAdd[i].Id, tokensToAdd[i].Value, new Range(tokensToAdd[i].Offset.Start.Value + offset - 1, tokensToAdd[i].Offset.End.Value + offset - 1))); } } } @@ -1602,7 +1605,7 @@ private static void AppendTokenWithOffsetAdjusting(IReadOnlyList t { foreach (EncodedToken t in tokensToAdd) { - tokens.Add(new EncodedToken(t.Id, t.Value, (t.Offset.Index + offset, t.Offset.Length))); + tokens.Add(new EncodedToken(t.Id, t.Value, new Range(t.Offset.Start.Value + offset, t.Offset.End.Value + offset))); } } } @@ -1622,7 +1625,7 @@ private List EncodeToTokens(Span text, Span mapping, Re char c = text[0]; string[] charToString = ByteToUnicodeEncoding.Instance.CharToString; string tokenValue = (uint)c < charToString.Length ? charToString[c] : c.ToString(); - return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)].Id, tokenValue, (mapping[0], 1)) }; + return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)].Id, tokenValue, new Range(mapping[0], mapping[0] + 1)) }; } BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); @@ -1694,9 +1697,8 @@ private List EncodeToTokens(Span text, Span mapping, Re static EncodedToken GetToken(int id, string token, int index, int length, ReadOnlySpan originalText, Span mapping) { - int tokenStartIndex = mapping[index]; - int tokenLength = (index + length < mapping.Length ? mapping[index + length] - tokenStartIndex : originalText.Length - tokenStartIndex); - return new EncodedToken(id, token, (tokenStartIndex, tokenLength)); + int endIndex = index + length < mapping.Length ? mapping[index + length] : originalText.Length; + return new EncodedToken(id, token, new Range(mapping[index], endIndex)); } void TryMerge(int left, int right, ReadOnlySpan textSpan) @@ -1892,7 +1894,7 @@ public static CodeGenTokenizer Create( return new CodeGenTokenizer( vocabStream, mergesStream, - new TiktokenPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), + new RegexPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null, CodeGenTokenizer.CodeGenAddedTokens, addPrefixSpace: addPrefixSpace, diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs index e1cc47e13f..85f921ff0f 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs @@ -325,7 +325,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read { foreach (EncodedToken t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length))) { - tokens.Add(new EncodedToken(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length))); + tokens.Add(new EncodedToken(t.Id, t.Value, new Range(split.Offset + t.Offset.Start.Value, split.Offset + t.Offset.End.Value))); } } @@ -597,14 +597,14 @@ private int EncodeToIdsResult(List tokens, IList? accumulated for (int i = 0; i < maxTokens; i++) { accumulatedIds.Add(tokens[i].Id); - charsConsumed += tokens[i].Offset.Length; + charsConsumed += tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value; } } else { for (int i = 0; i < maxTokens; i++) { - charsConsumed += tokens[i].Offset.Length; + charsConsumed += tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value; } } @@ -634,14 +634,14 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu for (int i = tokens.Count - maxTokens; i < tokens.Count; i++) { accumulatedIds.Add(tokens[i].Id); - textIndex -= tokens[i].Offset.Length; + textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value; } } else { for (int i = tokens.Count - maxTokens; i < tokens.Count; i++) { - textIndex -= tokens[i].Offset.Length; + textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value; } } @@ -750,7 +750,7 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu /// /// The list of ids that we want to decode. /// The decoded string. - public override string? Decode(IEnumerable ids) + public override string Decode(IEnumerable ids) { if (ids is null) { @@ -905,7 +905,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList list = new List(tokens.Count); for (int j = 0; j < i; j++) @@ -915,7 +915,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList EncodeToTokens(Span token, Span indexMappi { Debug.Assert(token[0] < charToString.Length); string tokenValue = charToString[token[0]]; - return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) }; + return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, new Range(indexMapping[0], indexMapping[0] + 1)) }; } List word = new(token.Length); @@ -1036,7 +1036,7 @@ private List EncodeToTokens(Span token, Span indexMappi foreach (string w in word) { - tokens.Add(new EncodedToken(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length))); + tokens.Add(new EncodedToken(_vocab[new StringSpanOrdinalKey(w)], w, new Range(indexMapping[index], indexMapping[index] + w.Length))); index += w.Length; } diff --git a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs index 2406ab50fb..fe58b7bde1 100644 --- a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs @@ -12,16 +12,16 @@ namespace Microsoft.ML.Tokenizers // SentencePiece is under the Apache License 2.0 https://github.com/google/sentencepiece/blob/master/LICENSE /// - /// LlamaTokenizer is SentencePieceBpeTokenizer which is implemented based on https://github.com/google/sentencepiece. + /// LlamaTokenizer is SentencePieceTokenizer which is implemented based on https://github.com/google/sentencepiece. /// - public sealed class LlamaTokenizer : SentencePieceBpeTokenizer + public sealed class LlamaTokenizer : SentencePieceTokenizer { internal LlamaTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? addedTokens = null) : base(modelProto, addBos, addEos, addedTokens) { } /// - /// Create from the given model stream a LlamaTokenizer which is based on SentencePieceBpeTokenizer. The model stream should contain the SentencePiece Bpe model according to + /// Create from the given model stream a LlamaTokenizer which is based on SentencePieceTokenizer. The model stream should contain the SentencePiece Bpe model according to /// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification. /// /// The stream containing the SentencePiece Bpe model. diff --git a/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs index 64985bcc9d..b2229482fa 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs @@ -113,7 +113,7 @@ internal Phi2Tokenizer( } return new Phi2Tokenizer( - vocabStream, mergesStream, new TiktokenPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null, + vocabStream, mergesStream, new RegexPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null, CodeGenTokenizer.CodeGenAddedTokens, addPrefixSpace: addPrefixSpace, addBeginningOfSentence: addBeginOfSentence, addEndOfSentence: addEndOfSentence); } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs similarity index 98% rename from src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs rename to src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index 45a58c84a4..b89606ba8d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -22,7 +22,7 @@ namespace Microsoft.ML.Tokenizers /// /// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model. /// - public class SentencePieceBpeTokenizer : Tokenizer + public class SentencePieceTokenizer : Tokenizer { private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id. private readonly Dictionary _vocab = new(); @@ -36,14 +36,14 @@ public class SentencePieceBpeTokenizer : Tokenizer private readonly Dictionary? _specialTokens; private readonly Dictionary? _specialTokensReverse; - internal SentencePieceBpeTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : + internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto, specialTokens) { AddBeginningOfSentence = addBos; AddEndOfSentence = addEos; } - private SentencePieceBpeTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens) + private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens) { for (int i = 0; i < modelProto.Pieces.Count; i++) { @@ -272,7 +272,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen if (addBeginOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, (0, 0))); + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); } int currentOffset = 0; @@ -286,7 +286,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen if (_specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) { - tokens.Add(new EncodedToken(id, _specialTokensReverse![id], (Offset, Length))); + tokens.Add(new EncodedToken(id, _specialTokensReverse![id], new Range(Offset, Offset + Length))); } currentOffset = Offset + Length; @@ -299,7 +299,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, (text.Length, 0))); + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); } } @@ -319,7 +319,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo if (addBeginOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, (0, 0))); + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); } for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next) @@ -352,7 +352,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo tokens.Add(new EncodedToken( id, GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text), - (symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length))); + new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length))); } continue; } @@ -364,7 +364,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, (text.Length, 0))); + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); } return; @@ -381,7 +381,7 @@ void EncodeAsBytes(ReadOnlySpan text, int index) if (_vocabReverse.TryGetValue(id, out string? token)) { - tokens.Add(new EncodedToken(id, token, (index + i, 1))); + tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1))); } } else @@ -405,7 +405,7 @@ void EncodeAsBytes(ReadOnlySpan text, int index) if (_vocabReverse.TryGetValue(id, out string? token)) { - tokens.Add(new EncodedToken(id, token, (index + i, length))); + tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length))); } length = 0; @@ -433,7 +433,7 @@ void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text) revMerge is null || !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) { - tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), (pieceSpan.Index, pieceSpan.Length))); + tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length))); return; } @@ -1526,7 +1526,7 @@ revMerge is null || /// /// The list of ids that we want to decode. /// The decoded string. - public override string? Decode(IEnumerable ids) + public override string Decode(IEnumerable ids) => Decode(ids, considerSpecialTokens: false); /// @@ -1535,7 +1535,7 @@ revMerge is null || /// The list of ids that we want to decode. /// Indicate whether to consider special tokens during decoding. /// The decoded string. - public string? Decode(IEnumerable ids, bool considerSpecialTokens) + public string Decode(IEnumerable ids, bool considerSpecialTokens) { if (ids is null) { @@ -1735,7 +1735,7 @@ static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitesp prefixRemoved = true; } - static void TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb) + static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb) { if (!considerSpecialTokens) { @@ -1979,7 +1979,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.Done; - static OperationStatus TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten) + static OperationStatus TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten) { string? specialToken = null; diff --git a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs index 47fc5971c0..b169b2234f 100644 --- a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs @@ -307,7 +307,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens, tokens.Add(new EncodedToken( value[i].Id, value[i].TokenLength == 0 ? string.Empty : text.Slice(value[i].TokenIndex, value[i].TokenLength).ToString(), - (value[i].TokenIndex + offset, value[i].TokenLength))); + new Range(value[i].TokenIndex + offset, value[i].TokenIndex + offset + value[i].TokenLength))); } return; @@ -316,7 +316,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens, // cache miss if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId)) { - tokens.Add(new EncodedToken(mappedId.Id, mappedId.Token, (offset, mappedId.Token.Length))); + tokens.Add(new EncodedToken(mappedId.Id, mappedId.Token, new Range(offset, offset + mappedId.Token.Length))); return; } @@ -348,7 +348,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens, tokens.Add(new EncodedToken( encodedTokens[i].Id, encodedTokens[i].TokenLength == 0 ? string.Empty : text.Slice(encodedTokens[i].TokenIndex, encodedTokens[i].TokenLength).ToString(), - (encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenLength))); + new Range(encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenIndex + offset + encodedTokens[i].TokenLength))); } } @@ -792,7 +792,7 @@ private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[] /// /// The list of ids that we want to decode. /// The decoded string. - public override string? Decode(IEnumerable ids) + public override string Decode(IEnumerable ids) { // Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words. // Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively. @@ -824,10 +824,6 @@ private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[] tokenBytes.Span.CopyTo(utf8Bytes.Slice(utf8ByteCount)); utf8ByteCount += tokenBytes.Length; } - else - { - return null; - } } return Helpers.GetString(utf8Bytes.Slice(0, utf8ByteCount)); @@ -1029,6 +1025,7 @@ private enum ModelEncoding private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = [ // chat + ( "o1-", ModelEncoding.O200kBase ), // e.g. o1-mini ( "gpt-4o-", ModelEncoding.O200kBase), // e.g., gpt-4o-2024-05-13 ( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k ( "gpt-3.5-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc. @@ -1040,6 +1037,7 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo { // chat { "gpt-4o", ModelEncoding.O200kBase }, + { "o1", ModelEncoding.O200kBase }, { "gpt-4", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase }, @@ -1239,7 +1237,7 @@ private static TiktokenTokenizer CreateForModel( cache.encoder, cache.decoder, cache.vocab, - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), tiktokenConfiguration.SpecialTokens, normalizer, LruCache.DefaultCacheSize); @@ -1367,7 +1365,7 @@ public static TiktokenTokenizer CreateForModel( } return new TiktokenTokenizer(vocabStream, - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), tiktokenConfiguration.SpecialTokens, normalizer, cacheSize); @@ -1407,7 +1405,7 @@ public static async Task CreateForModelAsync( } return await CreateAsync(vocabStream, - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), normalizer, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false); diff --git a/src/Microsoft.ML.Tokenizers/Model/Word.cs b/src/Microsoft.ML.Tokenizers/Model/Word.cs index 5acfd9ae4b..003243934c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Word.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Word.cs @@ -296,7 +296,7 @@ public void ToTokens(SortedDictionary vocabReverse, List - public abstract class PreTokenizer + public abstract partial class PreTokenizer { /// /// Get the offsets and lengths of the tokens relative to the . @@ -40,6 +40,32 @@ public abstract class PreTokenizer } } + private const string WhiteSpacePattern = /*lang=regex*/ @"\w+|[^\w\s]+"; + private static PreTokenizer? _whiteSpacePreTokenizer; +#if NET7_0_OR_GREATER + [GeneratedRegex(WhiteSpacePattern)] + private static partial Regex WhiteSpaceRegex(); +#else + private static Regex WhiteSpaceRegex() => new Regex(WhiteSpacePattern, RegexOptions.Compiled); +#endif + + /// + /// Create a new instance of the class which split the text at the word boundary. + /// The word is a set of alphabet, numeric, and underscore characters. + /// + /// The dictionary containing the special tokens and their corresponding ids. + /// The pre-tokenizer that splits the text at the word boundary. + public static PreTokenizer CreateWhiteSpace(IReadOnlyDictionary? specialTokensEncoder = null) + { + if (specialTokensEncoder is null) + { + // return a singleton instance of the WhiteSpace pre-tokenizer + return _whiteSpacePreTokenizer ??= new RegexPreTokenizer(WhiteSpaceRegex(), null); + } + + return new RegexPreTokenizer(WhiteSpaceRegex(), specialTokensEncoder); + } + internal static IEnumerable<(int Offset, int Length)> SplitText(ReadOnlySpan text, Regex regex) { #if NET7_0_OR_GREATER diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs similarity index 95% rename from src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs rename to src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs index 4050f75d07..9685e370b7 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs @@ -13,18 +13,18 @@ namespace Microsoft.ML.Tokenizers /// /// The pre-tokenizer for Tiktoken tokenizer. /// - public sealed class TiktokenPreTokenizer : PreTokenizer + public sealed partial class RegexPreTokenizer : PreTokenizer { private readonly Regex? _specialTokensRegex; private readonly Regex _regex; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The regex to use for splitting the text into smaller tokens in the pre-tokenization process. /// The dictionary containing the special tokens and their corresponding ids. /// When regex is null - public TiktokenPreTokenizer(Regex regex, IReadOnlyDictionary? specialTokensEncoder) + public RegexPreTokenizer(Regex regex, IReadOnlyDictionary? specialTokensEncoder) { if (regex is null) { diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs deleted file mode 100644 index 4ba737d1bb..0000000000 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs +++ /dev/null @@ -1,61 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Text.RegularExpressions; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// The pre-tokenizer which split the text at the word boundary. - /// The word is a set of alphabet, numeric, and underscore characters. - /// - public sealed partial class WhiteSpacePreTokenizer : PreTokenizer - { - /// - /// Gets a singleton instance of the WhiteSpace pre-tokenizer.. - /// - public static WhiteSpacePreTokenizer Instance { get; } = new WhiteSpacePreTokenizer(); - - private const string PretokenizePattern = /*lang=regex*/ @"\w+|[^\w\s]+"; -#if NET7_0_OR_GREATER - [GeneratedRegex(PretokenizePattern)] - private static partial Regex PretokenizeRegex(); -#else - private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled); - private static Regex PretokenizeRegex() => _regex; -#endif - - /// - /// Get the offsets and lengths of the tokens relative to the . - /// - /// The string to split into tokens. - /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. - public override IEnumerable<(int Offset, int Length)> PreTokenize(string text) - { - if (string.IsNullOrEmpty(text)) - { - return []; - } - - return SplitText(text, PretokenizeRegex()); - } - - /// - /// Get the offsets and lengths of the tokens relative to the . - /// - /// The string to split into tokens. - /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. - public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text) - { - if (text.IsEmpty) - { - return []; - } - - return SplitText(text, PretokenizeRegex()); - } - } -} diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 4821a91984..f9e47707b0 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -241,7 +241,7 @@ protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan text if (tokenCount > 0) { var token = tokens.Tokens[tokenCount - 1]; - return token.Offset.Index + token.Offset.Length; + return token.Offset.End.Value; } return 0; @@ -251,7 +251,7 @@ protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan text if (tokenCount > 0) { var token = tokens.Tokens[tokens.Tokens.Count - tokenCount]; - return token.Offset.Index; + return token.Offset.Start.Value; } return tokens.NormalizedText?.Length ?? textSpan.Length; @@ -361,7 +361,7 @@ public int GetIndexByTokenCountFromEnd(ReadOnlySpan text, int maxTokenCoun /// Types derived from may override this implementation to provide a more efficient implementation. /// By default, it uses . /// - public virtual string? Decode(IEnumerable ids) + public virtual string Decode(IEnumerable ids) { if (ids is null) { diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs index fdcbc070c8..081decbf07 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs @@ -103,18 +103,18 @@ private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null, var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32); // scale base_size - anchors[.., 2..] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0); + anchors[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(2..)] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0); // compute areas of anchors - var areas = torch.mul(anchors[.., 2], anchors[.., 3]); + var areas = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], anchors[RangeUtil.ToTensorIndex(..), 3]); // correct for ratios - anchors[.., 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length })); - anchors[.., 3] = torch.mul(anchors[.., 2], torch.repeat_interleave(ratios, new long[] { scales.Length })); + anchors[RangeUtil.ToTensorIndex(..), 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length })); + anchors[RangeUtil.ToTensorIndex(..), 3] = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], torch.repeat_interleave(ratios, new long[] { scales.Length })); // transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2) - anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[.., 2] * 0.5, new long[] { 2, 1 }).T; - anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[.., 3] * 0.5, new long[] { 2, 1 }).T; + anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 2] * 0.5, new long[] { 2, 1 }).T; + anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 3] * 0.5, new long[] { 2, 1 }).T; return anchors.MoveToOuterDisposeScope(); } diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs index a44d64c506..d50791a965 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs @@ -113,7 +113,7 @@ public override Tensor forward(Tensor x, Tensor mask) k = k.permute(0, 2, 1, 3); v = v.permute(0, 2, 1, 3); - var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[.., this.attention_bias_idxs]; + var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[RangeUtil.ToTensorIndex(..), this.attention_bias_idxs]; if (!(mask is null)) { long nW = mask.shape[0]; diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs index 6bba3fc596..28b9a948d9 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs @@ -127,7 +127,7 @@ public override Tensor forward(Tensor x, int h, int w, Tensor maskMatrix) } else { - x = x[.., ..h, ..w].contiguous(); + x = x[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..h), RangeUtil.ToTensorIndex(..w)].contiguous(); } } diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs index 6f3732c72b..735e135691 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs @@ -384,7 +384,7 @@ private bool TrainStep(IHost host, var padW = 32 - (image.Width % 32); var padH = 32 - (image.Height % 32); using var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: Device); - transMidTensor[.., .., ..image.Height, ..image.Width] = reMidTensor / 255.0; + transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0; var imageTensor = Normalize(transMidTensor, Device); VBuffer labels = default; @@ -407,11 +407,11 @@ private bool TrainStep(IHost host, long y1 = (long)boxValues[b++]; // Our labels are 1 based, the TorchSharp model is 0 based so subtract 1 to they align correctly. long cl = labelValues[i] - 1; - labelTensor[.., i, 0] = x0; - labelTensor[.., i, 1] = y0; - labelTensor[.., i, 2] = x1; - labelTensor[.., i, 3] = y1; - labelTensor[.., i, 4] = cl; + labelTensor[RangeUtil.ToTensorIndex(..), i, 0] = x0; + labelTensor[RangeUtil.ToTensorIndex(..), i, 1] = y0; + labelTensor[RangeUtil.ToTensorIndex(..), i, 2] = x1; + labelTensor[RangeUtil.ToTensorIndex(..), i, 3] = y1; + labelTensor[RangeUtil.ToTensorIndex(..), i, 4] = cl; } return (imageTensor.MoveToOuterDisposeScope(), labelTensor.MoveToOuterDisposeScope()); } @@ -919,7 +919,7 @@ private Tensor PrepInputTensors(ref MLImage image, ValueGetter imageGet var padW = 32 - (image.Width % 32); var padH = 32 - (image.Height % 32); var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: _parent.Device); - transMidTensor[.., .., ..image.Height, ..image.Width] = reMidTensor / 255.0; + transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0; var imageTensor = ObjectDetectionTrainer.Trainer.Normalize(transMidTensor, _parent.Device); return imageTensor.MoveToOuterDisposeScope(); } diff --git a/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs b/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs index 3954677526..45ebeb4aae 100644 --- a/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs +++ b/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs @@ -40,20 +40,20 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso var classificationLosses = new List(); var regressionLosses = new List(); - var anchor = anchors[0, .., ..]; + var anchor = anchors[0, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)]; - var anchorWidths = anchor[.., 2] - anchor[.., 0]; - var anchorHeights = anchor[.., 3] - anchor[.., 1]; - var anchorCtrX = anchor[.., 0] + (0.5 * anchorWidths); - var anchorCtrY = anchor[.., 1] + (0.5 * anchorHeights); + var anchorWidths = anchor[RangeUtil.ToTensorIndex(..), 2] - anchor[RangeUtil.ToTensorIndex(..), 0]; + var anchorHeights = anchor[RangeUtil.ToTensorIndex(..), 3] - anchor[RangeUtil.ToTensorIndex(..), 1]; + var anchorCtrX = anchor[RangeUtil.ToTensorIndex(..), 0] + (0.5 * anchorWidths); + var anchorCtrY = anchor[RangeUtil.ToTensorIndex(..), 1] + (0.5 * anchorHeights); for (int j = 0; j < batchSize; ++j) { - var classification = classifications[j, .., ..]; - var regression = regressions[j, .., ..]; + var classification = classifications[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)]; + var regression = regressions[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)]; - var bboxAnnotation = annotations[j, .., ..]; - bboxAnnotation = bboxAnnotation[bboxAnnotation[.., 4] != -1]; + var bboxAnnotation = annotations[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)]; + bboxAnnotation = bboxAnnotation[bboxAnnotation[RangeUtil.ToTensorIndex(..), 4] != -1]; classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4); @@ -73,7 +73,7 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso } else { - var iou = CalcIou(anchors[0, .., ..], bboxAnnotation[.., ..4]); // num_anchors x num_annotations + var iou = CalcIou(anchors[0, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)], bboxAnnotation[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..4)]); // num_anchors x num_annotations var (iou_max, iou_argmax) = torch.max(iou, dim: 1); // num_anchors x 1 @@ -125,10 +125,10 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso var anchorCtrXPi = anchorCtrX[positiveIndices]; var anchorCtrYPi = anchorCtrY[positiveIndices]; - var gtWidths = assignedAnnotations[.., 2] - assignedAnnotations[.., 0]; - var gtHeights = assignedAnnotations[.., 3] - assignedAnnotations[.., 1]; - var gtCtrX = assignedAnnotations[.., 0] + (0.5 * gtWidths); - var gtCtrY = assignedAnnotations[.., 1] + (0.5 * gtHeights); + var gtWidths = assignedAnnotations[RangeUtil.ToTensorIndex(..), 2] - assignedAnnotations[RangeUtil.ToTensorIndex(..), 0]; + var gtHeights = assignedAnnotations[RangeUtil.ToTensorIndex(..), 3] - assignedAnnotations[RangeUtil.ToTensorIndex(..), 1]; + var gtCtrX = assignedAnnotations[RangeUtil.ToTensorIndex(..), 0] + (0.5 * gtWidths); + var gtCtrY = assignedAnnotations[RangeUtil.ToTensorIndex(..), 1] + (0.5 * gtHeights); // clip widths to 1 gtWidths = torch.clamp(gtWidths, min: 1); @@ -178,17 +178,17 @@ private object ToTensorIndex() private static Tensor CalcIou(Tensor a, Tensor b) { - var area = (b[.., 2] - b[.., 0]) * (b[.., 3] - b[.., 1]); + var area = (b[RangeUtil.ToTensorIndex(..), 2] - b[RangeUtil.ToTensorIndex(..), 0]) * (b[RangeUtil.ToTensorIndex(..), 3] - b[RangeUtil.ToTensorIndex(..), 1]); - var iw = torch.minimum(input: torch.unsqueeze(a[.., 2], dim: 1), b[.., 2]) - - torch.maximum(input: torch.unsqueeze(a[.., 0], 1), b[.., 0]); - var ih = torch.minimum(input: torch.unsqueeze(a[.., 3], dim: 1), b[.., 3]) - - torch.maximum(input: torch.unsqueeze(a[.., 1], 1), b[.., 1]); + var iw = torch.minimum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 2], dim: 1), b[RangeUtil.ToTensorIndex(..), 2]) - + torch.maximum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 0], 1), b[RangeUtil.ToTensorIndex(..), 0]); + var ih = torch.minimum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 3], dim: 1), b[RangeUtil.ToTensorIndex(..), 3]) - + torch.maximum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 1], 1), b[RangeUtil.ToTensorIndex(..), 1]); iw = torch.clamp(iw, min: 0); ih = torch.clamp(ih, min: 0); - var ua = torch.unsqueeze((a[.., 2] - a[.., 0]) * (a[.., 3] - a[.., 1]), dim: 1) + area - (iw * ih); + var ua = torch.unsqueeze((a[RangeUtil.ToTensorIndex(..), 2] - a[RangeUtil.ToTensorIndex(..), 0]) * (a[RangeUtil.ToTensorIndex(..), 3] - a[RangeUtil.ToTensorIndex(..), 1]), dim: 1) + area - (iw * ih); ua = torch.clamp(ua, min: 1e-8); var intersection = iw * ih; diff --git a/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj b/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj index 698dbfd623..c347333d27 100644 --- a/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj +++ b/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj @@ -19,6 +19,7 @@ + @@ -32,13 +33,13 @@ dict.txt - + encoder.json - + vocab.bpe - + diff --git a/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs b/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs index 7d2e0d3850..cd158fa5d8 100644 --- a/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs +++ b/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs @@ -50,7 +50,7 @@ public static void Postprocess(Tensor imgBatch, Tensor classification, Tensor re for (int i = 0; i < classification.shape[2]; ++i) { - var scores1 = torch.squeeze(classification[.., .., i], null); + var scores1 = torch.squeeze(classification[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), i], null); var scoresOverThresh = scores1 > 0.05; if (scoresOverThresh.sum().ToSingle() == 0) { @@ -108,16 +108,16 @@ private static Tensor Nms(Tensor boxes, Tensor scores, double iouThreshold = 0.5 using (var nmsScope = torch.NewDisposeScope()) { // boxes: Tensor [N,4],scores: Tensor [N,] - var x1 = boxes[.., 0]; - var y1 = boxes[.., 1]; - var x2 = boxes[.., 2]; - var y2 = boxes[.., 3]; + var x1 = boxes[RangeUtil.ToTensorIndex(..), 0]; + var y1 = boxes[RangeUtil.ToTensorIndex(..), 1]; + var x2 = boxes[RangeUtil.ToTensorIndex(..), 2]; + var y2 = boxes[RangeUtil.ToTensorIndex(..), 3]; var areas = (x2 - x1) * (y2 - y1); // [N,] var (_, _order) = scores.sort(0, descending: true); var keep = new List(); - var order = _order[..]; + var order = _order[RangeUtil.ToTensorIndex(..)]; while (order.numel() > 0) { long i; @@ -133,13 +133,13 @@ private static Tensor Nms(Tensor boxes, Tensor scores, double iouThreshold = 0.5 keep.Add(i); } - var xx1 = x1[order[1..]].clamp(min: x1[i]); // [N - 1,] - var yy1 = y1[order[1..]].clamp(min: y1[i]); - var xx2 = x2[order[1..]].clamp(max: x2[i]); - var yy2 = y2[order[1..]].clamp(max: y2[i]); + var xx1 = x1[order[RangeUtil.ToTensorIndex(1..)]].clamp(min: x1[i]); // [N - 1,] + var yy1 = y1[order[RangeUtil.ToTensorIndex(1..)]].clamp(min: y1[i]); + var xx2 = x2[order[RangeUtil.ToTensorIndex(1..)]].clamp(max: x2[i]); + var yy2 = y2[order[RangeUtil.ToTensorIndex(1..)]].clamp(max: y2[i]); var inter = (xx2 - xx1).clamp(min: 0) * (yy2 - yy1).clamp(min: 0); // [N - 1,] - var iou = inter / (areas[i] + areas[order[1..]] - inter); // [N-1, ] + var iou = inter / (areas[i] + areas[order[RangeUtil.ToTensorIndex(1..)]] - inter); // [N-1, ] var idx = (iou <= iouThreshold).nonzero().squeeze(); // idx: [N - 1,] and order:[N,] if (idx.numel() == 0) { @@ -167,15 +167,15 @@ private static Tensor TransformBbox(Tensor boxes, Tensor deltas) var mean = torch.from_array(new double[] { 0, 0, 0, 0 }).to_type(ScalarType.Float32).to(boxes.device); var std = torch.from_array(new double[] { 0.1, 0.1, 0.2, 0.2 }).to_type(ScalarType.Float32).to(boxes.device); - var widths = boxes[.., .., 2] - boxes[.., .., 0]; - var heights = boxes[.., .., 3] - boxes[.., .., 1]; - var ctrX = boxes[.., .., 0] + (0.5 * widths); - var ctrY = boxes[.., .., 1] + (0.5 * heights); + var widths = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2] - boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0]; + var heights = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3] - boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1]; + var ctrX = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0] + (0.5 * widths); + var ctrY = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1] + (0.5 * heights); - var dx = (deltas[.., .., 0] * std[0]) + mean[0]; - var dy = (deltas[.., .., 1] * std[1]) + mean[1]; - var dw = (deltas[.., .., 2] * std[2]) + mean[2]; - var dh = (deltas[.., .., 3] * std[3]) + mean[3]; + var dx = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0] * std[0]) + mean[0]; + var dy = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1] * std[1]) + mean[1]; + var dw = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2] * std[2]) + mean[2]; + var dh = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3] * std[3]) + mean[3]; var predCtrX = ctrX + (dx * widths); var predCtrY = ctrY + (dy * heights); @@ -210,11 +210,11 @@ private static Tensor ClipBoxes(Tensor boxes, Tensor img) var height = img.shape[2]; var width = img.shape[3]; - var clippedBoxesX0 = torch.clamp(boxes[.., .., 0], min: 0); - var clippedBoxesY0 = torch.clamp(boxes[.., .., 1], min: 0); + var clippedBoxesX0 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0], min: 0); + var clippedBoxesY0 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1], min: 0); - var clippedBoxesX1 = torch.clamp(boxes[.., .., 2], max: width); - var clippedBoxesY1 = torch.clamp(boxes[.., .., 3], max: height); + var clippedBoxesX1 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2], max: width); + var clippedBoxesY1 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3], max: height); var clippedBoxes = torch.stack( new List { clippedBoxesX0, clippedBoxesY0, clippedBoxesX1, clippedBoxesY1 }, diff --git a/src/Microsoft.ML.TorchSharp/Utils/Index.cs b/src/Microsoft.ML.TorchSharp/Utils/Index.cs deleted file mode 100644 index 20f59a2e50..0000000000 --- a/src/Microsoft.ML.TorchSharp/Utils/Index.cs +++ /dev/null @@ -1,145 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Runtime.CompilerServices; - -namespace System -{ - /// Represent a type can be used to index a collection either from the start or the end. - /// - /// Index is used by the C# compiler to support the new index syntax - /// - /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ; - /// int lastElement = someArray[^1]; // lastElement = 5 - /// - /// - internal readonly struct Index : IEquatable - { - private readonly int _value; - - /// Construct an Index using a value and indicating if the index is from the start or from the end. - /// The index value. it has to be zero or positive number. - /// Indicating if the index is from the start or from the end. - /// - /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Index(int value, bool fromEnd = false) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required."); - } - - if (fromEnd) - _value = ~value; - else - _value = value; - } - - // The following private constructors mainly created for perf reason to avoid the checks - private Index(int value) - { - _value = value; - } - - /// Create an Index pointing at first element. - public static Index Start => new Index(0); - - /// Create an Index pointing at beyond last element. - public static Index End => new Index(~0); - - /// Create an Index from the start at the position indicated by the value. - /// The index value from the start. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Index FromStart(int value) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required."); - } - - return new Index(value); - } - - /// Create an Index from the end at the position indicated by the value. - /// The index value from the end. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Index FromEnd(int value) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required."); - } - - return new Index(~value); - } - - /// Returns the index value. - public int Value - { - get - { - if (_value < 0) - return ~_value; - else - return _value; - } - } - - /// Indicates whether the index is from the start or the end. - public bool IsFromEnd => _value < 0; - - /// Calculate the offset from the start using the giving collection length. - /// The length of the collection that the Index will be used with. length has to be a positive value - /// - /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. - /// we don't validate either the returned offset is greater than the input length. - /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and - /// then used to index a collection will get out of range exception which will be same affect as the validation. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public int GetOffset(int length) - { - var offset = _value; - if (IsFromEnd) - { - // offset = length - (~value) - // offset = length + (~(~value) + 1) - // offset = length + value + 1 - - offset += length + 1; - } - return offset; - } - - /// Indicates whether the current Index object is equal to another object of the same type. - /// An object to compare with this object - public override bool Equals(object value) => value is Index && _value == ((Index)value)._value; - - /// Indicates whether the current Index object is equal to another Index object. - /// An object to compare with this object - public bool Equals(Index other) => _value == other._value; - - /// Returns the hash code for this instance. - public override int GetHashCode() => _value; - - /// Converts integer number to an Index. - public static implicit operator Index(int value) => FromStart(value); - - /// Converts the value of the current Index object to its equivalent string representation. - public override string ToString() - { - if (IsFromEnd) - return ToStringFromEnd(); - - return ((uint)Value).ToString(); - } - - private string ToStringFromEnd() - { - return '^' + Value.ToString(); - } - } -} diff --git a/src/Microsoft.ML.TorchSharp/Utils/Range.cs b/src/Microsoft.ML.TorchSharp/Utils/Range.cs deleted file mode 100644 index b372aed591..0000000000 --- a/src/Microsoft.ML.TorchSharp/Utils/Range.cs +++ /dev/null @@ -1,141 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using Microsoft.ML.TorchSharp.Utils; -using static TorchSharp.torch; - -namespace System -{ - /// Represent a range has start and end indexes. - /// - /// Range is used by the C# compiler to support the range syntax. - /// - /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 }; - /// int[] subArray1 = someArray[0..2]; // { 1, 2 } - /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 } - /// - /// - internal readonly struct Range : IEquatable - { - /// Represent the inclusive start index of the Range. - public Index Start { get; } - - /// Represent the exclusive end index of the Range. - public Index End { get; } - - /// Construct a Range object using the start and end indexes. - /// Represent the inclusive start index of the range. - /// Represent the exclusive end index of the range. - public Range(Index start, Index end) - { - Start = start; - End = end; - } - - /// Indicates whether the current Range object is equal to another object of the same type. - /// An object to compare with this object - public override bool Equals(object value) => - value is Range r && - r.Start.Equals(Start) && - r.End.Equals(End); - - /// Indicates whether the current Range object is equal to another Range object. - /// An object to compare with this object - public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End); - - /// Returns the hash code for this instance. - public override int GetHashCode() - { -#if (!NETSTANDARD2_0 && !NETFRAMEWORK) - return HashCode.Combine(Start.GetHashCode(), End.GetHashCode()); -#else - return HashHelpers.Combine(Start.GetHashCode(), End.GetHashCode()); -#endif - } - - /// Converts the value of the current Range object to its equivalent string representation. - public override string ToString() - { -#if (!NETSTANDARD2_0 && !NETFRAMEWORK) - Span span = stackalloc char[2 + (2 * 11)]; // 2 for "..", then for each index 1 for '^' and 10 for longest possible uint - int pos = 0; - - if (Start.IsFromEnd) - { - span[0] = '^'; - pos = 1; - } - bool formatted = ((uint)Start.Value).TryFormat(span.Slice(pos), out int charsWritten); - Debug.Assert(formatted); - pos += charsWritten; - - span[pos++] = '.'; - span[pos++] = '.'; - - if (End.IsFromEnd) - { - span[pos++] = '^'; - } - formatted = ((uint)End.Value).TryFormat(span.Slice(pos), out charsWritten); - Debug.Assert(formatted); - pos += charsWritten; - - return new string(span.Slice(0, pos)); -#else - return Start.ToString() + ".." + End.ToString(); -#endif - } - - /// Create a Range object starting from start index to the end of the collection. - public static Range StartAt(Index start) => new Range(start, Index.End); - - /// Create a Range object starting from first element in the collection to the end Index. - public static Range EndAt(Index end) => new Range(Index.Start, end); - - /// Create a Range object starting from first element to the end. - public static Range All => new Range(Index.Start, Index.End); - - /// Calculate the start offset and length of range object using a collection length. - /// The length of the collection that the range will be used with. length has to be a positive value. - /// - /// For performance reason, we don't validate the input length parameter against negative values. - /// It is expected Range will be used with collections which always have non negative length/count. - /// We validate the range is inside the length scope though. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public (int Offset, int Length) GetOffsetAndLength(int length) - { - int start; - var startIndex = Start; - if (startIndex.IsFromEnd) - start = length - startIndex.Value; - else - start = startIndex.Value; - - int end; - var endIndex = End; - if (endIndex.IsFromEnd) - end = length - endIndex.Value; - else - end = endIndex.Value; - - if ((uint)end > (uint)length || (uint)start > (uint)end) - { - throw new ArgumentOutOfRangeException(nameof(length)); - } - - return (start, end - start); - } - - public static implicit operator TensorIndex(Range range) - { - long? start = !range.Start.IsFromEnd ? range.Start.Value : -1 * range.Start.Value; - var stop = !range.End.IsFromEnd ? new long?(range.End.Value) : range.End.Value == 0 ? null : new long?(-1 * range.End.Value); - return TensorIndex.Slice(start, stop); - } - } -} diff --git a/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs b/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs new file mode 100644 index 0000000000..50f10eb431 --- /dev/null +++ b/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using static TorchSharp.torch; + +namespace Microsoft.ML.TorchSharp +{ + internal static class RangeUtil + { + public static TensorIndex ToTensorIndex(this Range range) + { + long? start = !range.Start.IsFromEnd ? range.Start.Value : -1 * range.Start.Value; + var stop = !range.End.IsFromEnd ? new long?(range.End.Value) : range.End.Value == 0 ? null : new long?(-1 * range.End.Value); + return TensorIndex.Slice(start, stop); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj index 8c65cf0621..149962617d 100644 --- a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj +++ b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj @@ -5,6 +5,10 @@ None + + true + + diff --git a/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj b/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj index 4bff917a66..af3f6b1d13 100644 --- a/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj +++ b/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj @@ -5,6 +5,10 @@ None + + true + + diff --git a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj index 09faf80224..ab5b0aba34 100644 --- a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj +++ b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj @@ -5,6 +5,10 @@ $(NoWarn);MSML_ParameterLocalVarName;MSML_PrivateFieldName;MSML_ExtendBaseTestClass;MSML_GeneralName + + true + + diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 2ded80987a..f07f80089e 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -6,6 +6,7 @@ $(NoWarn);MSML_ExtendBaseTestClass enable true + true diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj index d135f09bbb..62d0fed2fd 100644 --- a/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj +++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj @@ -8,6 +8,10 @@ true + + true + + diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj index 4715947431..6852856a4e 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj @@ -8,6 +8,10 @@ true + + true + + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index dec8dbbb25..d86f06c8a0 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -6,6 +6,7 @@ $(NoWarn);MSML_ExtendBaseTestClass enable true + true diff --git a/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj b/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj index fe4dce9c2e..0bb5927412 100644 --- a/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj @@ -7,6 +7,10 @@ enable + + true + + diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 1fbb56128f..6fb5619660 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -251,7 +251,7 @@ public void SimpleTestWithUnknownToken( try { - BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: unknownToken, + BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: unknownToken, continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken); Tokenizer tokenizer = bpe; IReadOnlyList encoding = tokenizer.EncodeToTokens(sentence, out _); @@ -274,7 +274,7 @@ public void SimpleTestWithUnknownToken( for (int i = 0; i < encoding.Count; i++) { Assert.Equal(expectedTokens[i], encoding[i].Value); - Assert.Equal(offsets[i], encoding[i].Offset); + Assert.Equal(offsets[i], (encoding[i].Offset.Start.Value, encoding[i].Offset.End.Value - encoding[i].Offset.Start.Value)); Assert.Equal(ids[i], encoding[i].Id); Assert.Equal(ids[i], idsList[i]); Assert.Equal(encoding[i].Value, reverseVocabulary[encodingIds[i]]); @@ -430,11 +430,11 @@ public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, i IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); @@ -472,6 +472,62 @@ public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, i Assert.Equal(3, tokenCount); } + [Fact] + public void TestWithAddedTokens() + { + // Picked from https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/raw/main/tokenizer.json + IReadOnlyDictionary addedTokens = new Dictionary() + { + {"<|endoftext|>", 0 }, + {"<|im_start|>", 1 }, + {"<|im_end|>", 2 }, + {"", 3 }, + {"", 4 }, + {"", 5 }, + {"", 6 }, + {"", 7 }, + {"", 8 }, + {"", 9 }, + {"", 10 }, + {"", 11 }, + {"", 12 }, + {"", 13 }, + {"", 14 }, + {"", 15 }, + {"", 16 }, + }; + + using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json")); + using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt")); + + var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWhiteSpace(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>"); + + string input = "Hello, y'all! How are you 😁 ?<|endoftext|>"; + + IReadOnlyList tokens = bpeTokenizer.EncodeToTokens(input, out _); + + EncodedToken[] expectedTokens = [ + new EncodedToken(15496, "Hello", new Range(0, 5)), + new EncodedToken(11, ",", new Range(5, 6)), + new EncodedToken(88, "y", new Range(7, 8)), + new EncodedToken(6, "'", new Range(8, 9)), + new EncodedToken(439, "all", new Range(9, 12)), + new EncodedToken(0, "!", new Range(12, 13)), + new EncodedToken(9, "", new Range(14, 29)), + new EncodedToken(2437, "How", new Range(29, 32)), + new EncodedToken(533, "are", new Range(33, 36)), + new EncodedToken(5832, "you", new Range(37, 40)), + new EncodedToken(50256, "<|endoftext|>", new Range(41, 43)), + new EncodedToken(30, "?", new Range(44, 45)), + new EncodedToken(0, "<|endoftext|>", new Range(45, 58)) + ]; + + Assert.Equal(expectedTokens, tokens); + + IReadOnlyList ids = bpeTokenizer.EncodeToIds(input); + Assert.Equal(expectedTokens.Select(t => t.Id).ToArray(), ids); + } + private static string WriteToMergeFile((string, string)[] mergeEntries) { string fileName = Utils.CreateTemporaryFile("txt"); @@ -500,7 +556,7 @@ internal static BpeTokenizer CreateEmptyBpe(PreTokenizer? preTokenizer = null, N emptyVocabStream.Position = 0; return BpeTokenizer.Create( - vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpacePreTokenizer.Instance, normalizer: normalizer, unknownToken: "Ukn"); + vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWhiteSpace(), normalizer: normalizer, unknownToken: "Ukn"); } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs index a4273f040c..4965ce064a 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs @@ -235,13 +235,13 @@ private void ValidateEncoding(IReadOnlyList encoding, bool addPref { Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); } else { Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); } } @@ -555,22 +555,22 @@ public void TestBegginingAndEndOfSentenceEncoding( tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); idList = new List(expectedIdsWithSpace); idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value); @@ -579,32 +579,32 @@ public void TestBegginingAndEndOfSentenceEncoding( encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); + Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value)); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); IReadOnlyList ids = codeGenTokenizer.EncodeToIds(text); Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]); @@ -688,22 +688,22 @@ public void TestBegginingAndEndOfSentenceEncoding( tokensList.Add(codeGenTokenizer.EndOfSentenceToken!); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); idList = new List(expectedIdsWithSpace); idList.Add(codeGenTokenizer.EndOfSentenceId!.Value); @@ -712,32 +712,32 @@ public void TestBegginingAndEndOfSentenceEncoding( encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); ids = codeGenTokenizer.EncodeToIds(text); Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]); @@ -823,26 +823,26 @@ public void TestBegginingAndEndOfSentenceEncoding( tokensList.Add(codeGenTokenizer.EndOfSentenceToken!); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); idList = new List(expectedIdsWithSpace); idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value); @@ -853,38 +853,38 @@ public void TestBegginingAndEndOfSentenceEncoding( encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _); Assert.Equal(idList, encoding.Select(t => t.Id).ToArray()); Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray()); - Assert.Equal((0, 0), encoding[0].Offset); - Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.Equal(new Range(0, 0), encoding[0].Offset); + Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _); Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray()); - Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0)); - Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset); + Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0))); + Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset); ids = codeGenTokenizer.EncodeToIds(text); Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]); diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index 9014d208e1..56dec4f144 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -182,11 +182,11 @@ public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Ind IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); @@ -264,7 +264,7 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call } int[] encodingIds = encoding.Select(t => t.Id).ToArray(); - (int, int)[] offsets = encoding.Select(t => t.Offset).ToArray(); + (int, int)[] offsets = encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray(); string[] tokens = encoding.Select(t => t.Value).ToArray(); Assert.Equal(p[1], encodingIds); diff --git a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs index 6d7178ac2d..7bd41bda45 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs @@ -66,7 +66,7 @@ private static Tokenizer CreateLPhi3Tokenizer(bool treatWhitespaceAsSuffix = fal if (treatWhitespaceAsSuffix) { - PropertyInfo? propertyInfo = typeof(SentencePieceBpeTokenizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); + PropertyInfo? propertyInfo = typeof(SentencePieceTokenizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); if (propertyInfo != null) { propertyInfo.SetValue(tokenizer, true); @@ -244,7 +244,7 @@ public void TestLlamaTokenizer(Tokenizer tokenizer, string input, int[] ids, str IReadOnlyList result = llamaTokenizer.EncodeToTokens(input, out _); Assert.Equal(ids, result.Select(t => t.Id).ToArray()); Assert.Equal(tokens, result.Select(t => t.Value).ToArray()); - Assert.Equal(offsets, result.Select(t => t.Offset).ToArray()); + Assert.Equal(offsets, result.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(input, llamaTokenizer.Decode(ids)); TestDecodingWithSpan(bpe, ids, input); Assert.Equal(ids, llamaTokenizer.EncodeToIds(input)); @@ -501,14 +501,14 @@ public void TestTokenizerEncoding(string text, string normalizedText, string[] e IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); - SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!; + SentencePieceTokenizer sentencePieceBpe = (tokenizer as SentencePieceTokenizer)!; foreach (bool considerNormalization in new[] { true, false }) foreach (bool addBeginningOfSentence in new[] { true, false }) foreach (bool addEndOfSentence in new[] { true, false }) @@ -539,7 +539,7 @@ public void TestTokenizerEncoding(string text, string normalizedText, string[] e expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1; Assert.Equal(expectedTokens1, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets1, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets1, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds1, encoding.Select(t => t.Id).ToArray()); } } @@ -562,7 +562,7 @@ public void TestTokenizerEncodingToIds(string text, string normalizedText, strin Assert.Equal(normalizedText, normalizedString); Assert.Equal(normalizedText.Length, length); - SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!; + SentencePieceTokenizer sentencePieceBpe = (tokenizer as SentencePieceTokenizer)!; foreach (bool considerNormalization in new[] { true, false }) foreach (bool addBeginningOfSentence in new[] { true, false }) foreach (bool addEndOfSentence in new[] { true, false }) diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj index b4a386bc40..e0d08c93aa 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj @@ -10,6 +10,10 @@ + + true + + diff --git a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs index f048a6a209..3d77179dfd 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs @@ -18,14 +18,14 @@ public static IEnumerable PreTokenizerData { yield return new object[] { - WhiteSpacePreTokenizer.Instance, + PreTokenizer.CreateWhiteSpace(), "How are you doing?", new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), } }; yield return new object[] { - WhiteSpacePreTokenizer.Instance, + PreTokenizer.CreateWhiteSpace(), "I_am_Just_Fine!", new (int Offset, int Length)[] { (0, 14), (14, 1) } }; @@ -63,7 +63,7 @@ public void TestPreTokenizer(PreTokenizer preTokenizer, string text, (int Offset [Fact] public void TestWhiteSpacePreTokenizer() { - Assert.Empty(WhiteSpacePreTokenizer.Instance.PreTokenize((string)null!)); + Assert.Empty(PreTokenizer.CreateWhiteSpace().PreTokenize((string)null!)); } public class SpacePreTokenizer : PreTokenizer diff --git a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs index bf75e51ec0..a8df1cc982 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs @@ -145,10 +145,10 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer) int[] ids = result.Select(token => token.Id).ToArray(); string[] tokens = result.Select(token => token.Value).ToArray(); - (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + Range[] offsets = result.Select(token => token.Offset).ToArray(); Assert.Equal(encoded, ids); Assert.Equal(new string[] { "Hello", " World" }, tokens); - Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, offsets); + Assert.Equal(new List { new Range(0, 5), new Range(5, 11) }, offsets); Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded, ids); @@ -198,7 +198,7 @@ public void TestEncode1() int[] ids = result.Select(token => token.Id).ToArray(); string[] tokens = result.Select(token => token.Value).ToArray(); - (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + (int, int)[] offsets = result.Select(token => (token.Offset.Start.Value, token.Offset.End.Value - token.Offset.Start.Value)).ToArray(); Assert.Equal(encoded, ids); Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, tokens); @@ -239,7 +239,7 @@ public void TestEncode3() IReadOnlyList result = GPT4.EncodeToTokens(text, out string? normalizedString); int[] ids = result.Select(token => token.Id).ToArray(); string[] tokens = result.Select(token => token.Value).ToArray(); - (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + (int, int)[] offsets = result.Select(token => (token.Offset.Start.Value, token.Offset.End.Value - token.Offset.Start.Value)).ToArray(); int idsCount = GPT4.CountTokens(text); Assert.Equal(encoded, ids); @@ -275,7 +275,7 @@ public void TestEncode5() Assert.Equal(encoded, result.Select(token => token.Id).ToArray()); Assert.Equal(encoded.Count, idsCount); Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "⭐", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray()); - Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (18, 1), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray()); + Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (18, 1), (19, 6), (25, 10) }, result.Select(token => (token.Offset.Start.Value, token.Offset.End.Value - token.Offset.Start.Value)).ToArray()); } [Fact] @@ -310,7 +310,7 @@ public void TestEncodeGpt4o() Assert.Equal(encoded, result.Select(token => token.Id).ToArray()); Assert.Equal(encoded.Count, idsCount); Assert.Equal(new string[] { "<|endoftext|>", "Hello", " ⭐", " World", "<|endofprompt|>" }, result.Select(token => token.Value).ToArray()); - Assert.Equal(new List<(int, int)> { (0, 13), (13, 5), (18, 2), (20, 6), (26, 15) }, result.Select(token => token.Offset).ToArray()); + Assert.Equal(new List<(int, int)> { (0, 13), (13, 5), (18, 2), (20, 6), (26, 15) }, result.Select(token => (token.Offset.Start.Value, token.Offset.End.Value - token.Offset.Start.Value)).ToArray()); TokenizerTests.TestTokenLimits(GPT4o); } @@ -392,6 +392,8 @@ public void TestEncodeR50kBase() } [Theory] + [InlineData("o1")] + [InlineData("o1-")] [InlineData("gpt-4o")] [InlineData("gpt-4o-")] [InlineData("gpt-4")] @@ -493,6 +495,7 @@ public void TestEncodingNamesNegativeCases() [InlineData("gpt-4")] [InlineData("gpt-4o")] + [InlineData("o1")] [InlineData("text-davinci-003")] [InlineData("text-curie-001")] [InlineData("text-davinci-edit-001")] @@ -566,11 +569,11 @@ public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Ind IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _); Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); - Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray()); Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); @@ -679,7 +682,7 @@ public void TestPreciseTokenLimits(string text, string[] expectedTokens, (int In { IReadOnlyList result = GPT4.EncodeToTokens(text, out _); int[] ids = result.Select(r => r.Id).ToArray(); - (int Index, int Length)[] offsets = result.Select(r => r.Offset).ToArray(); + (int Index, int Length)[] offsets = result.Select(r => (r.Offset.Start.Value, r.Offset.End.Value - r.Offset.Start.Value)).ToArray(); Assert.Equal(expectedTokens, result.Select(r => r.Value)); Assert.Equal(expectedIds, ids); Assert.Equal(expectedOffsets, offsets); diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index df02916d7d..a982e7303f 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -112,7 +112,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read if (count >= settings.MaxTokenCount) break; - tokens.Add(new EncodedToken(c - 'a', c.ToString(), (count, 1))); + tokens.Add(new EncodedToken(c - 'a', c.ToString(), new Range(count, count + 1))); count++; } @@ -152,7 +152,7 @@ internal static void TestTokenLimits(Tokenizer tokenizer) { string prefixString = (processedText1 ?? input).Substring(0, index1); - if (tokenizer is SentencePieceBpeTokenizer) + if (tokenizer is SentencePieceTokenizer) { // SentencePieceBpe model normalize the text and insert more characters. // We call the model directly to bypass the normalization step @@ -170,7 +170,7 @@ internal static void TestTokenLimits(Tokenizer tokenizer) { string suffixString = (processedText2 ?? input).Substring(index2); - if (tokenizer is SentencePieceBpeTokenizer) + if (tokenizer is SentencePieceTokenizer) { // SentencePieceBpe model normalize the text and insert more characters. // We call the model directly to bypass the normalization step diff --git a/test/Microsoft.ML.TorchSharp.Tests/Microsoft.ML.TorchSharp.Tests.csproj b/test/Microsoft.ML.TorchSharp.Tests/Microsoft.ML.TorchSharp.Tests.csproj index 138d001b98..0d5f6541a8 100644 --- a/test/Microsoft.ML.TorchSharp.Tests/Microsoft.ML.TorchSharp.Tests.csproj +++ b/test/Microsoft.ML.TorchSharp.Tests/Microsoft.ML.TorchSharp.Tests.csproj @@ -8,6 +8,10 @@ + + true + +