From f5c97babf97539baf2c9afb72250cba5598597a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jochen=20Kirst=C3=A4tter?= Date: Thu, 7 Mar 2024 11:56:09 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=94=20Implement=20Chat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Mscc.GenerativeAI/GenerativeModel.cs | 8 +- src/Mscc.GenerativeAI/Types/ChatSession.cs | 53 +++++++++- src/Mscc.GenerativeAI/Types/Content.cs | 2 + .../Types/GenerateContentResponse.cs | 7 +- src/Mscc.GenerativeAI/Types/Part.cs | 16 ++- tests/GoogleAi_GeminiPro_Should.cs | 99 ++++++++++++++++++- tests/VertexAi_GeminiPro_Should.cs | 51 ++++++++++ 7 files changed, 220 insertions(+), 16 deletions(-) diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs index 82f535d..0deb0b2 100644 --- a/src/Mscc.GenerativeAI/GenerativeModel.cs +++ b/src/Mscc.GenerativeAI/GenerativeModel.cs @@ -383,10 +383,12 @@ public string Name() /// /// /// - public ChatSession StartChat(List? history = null, GenerationConfig? generationConfig = null, List? safetySettings = null, List? tools = null) + public ChatSession StartChat(List? history = null, GenerationConfig? generationConfig = null, List? safetySettings = null, List? tools = null) { - this.tools = tools ?? new List(); - return new ChatSession(); + //this.generationConfig = generationConfig ?? new GenerationConfig(); + //this.safetySettings = safetySettings ?? new List(); + //this.tools = tools ?? new List(); + return new ChatSession(this, history); } /// diff --git a/src/Mscc.GenerativeAI/Types/ChatSession.cs b/src/Mscc.GenerativeAI/Types/ChatSession.cs index 16530ed..a9f4464 100644 --- a/src/Mscc.GenerativeAI/Types/ChatSession.cs +++ b/src/Mscc.GenerativeAI/Types/ChatSession.cs @@ -1,19 +1,62 @@ -using System.Collections.Generic; +#if NET472_OR_GREATER || NETSTANDARD2_0 +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +#endif namespace Mscc.GenerativeAI { public class ChatSession { - public List? History { get; set; } + private readonly GenerativeModel model; + + public List History { get; set; } = []; public List? Tools { get; set; } + /// + /// + /// + /// + /// + public ChatSession(GenerativeModel model, List? history = null) + { + this.model = model; + history ??= []; + History = history; + } - public GenerateContentResponse SendMessage(string content, GenerationConfig? generationConfig, List? safetySettings) + /// + /// + /// + /// + /// + public async Task SendMessage(string prompt) { - return new GenerateContentResponse(); + if (prompt == null) throw new ArgumentNullException(nameof(prompt)); + if (string.IsNullOrEmpty(prompt)) throw new ArgumentException(prompt, nameof(prompt)); + + History.Add(new ContentResponse { Role = "user", Parts = new List { new Part { Text = prompt } } }); + var request = new GenerateContentRequest + { + Contents = History.Select(x => new Content { Role = x.Role, PartTypes = x.Parts }).ToList() + //GenerationConfig = model.GenerationConfig, + //SafetySettings = model.SafetySettings, + //Tools = model.Tools + }; + + var response = await model.GenerateContent(request); + History.Add(new ContentResponse { Role = "model", Parts = new List { new Part { Text = response.Text } } }); + + return response; } - public List SendMessageStream(string content, GenerationConfig? generationConfig, List? safetySettings) + /// + /// + /// + /// + /// + public async Task> SendMessageStream(string prompt) { return new List(); } diff --git a/src/Mscc.GenerativeAI/Types/Content.cs b/src/Mscc.GenerativeAI/Types/Content.cs index e608a3c..1a11462 100644 --- a/src/Mscc.GenerativeAI/Types/Content.cs +++ b/src/Mscc.GenerativeAI/Types/Content.cs @@ -17,6 +17,8 @@ public class Content internal void SynchronizeParts() { + if (Parts == null) return; + PartTypes = new List(); foreach (var part in Parts) { diff --git a/src/Mscc.GenerativeAI/Types/GenerateContentResponse.cs b/src/Mscc.GenerativeAI/Types/GenerateContentResponse.cs index 2fe0a64..5f348f8 100644 --- a/src/Mscc.GenerativeAI/Types/GenerateContentResponse.cs +++ b/src/Mscc.GenerativeAI/Types/GenerateContentResponse.cs @@ -1,4 +1,7 @@ -using System.Collections.Generic; +#if NET472_OR_GREATER || NETSTANDARD2_0 +using System.Collections.Generic; +using System.Linq; +#endif namespace Mscc.GenerativeAI { @@ -9,7 +12,7 @@ public class GenerateContentResponse /// public string? Text { - get { return Candidates?[0].Content?.Parts?[0].Text; } + get { return Candidates?.FirstOrDefault().Content?.Parts?.FirstOrDefault().Text; } } /// diff --git a/src/Mscc.GenerativeAI/Types/Part.cs b/src/Mscc.GenerativeAI/Types/Part.cs index e7c47ff..699f029 100644 --- a/src/Mscc.GenerativeAI/Types/Part.cs +++ b/src/Mscc.GenerativeAI/Types/Part.cs @@ -15,12 +15,24 @@ namespace Mscc.GenerativeAI [DebuggerDisplay("{Text}")] public class Part { + public Part() { } + + /// + /// + /// + /// + public Part(string text) + { + Text = text; + } + /// /// A text part of a conversation with the model. /// - public string Text { + public string Text + { get { return TextData?.Text; } - set { TextData = new TextData { Text = value }; } + set { TextData = new TextData { Text = value }; } } /// [DebuggerHidden] diff --git a/tests/GoogleAi_GeminiPro_Should.cs b/tests/GoogleAi_GeminiPro_Should.cs index d941489..07870d5 100644 --- a/tests/GoogleAi_GeminiPro_Should.cs +++ b/tests/GoogleAi_GeminiPro_Should.cs @@ -297,17 +297,108 @@ public async void Count_Tokens_Request(string prompt, int expected) } [Fact] + public async void Start_Chat() + { + // Arrange + var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model); + var chat = model.StartChat(); + var prompt = "How can I learn more about C#?"; + + // Act + var response = await chat.SendMessage(prompt); + + // Assert + response.Should().NotBeNull(); + response.Candidates.Should().NotBeNull().And.HaveCount(1); + response.Text.Should().NotBeEmpty(); + output.WriteLine(response?.Text); + } + + [Fact] + public async void Start_Chat_With_History() + { + // Arrange + var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model); + var history = new List + { + new ContentResponse { Role = "user", Parts = new List { new Part("Hello") } }, + new ContentResponse { Role = "model", Parts = new List { new Part("Hello! How can I assist you today?") } } + }; + var chat = model.StartChat(history); + var prompt = "How does electricity work?"; + + // Act + var response = await chat.SendMessage(prompt); + + // Assert + response.Should().NotBeNull(); + response.Candidates.Should().NotBeNull().And.HaveCount(1); + response.Text.Should().NotBeEmpty(); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + //output.WriteLine(response?.PromptFeedback); + } + + [Fact] + // Refs: + // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-chat-prompts-gemini + public async void Start_Chat_Multiple_Prompts() + { + // Arrange + var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model); + var chat = model.StartChat(); + + // Act + var prompt = "Hello, let's talk a bit about nature."; + var response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + prompt = "What are all the colors in a rainbow?"; + response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + prompt = "Why does it appear when it rains?"; + response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + + // Assert + response.Should().NotBeNull(); + response.Candidates.Should().NotBeNull().And.HaveCount(1); + response.Text.Should().NotBeEmpty(); + } + + [Fact] + // Refs: + // https://ai.google.dev/tutorials/python_quickstart#chat_conversations + public async void Start_Chat_Conversations() + { + // Arrange + var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model); + var chat = model.StartChat(); + + // Act + _ = await chat.SendMessage("Hello, fancy brainstorming about IT?"); + _ = await chat.SendMessage("In one sentence, explain how a computer works to a young child."); + _ = await chat.SendMessage("Okay, how about a more detailed explanation to a high schooler?"); + _ = await chat.SendMessage("Lastly, give a thorough definition for a CS graduate."); + + // Assert + chat.History.ForEach(c => output.WriteLine($"{c.Role}: {c.Parts[0].Text}")); + } + + [Fact(Skip = "Incomplete")] public async void Start_Chat_Streaming() { // Arrange var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model); var chat = model.StartChat(); - var chatInput1 = "How can I learn more about C#?"; + var prompt = "How can I learn more about C#?"; // Act - //var response = await chat.SendMessageStream(chatInput1); + var response = await chat.SendMessageStream(prompt); - //// Assert + // Assert //response.Should().NotBeNull().And.HaveCountGreaterThanOrEqualTo(1); //response.FirstOrDefault().Should().NotBeNull(); //response.ForEach(x => output.WriteLine(x.Text)); @@ -326,7 +417,7 @@ public async void Function_Calling_Chat() var chatInput1 = "What is the weather in Boston?"; // Act - //var result1 = await chat.SendMessageStream(chatInput1); + //var result1 = await chat.SendMessageStream(prompt); //var response1 = await result1.Response; //var result2 = await chat.SendMessageStream(new List { new FunctionResponse() }); //var response2 = await result2.Response; diff --git a/tests/VertexAi_GeminiPro_Should.cs b/tests/VertexAi_GeminiPro_Should.cs index e3d9fd6..43aae11 100644 --- a/tests/VertexAi_GeminiPro_Should.cs +++ b/tests/VertexAi_GeminiPro_Should.cs @@ -260,6 +260,57 @@ public async void Count_Tokens_Request(string prompt, int expected) output.WriteLine($"Tokens: {response?.TotalTokens}"); } + [Fact] + public async void Start_Chat() + { + // Arrange + var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region); + var model = vertex.GenerativeModel(model: this.model); + model.AccessToken = fixture.AccessToken; + var chat = model.StartChat(); + var prompt = "How can I learn more about C#?"; + + // Act + var response = await chat.SendMessage(prompt); + + // Assert + response.Should().NotBeNull(); + response.Candidates.Should().NotBeNull().And.HaveCount(1); + response.Text.Should().NotBeEmpty(); + output.WriteLine(response?.Text); + } + + [Fact] + // Refs: + // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-chat-prompts-gemini + public async void Start_Chat_Multiple_Prompts() + { + // Arrange + var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region); + var model = vertex.GenerativeModel(model: this.model); + model.AccessToken = fixture.AccessToken; + var chat = model.StartChat(); + + // Act + var prompt = "Hello."; + var response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + prompt = "What are all the colors in a rainbow?"; + response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + prompt = "Why does it appear when it rains?"; + response = await chat.SendMessage(prompt); + output.WriteLine(prompt); + output.WriteLine(response?.Text); + + // Assert + response.Should().NotBeNull(); + response.Candidates.Should().NotBeNull().And.HaveCount(1); + response.Text.Should().NotBeEmpty(); + } + [Fact] public async void Start_Chat_Streaming() {