Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] feature: Ollama integration #2693

Merged
merged 13 commits into from
May 15, 2024
Merged
14 changes: 14 additions & 0 deletions dotnet/AutoGen.sln
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Test
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama.Tests", "test\Autogen.Ollama.Tests\Autogen.Ollama.Tests.csproj", "{C24FDE63-952D-4F8E-A807-AF31D43AD675}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -91,6 +95,14 @@ Global
{15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.Build.0 = Release|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.Build.0 = Debug|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.ActiveCfg = Release|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.Build.0 = Release|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand All @@ -116,6 +128,8 @@ Global
{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
Expand Down
216 changes: 216 additions & 0 deletions dotnet/src/Autogen.Ollama/Agent/OllamaAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaAgent.cs

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;

namespace Autogen.Ollama;

/// <summary>
/// An agent that can interact with ollama models.
/// </summary>
public class OllamaAgent : IStreamingAgent
{
private readonly HttpClient _httpClient;
public string Name { get; }
private readonly string _modelName;
private readonly string _systemMessage;
private readonly OllamaReplyOptions? _replyOptions;

public OllamaAgent(HttpClient httpClient, string name, string modelName,
string systemMessage = "You are a helpful AI assistant",
OllamaReplyOptions? replyOptions = null)
{
Name = name;
_httpClient = httpClient;
_modelName = modelName;
_systemMessage = systemMessage;
_replyOptions = replyOptions;
}
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default)
{
ChatRequest request = await BuildChatRequest(messages, options);
request.Stream = false;
using (HttpResponseMessage? response = await _httpClient
.SendAsync(BuildRequestMessage(request), HttpCompletionOption.ResponseContentRead, cancellation))
{
response.EnsureSuccessStatusCode();
Stream? streamResponse = await response.Content.ReadAsStreamAsync();
ChatResponse chatResponse = await JsonSerializer.DeserializeAsync<ChatResponse>(streamResponse, cancellationToken: cancellation)
?? throw new Exception("Failed to deserialize response");
var output = new MessageEnvelope<ChatResponse>(chatResponse, from: Name);
return output;
}
}
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ChatRequest request = await BuildChatRequest(messages, options);
request.Stream = true;
HttpRequestMessage message = BuildRequestMessage(request);
using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
{
response.EnsureSuccessStatusCode();
using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
using var reader = new StreamReader(stream);

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
string? line = await reader.ReadLineAsync();
if (string.IsNullOrWhiteSpace(line)) continue;

ChatResponseUpdate? update = JsonSerializer.Deserialize<ChatResponseUpdate>(line);
if (update != null)
{
yield return new MessageEnvelope<ChatResponseUpdate>(update, from: Name);
}

if (update is { Done: false }) continue;

ChatResponse? chatMessage = JsonSerializer.Deserialize<ChatResponse>(line);
if (chatMessage == null) continue;
yield return new MessageEnvelope<ChatResponse>(chatMessage, from: Name);
}
}
}
private async Task<ChatRequest> BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options)
{
var request = new ChatRequest
{
Model = _modelName,
Messages = await BuildChatHistory(messages)
};

if (options is OllamaReplyOptions replyOptions)
{
BuildChatRequestOptions(replyOptions, request);
return request;
}

if (_replyOptions != null)
{
BuildChatRequestOptions(_replyOptions, request);
return request;
}
return request;
}
private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request)
{
request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null;
request.Template = replyOptions.Template;
request.KeepAlive = replyOptions.KeepAlive;

if (replyOptions.Temperature != null
|| replyOptions.MaxToken != null
|| replyOptions.StopSequence != null
|| replyOptions.Seed != null
|| replyOptions.MiroStat != null
|| replyOptions.MiroStatEta != null
|| replyOptions.MiroStatTau != null
|| replyOptions.NumCtx != null
|| replyOptions.NumGqa != null
|| replyOptions.NumGpu != null
|| replyOptions.NumThread != null
|| replyOptions.RepeatLastN != null
|| replyOptions.RepeatPenalty != null
|| replyOptions.TopK != null
|| replyOptions.TopP != null
|| replyOptions.TfsZ != null)
{
request.Options = new ModelReplyOptions
{
Temperature = replyOptions.Temperature,
NumPredict = replyOptions.MaxToken,
Stop = replyOptions.StopSequence?[0],
Seed = replyOptions.Seed,
MiroStat = replyOptions.MiroStat,
MiroStatEta = replyOptions.MiroStatEta,
MiroStatTau = replyOptions.MiroStatTau,
NumCtx = replyOptions.NumCtx,
NumGqa = replyOptions.NumGqa,
NumGpu = replyOptions.NumGpu,
NumThread = replyOptions.NumThread,
RepeatLastN = replyOptions.RepeatLastN,
RepeatPenalty = replyOptions.RepeatPenalty,
TopK = replyOptions.TopK,
TopP = replyOptions.TopP,
TfsZ = replyOptions.TfsZ
};
}
}
private async Task<List<Message>> BuildChatHistory(IEnumerable<IMessage> messages)
{
if (!messages.Any(m => m.IsSystemMessage()))
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: Name);
messages = new[] { systemMessage }.Concat(messages);
}

