Skip to content

Commit

Permalink
.Net: Added logprobs property to OpenAIPromptExecutionSettings (#6300)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Fixes: #6277


https://platform.openai.com/docs/api-reference/chat/create#chat-create-logprobs

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored May 20, 2024
1 parent 3e19785 commit 5e0b757
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr
Echo = false,
ChoicesPerPrompt = executionSettings.ResultsPerPrompt,
GenerationSampleCount = executionSettings.ResultsPerPrompt,
LogProbabilityCount = null,
LogProbabilityCount = executionSettings.TopLogprobs,
User = executionSettings.User,
DeploymentName = deploymentOrModelName
};
Expand Down Expand Up @@ -1102,7 +1102,9 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(
ChoiceCount = executionSettings.ResultsPerPrompt,
DeploymentName = deploymentOrModelName,
Seed = executionSettings.Seed,
User = executionSettings.User
User = executionSettings.User,
LogProbabilitiesPerToken = executionSettings.TopLogprobs,
EnableLogProbabilities = executionSettings.Logprobs
};

switch (executionSettings.ResponseFormat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,39 @@ public string? User
}
}

/// <summary>
/// Whether to return log probabilities of the output tokens or not.
/// If true, returns the log probabilities of each output token returned in the `content` of `message`.
/// </summary>
[Experimental("SKEXP0010")]
[JsonPropertyName("logprobs")]
public bool? Logprobs
{
get => this._logprobs;

set
{
this.ThrowIfFrozen();
this._logprobs = value;
}
}

/// <summary>
/// An integer specifying the number of most likely tokens to return at each token position, each with an associated log probability.
/// </summary>
[Experimental("SKEXP0010")]
[JsonPropertyName("top_logprobs")]
public int? TopLogprobs
{
get => this._topLogprobs;

set
{
this.ThrowIfFrozen();
this._topLogprobs = value;
}
}

/// <inheritdoc/>
public override void Freeze()
{
Expand Down Expand Up @@ -294,7 +327,9 @@ public override PromptExecutionSettings Clone()
TokenSelectionBiases = this.TokenSelectionBiases is not null ? new Dictionary<int, int>(this.TokenSelectionBiases) : null,
ToolCallBehavior = this.ToolCallBehavior,
User = this.User,
ChatSystemPrompt = this.ChatSystemPrompt
ChatSystemPrompt = this.ChatSystemPrompt,
Logprobs = this.Logprobs,
TopLogprobs = this.TopLogprobs
};
}

Expand Down Expand Up @@ -370,6 +405,8 @@ public static OpenAIPromptExecutionSettings FromExecutionSettingsWithData(Prompt
private ToolCallBehavior? _toolCallBehavior;
private string? _user;
private string? _chatSystemPrompt;
private bool? _logprobs;
private int? _topLogprobs;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
ResultsPerPrompt = 5,
Seed = 567,
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
StopSequences = ["stop_sequence"]
StopSequences = ["stop_sequence"],
Logprobs = true,
TopLogprobs = 5
};

var chatHistory = new ChatHistory();
Expand Down Expand Up @@ -218,6 +220,8 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
Assert.Equal(567, content.GetProperty("seed").GetInt32());
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
Assert.True(content.GetProperty("logprobs").GetBoolean());
Assert.Equal(5, content.GetProperty("top_logprobs").GetInt32());
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults()
Assert.Equal(1, executionSettings.ResultsPerPrompt);
Assert.Null(executionSettings.StopSequences);
Assert.Null(executionSettings.TokenSelectionBiases);
Assert.Null(executionSettings.TopLogprobs);
Assert.Null(executionSettings.Logprobs);
Assert.Equal(128, executionSettings.MaxTokens);
}

Expand All @@ -47,6 +49,8 @@ public void ItUsesExistingOpenAIExecutionSettings()
StopSequences = new string[] { "foo", "bar" },
ChatSystemPrompt = "chat system prompt",
MaxTokens = 128,
Logprobs = true,
TopLogprobs = 5,
TokenSelectionBiases = new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } },
};

Expand Down Expand Up @@ -97,6 +101,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
{ "max_tokens", 128 },
{ "token_selection_biases", new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 },
}
};

Expand All @@ -105,7 +111,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()

// Assert
AssertExecutionSettings(executionSettings);
Assert.Equal(executionSettings.Seed, 123456);
}

[Fact]
Expand All @@ -124,7 +129,10 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
{ "stop_sequences", new [] { "foo", "bar" } },
{ "chat_system_prompt", "chat system prompt" },
{ "max_tokens", "128" },
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } }
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 }
}
};

Expand All @@ -149,7 +157,10 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
"stop_sequences": [ "foo", "bar" ],
"chat_system_prompt": "chat system prompt",
"token_selection_biases": { "1": 2, "3": 4 },
"max_tokens": 128
"max_tokens": 128,
"seed": 123456,
"logprobs": true,
"top_logprobs": 5
}
""";
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
Expand Down Expand Up @@ -255,5 +266,8 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut
Assert.Equal("chat system prompt", executionSettings.ChatSystemPrompt);
Assert.Equal(new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } }, executionSettings.TokenSelectionBiases);
Assert.Equal(128, executionSettings.MaxTokens);
Assert.Equal(123456, executionSettings.Seed);
Assert.Equal(true, executionSettings.Logprobs);
Assert.Equal(5, executionSettings.TopLogprobs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
PresencePenalty = 1.2,
ResultsPerPrompt = 5,
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
StopSequences = ["stop_sequence"]
StopSequences = ["stop_sequence"],
TopLogprobs = 5
};

this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
Expand Down Expand Up @@ -154,6 +155,7 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
Assert.Equal(5, content.GetProperty("best_of").GetInt32());
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
Assert.Equal(5, content.GetProperty("logprobs").GetInt32());
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http.Resilience;
Expand Down Expand Up @@ -504,6 +505,38 @@ public async Task SemanticKernelVersionHeaderIsSentAsync()
Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
}

[Theory(Skip = "This test is for manual verification.")]
[InlineData(null, null)]
[InlineData(false, null)]
[InlineData(true, 2)]
[InlineData(true, 5)]
public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs)
{
// Arrange
var settings = new OpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs };

this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
Kernel target = builder.Build();

// Act
var result = await target.InvokePromptAsync("Hi, can you help me today?", new(settings));

var logProbabilityInfo = result.Metadata?["LogProbabilityInfo"] as ChatChoiceLogProbabilityInfo;

// Assert
if (logprobs is true)
{
Assert.NotNull(logProbabilityInfo);
Assert.Equal(topLogprobs, logProbabilityInfo.TokenLogProbabilityResults[0].TopLogProbabilityEntries.Count);
}
else
{
Assert.Null(logProbabilityInfo);
}
}

#region internals

private readonly XunitLogger<Kernel> _logger = new(output);
Expand Down

0 comments on commit 5e0b757

Please sign in to comment.