Skip to content

Commit

Permalink
✔ Implement Chat
Browse files Browse the repository at this point in the history
  • Loading branch information
Jochen Kirstätter committed Mar 7, 2024
1 parent 0417ec5 commit f5c97ba
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 16 deletions.
8 changes: 5 additions & 3 deletions src/Mscc.GenerativeAI/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,12 @@ public string Name()
/// <param name="history"></param>
/// <param name="tools"></param>
/// <returns></returns>
public ChatSession StartChat(List<Content>? history = null, GenerationConfig? generationConfig = null, List<SafetySetting>? safetySettings = null, List<Tool>? tools = null)
public ChatSession StartChat(List<ContentResponse>? history = null, GenerationConfig? generationConfig = null, List<SafetySetting>? safetySettings = null, List<Tool>? tools = null)
{
this.tools = tools ?? new List<Tool>();
return new ChatSession();
//this.generationConfig = generationConfig ?? new GenerationConfig();
//this.safetySettings = safetySettings ?? new List<SafetySetting>();
//this.tools = tools ?? new List<Tool>();
return new ChatSession(this, history);
}

/// <summary>
Expand Down
53 changes: 48 additions & 5 deletions src/Mscc.GenerativeAI/Types/ChatSession.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,62 @@
using System.Collections.Generic;
#if NET472_OR_GREATER || NETSTANDARD2_0
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
#endif

namespace Mscc.GenerativeAI
{
public class ChatSession
{
public List<Content>? History { get; set; }
private readonly GenerativeModel model;

public List<ContentResponse> History { get; set; } = [];
public List<Tool>? Tools { get; set; }

/// <summary>
///
/// </summary>
/// <param name="model"></param>
/// <param name="history"></param>
public ChatSession(GenerativeModel model, List<ContentResponse>? history = null)
{
this.model = model;
history ??= [];
History = history;
}

public GenerateContentResponse SendMessage(string content, GenerationConfig? generationConfig, List<SafetySetting>? safetySettings)
/// <summary>
///
/// </summary>
/// <param name="prompt"></param>
/// <returns></returns>
public async Task<GenerateContentResponse> SendMessage(string prompt)
{
return new GenerateContentResponse();
if (prompt == null) throw new ArgumentNullException(nameof(prompt));
if (string.IsNullOrEmpty(prompt)) throw new ArgumentException(prompt, nameof(prompt));

History.Add(new ContentResponse { Role = "user", Parts = new List<Part> { new Part { Text = prompt } } });
var request = new GenerateContentRequest
{
Contents = History.Select(x => new Content { Role = x.Role, PartTypes = x.Parts }).ToList()
//GenerationConfig = model.GenerationConfig,
//SafetySettings = model.SafetySettings,
//Tools = model.Tools
};

var response = await model.GenerateContent(request);
History.Add(new ContentResponse { Role = "model", Parts = new List<Part> { new Part { Text = response.Text } } });

return response;
}

public List<GenerateContentResponse> SendMessageStream(string content, GenerationConfig? generationConfig, List<SafetySetting>? safetySettings)
/// <summary>
///
/// </summary>
/// <param name="prompt"></param>
/// <returns></returns>
public async Task<List<GenerateContentResponse>> SendMessageStream(string prompt)
{
return new List<GenerateContentResponse>();
}
Expand Down
2 changes: 2 additions & 0 deletions src/Mscc.GenerativeAI/Types/Content.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public class Content

internal void SynchronizeParts()
{
if (Parts == null) return;

PartTypes = new List<Part>();
foreach (var part in Parts)
{
Expand Down
7 changes: 5 additions & 2 deletions src/Mscc.GenerativeAI/Types/GenerateContentResponse.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System.Collections.Generic;
#if NET472_OR_GREATER || NETSTANDARD2_0
using System.Collections.Generic;
using System.Linq;
#endif

namespace Mscc.GenerativeAI
{
Expand All @@ -9,7 +12,7 @@ public class GenerateContentResponse
/// </summary>
public string? Text
{
get { return Candidates?[0].Content?.Parts?[0].Text; }
get { return Candidates?.FirstOrDefault().Content?.Parts?.FirstOrDefault().Text; }
}

/// <summary>
Expand Down
16 changes: 14 additions & 2 deletions src/Mscc.GenerativeAI/Types/Part.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,24 @@ namespace Mscc.GenerativeAI
[DebuggerDisplay("{Text}")]
public class Part
{
public Part() { }

/// <summary>
///
/// </summary>
/// <param name="text"></param>
public Part(string text)
{
Text = text;
}

/// <summary>
/// A text part of a conversation with the model.
/// </summary>
public string Text {
public string Text
{
get { return TextData?.Text; }
set { TextData = new TextData { Text = value }; }
set { TextData = new TextData { Text = value }; }
}
/// <remarks/>
[DebuggerHidden]
Expand Down
99 changes: 95 additions & 4 deletions tests/GoogleAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,108 @@ public async void Count_Tokens_Request(string prompt, int expected)
}

[Fact]
public async void Start_Chat()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
var chat = model.StartChat();
var prompt = "How can I learn more about C#?";

// Act
var response = await chat.SendMessage(prompt);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
output.WriteLine(response?.Text);
}

[Fact]
public async void Start_Chat_With_History()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
var history = new List<ContentResponse>
{
new ContentResponse { Role = "user", Parts = new List<Part> { new Part("Hello") } },
new ContentResponse { Role = "model", Parts = new List<Part> { new Part("Hello! How can I assist you today?") } }
};
var chat = model.StartChat(history);
var prompt = "How does electricity work?";

// Act
var response = await chat.SendMessage(prompt);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
output.WriteLine(prompt);
output.WriteLine(response?.Text);
//output.WriteLine(response?.PromptFeedback);
}

[Fact]
// Refs:
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-chat-prompts-gemini
public async void Start_Chat_Multiple_Prompts()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
var chat = model.StartChat();

// Act
var prompt = "Hello, let's talk a bit about nature.";
var response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);
prompt = "What are all the colors in a rainbow?";
response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);
prompt = "Why does it appear when it rains?";
response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
}

[Fact]
// Refs:
// https://ai.google.dev/tutorials/python_quickstart#chat_conversations
public async void Start_Chat_Conversations()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
var chat = model.StartChat();

// Act
_ = await chat.SendMessage("Hello, fancy brainstorming about IT?");
_ = await chat.SendMessage("In one sentence, explain how a computer works to a young child.");
_ = await chat.SendMessage("Okay, how about a more detailed explanation to a high schooler?");
_ = await chat.SendMessage("Lastly, give a thorough definition for a CS graduate.");

// Assert
chat.History.ForEach(c => output.WriteLine($"{c.Role}: {c.Parts[0].Text}"));
}

