Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama support for streaming function calling and native structured output #5730

Merged
57 changes: 44 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace Microsoft.Extensions.AI;
public sealed class OllamaChatClient : IChatClient
{
private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement;
private static readonly JsonElement _defaultJsonSchema = JsonDocument.Parse("\"json\"").RootElement;
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>The api/chat endpoint URI.</summary>
private readonly Uri _apiChatEndpoint;
Expand Down Expand Up @@ -111,15 +112,6 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
{
_ = Throw.IfNull(chatMessages);

if (options?.Tools is { Count: > 0 })
{
// We can actually make it work by using the /generate endpoint like the eShopSupport sample does,
// but it's complicated. Really it should be Ollama's job to support this.
throw new NotSupportedException(
"Currently, Ollama does not support function calls in streaming mode. " +
"See Ollama docs at https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1 to see whether support has since been added.");
}

using HttpRequestMessage request = new(HttpMethod.Post, _apiChatEndpoint)
{
Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest)
Expand Down Expand Up @@ -158,7 +150,22 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs

if (chunk.Message is { } message)
{
update.Contents.Add(new TextContent(message.Content));
if (message.ToolCalls is { Length: > 0 })
{
foreach (var toolCall in message.ToolCalls)
{
if (toolCall.Function is { } function)
{
update.Contents.Add(ToFunctionCallContent(function));
}
}
}

// Equivalent rule to the nonstreaming case
if (message.Content?.Length > 0 || update.Contents.Count == 0)
{
update.Contents.Insert(0, new TextContent(message.Content));
}
}

if (ParseOllamaChatResponseUsage(chunk) is { } usage)
Expand Down Expand Up @@ -231,8 +238,7 @@ private static ChatMessage FromOllamaMessage(OllamaChatResponseMessage message)
{
if (toolCall.Function is { } function)
{
var id = Guid.NewGuid().ToString().Substring(0, 8);
contents.Add(new FunctionCallContent(id, function.Name, function.Arguments));
contents.Add(ToFunctionCallContent(function));
}
}
}
Expand All @@ -247,11 +253,36 @@ private static ChatMessage FromOllamaMessage(OllamaChatResponseMessage message)
return new ChatMessage(new(message.Role), contents);
}

private static FunctionCallContent ToFunctionCallContent(OllamaFunctionToolCall function)
{
var id = Guid.NewGuid().ToString().Substring(0, 8);
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
return new FunctionCallContent(id, function.Name, function.Arguments);
}

private static JsonElement? ToOllamaChatResponseFormat(ChatResponseFormat? format)
{
if (format is ChatResponseFormatJson jsonFormat)
{
if (!string.IsNullOrEmpty(jsonFormat.Schema))
{
return JsonDocument.Parse(jsonFormat.Schema!).RootElement;
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
return _defaultJsonSchema;
}
}
else
{
return null;
}
}

private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, ChatOptions? options, bool stream)
{
OllamaChatRequest request = new()
{
Format = options?.ResponseFormat is ChatResponseFormatJson ? "json" : null,
Format = ToOllamaChatResponseFormat(options?.ResponseFormat),
Messages = chatMessages.SelectMany(ToOllamaChatRequestMessages).ToArray(),
Model = options?.ModelId ?? Metadata.ModelId ?? string.Empty,
Stream = stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Text.Json;

namespace Microsoft.Extensions.AI;

internal sealed class OllamaChatRequest
{
public required string Model { get; set; }
public required OllamaChatRequestMessage[] Messages { get; set; }
public string? Format { get; set; }
public JsonElement? Format { get; set; }
public bool Stream { get; set; }
public IEnumerable<OllamaTool>? Tools { get; set; }
public OllamaRequestOptions? Options { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,33 @@ public virtual async Task CompleteAsync_StructuredOutput_WithFunctions()
Assert.Equal(expectedPerson.Job, response.Result.Job);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutput_Native()
{
SkipIfNotEnabled();

var capturedCalls = new List<IList<ChatMessage>>();
var captureOutputChatClient = _chatClient.AsBuilder()
.Use((messages, options, nextAsync, cancellationToken) =>
{
capturedCalls.Add([.. messages]);
return nextAsync(messages, options, cancellationToken);
})
.Build();

var response = await captureOutputChatClient.CompleteAsync<Person>("""
Supply a JSON object to represent Jimbo Smith from Cardiff.
""", useNativeJsonSchema: true);

Assert.Equal("Jimbo Smith", response.Result.FullName);
Assert.Contains("Cardiff", response.Result.HomeTown);

// Verify it used *native* structured output, i.e., no prompt augmentation
Assert.All(
Assert.Single(capturedCalls),
message => Assert.DoesNotContain("schema", message.Text));
}

private class Person
{
#pragma warning disable S1144, S3459 // Unassigned members should be removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
new OllamaChatClient(endpoint, "llama3.1") :
null;

public override Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_Streaming() =>
throw new SkipTestException("Ollama does not currently support function invocation with streaming.");

public override Task Logging_LogsFunctionCalls_Streaming() =>
throw new SkipTestException("Ollama does not currently support function invocation with streaming.");

public override Task FunctionInvocation_RequireAny() =>
throw new SkipTestException("Ollama does not currently support requiring function invocation.");

Expand Down
Loading