var collection = new List<Message>();
foreach (IMessage? message in messages)
{
Message item;
switch (message)
{
case TextMessage tm:
item = new Message { Role = tm.Role.ToString(), Value = tm.Content };
break;
case ImageMessage im:
string base64Image = await ImageUrlToBase64(im.Url!);
item = new Message { Role = im.Role.ToString(), Images = [base64Image] };
break;
case MultiModalMessage mm:
var textsGroupedByRole = mm.Content.OfType<TextMessage>().GroupBy(tm => tm.Role)
.ToDictionary(g => g.Key, g => string.Join(Environment.NewLine, g.Select(tm => tm.Content)));

string content = string.Join($"{Environment.NewLine}", textsGroupedByRole
.Select(g => $"{g.Key}{Environment.NewLine}:{g.Value}"));

IEnumerable<Task<string>> imagesConversionTasks = mm.Content
.OfType<ImageMessage>()
.Select(async im => await ImageUrlToBase64(im.Url!));

string[]? imagesBase64 = await Task.WhenAll(imagesConversionTasks);
item = new Message { Role = mm.Role.ToString(), Value = content, Images = imagesBase64 };
break;
default:
throw new NotSupportedException();
}

collection.Add(item);
}

return collection;
}
private static HttpRequestMessage BuildRequestMessage(ChatRequest request)
{
string serialized = JsonSerializer.Serialize(request);
return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint)
{
Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType)
};
}
private async Task<string> ImageUrlToBase64(string imageUrl)
{
if (string.IsNullOrWhiteSpace(imageUrl))
{
throw new ArgumentException("required parameter", nameof(imageUrl));
}
byte[] imageBytes = await _httpClient.GetByteArrayAsync(imageUrl);
return imageBytes != null
? Convert.ToBase64String(imageBytes)
: throw new InvalidOperationException("no image byte array");
}
}
12 changes: 12 additions & 0 deletions dotnet/src/Autogen.Ollama/Autogen.Ollama.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>

</Project>
54 changes: 54 additions & 0 deletions dotnet/src/Autogen.Ollama/DTOs/ChatRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatRequest.cs

using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Autogen.Ollama;

public class ChatRequest
{
/// <summary>
/// (required) the model name
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;

/// <summary>
/// the messages of the chat, this can be used to keep a chat memory
/// </summary>
[JsonPropertyName("messages")]
public IList<Message> Messages { get; set; } = Array.Empty<Message>();

/// <summary>
/// the format to return a response in. Currently, the only accepted value is json
/// </summary>
[JsonPropertyName("format")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Format { get; set; }

/// <summary>
/// additional model parameters listed in the documentation for the Modelfile such as temperature
/// </summary>
[JsonPropertyName("options")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ModelReplyOptions? Options { get; set; }
/// <summary>
/// the prompt template to use (overrides what is defined in the Modelfile)
/// </summary>
[JsonPropertyName("template")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Template { get; set; }
/// <summary>
/// if false the response will be returned as a single response object, rather than a stream of objects
/// </summary>
[JsonPropertyName("stream")]
public bool Stream { get; set; }
/// <summary>
/// controls how long the model will stay loaded into memory following the request (default: 5m)
/// </summary>
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }
}
45 changes: 45 additions & 0 deletions dotnet/src/Autogen.Ollama/DTOs/ChatResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatResponse.cs

using System.Text.Json.Serialization;

namespace Autogen.Ollama;

public class ChatResponse : ChatResponseUpdate
{
/// <summary>
/// time spent generating the response
/// </summary>
[JsonPropertyName("total_duration")]
public long TotalDuration { get; set; }

/// <summary>
/// time spent in nanoseconds loading the model
/// </summary>
[JsonPropertyName("load_duration")]
public long LoadDuration { get; set; }

/// <summary>
/// number of tokens in the prompt
/// </summary>
[JsonPropertyName("prompt_eval_count")]
public int PromptEvalCount { get; set; }

/// <summary>
/// time spent in nanoseconds evaluating the prompt
/// </summary>
[JsonPropertyName("prompt_eval_duration")]
public long PromptEvalDuration { get; set; }

/// <summary>
/// number of tokens the response
/// </summary>
[JsonPropertyName("eval_count")]
public int EvalCount { get; set; }

/// <summary>
/// time in nanoseconds spent generating the response
/// </summary>
[JsonPropertyName("eval_duration")]
public long EvalDuration { get; set; }
}
Loading
Loading