[Fact(Skip = "Incomplete")]
public async void Start_Chat_Streaming()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
var chat = model.StartChat();
var chatInput1 = "How can I learn more about C#?";
var prompt = "How can I learn more about C#?";

// Act
//var response = await chat.SendMessageStream(chatInput1);
var response = await chat.SendMessageStream(prompt);

//// Assert
// Assert
//response.Should().NotBeNull().And.HaveCountGreaterThanOrEqualTo(1);
//response.FirstOrDefault().Should().NotBeNull();
//response.ForEach(x => output.WriteLine(x.Text));
Expand All @@ -326,7 +417,7 @@ public async void Function_Calling_Chat()
var chatInput1 = "What is the weather in Boston?";

// Act
//var result1 = await chat.SendMessageStream(chatInput1);
//var result1 = await chat.SendMessageStream(prompt);
//var response1 = await result1.Response;
//var result2 = await chat.SendMessageStream(new List<IPart> { new FunctionResponse() });
//var response2 = await result2.Response;
Expand Down
51 changes: 51 additions & 0 deletions tests/VertexAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,57 @@ public async void Count_Tokens_Request(string prompt, int expected)
output.WriteLine($"Tokens: {response?.TotalTokens}");
}

[Fact]
public async void Start_Chat()
{
// Arrange
var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region);
var model = vertex.GenerativeModel(model: this.model);
model.AccessToken = fixture.AccessToken;
var chat = model.StartChat();
var prompt = "How can I learn more about C#?";

// Act
var response = await chat.SendMessage(prompt);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
output.WriteLine(response?.Text);
}

[Fact]
// Refs:
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-chat-prompts-gemini
public async void Start_Chat_Multiple_Prompts()
{
// Arrange
var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region);
var model = vertex.GenerativeModel(model: this.model);
model.AccessToken = fixture.AccessToken;
var chat = model.StartChat();

// Act
var prompt = "Hello.";
var response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);
prompt = "What are all the colors in a rainbow?";
response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);
prompt = "Why does it appear when it rains?";
response = await chat.SendMessage(prompt);
output.WriteLine(prompt);
output.WriteLine(response?.Text);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
}

[Fact]
public async void Start_Chat_Streaming()
{
Expand Down

0 comments on commit f5c97ba

Please sign in to comment.