Skip to content

Commit

Permalink
start feature enableAutomaticFunctionCalling (work in progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
jochenkirstaetter committed Mar 29, 2024
1 parent e95579d commit ba34bb5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/Mscc.GenerativeAI/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -784,12 +784,13 @@ public async Task<CountTokensResponse> CountTokens(List<IPart>? parts)
public ChatSession StartChat(List<ContentResponse>? history = null,
GenerationConfig? generationConfig = null,
List<SafetySetting>? safetySettings = null,
List<Tool>? tools = null)
List<Tool>? tools = null,
bool enableAutomaticFunctionCalling = false)
{
var config = generationConfig ?? _generationConfig;
var safety = safetySettings ?? _safetySettings;
var tool = tools ?? _tools;
return new ChatSession(this, history, config, safety, tool);
return new ChatSession(this, history, config, safety, tool, enableAutomaticFunctionCalling);
}

#region "PaLM 2" methods
Expand Down
77 changes: 76 additions & 1 deletion src/Mscc.GenerativeAI/Types/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class ChatSession
private readonly GenerationConfig? _generationConfig;
private readonly List<SafetySetting>? _safetySettings;
private readonly List<Tool>? _tools;
private readonly bool _enableAutomaticFunctionCalling;
private ContentResponse? _lastSent;
private ContentResponse? _lastReceived;

Expand All @@ -40,17 +41,20 @@ public class ChatSession
/// <param name="generationConfig">Optional. Configuration options for model generation and outputs.</param>
/// <param name="safetySettings">Optional. A list of unique SafetySetting instances for blocking unsafe content.</param>
/// <param name="tools">Optional. </param>
/// <param name="enableAutomaticFunctionCalling">Optional. </param>
public ChatSession(GenerativeModel model,
List<ContentResponse>? history = null,
GenerationConfig? generationConfig = null,
List<SafetySetting>? safetySettings = null,
List<Tool>? tools = null)
List<Tool>? tools = null,
bool enableAutomaticFunctionCalling = false)
{
_model = model;
History = history ?? new List<ContentResponse>();
_generationConfig = generationConfig;
_safetySettings = safetySettings;
_tools = tools;
_enableAutomaticFunctionCalling = enableAutomaticFunctionCalling;
}

/// <summary>
Expand All @@ -61,6 +65,9 @@ public ChatSession(GenerativeModel model,
/// <param name="generationConfig">Optional. Overrides for the model's generation config.</param>
/// <param name="safetySettings">Optional. Overrides for the model's safety settings.</param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="BlockedPromptException"></exception>
/// <exception cref="StopCandidateException"></exception>
public async Task<GenerateContentResponse> SendMessage(string prompt,
GenerationConfig? generationConfig = null,
List<SafetySetting>? safetySettings = null)
Expand Down Expand Up @@ -89,6 +96,18 @@ public async Task<GenerateContentResponse> SendMessage(string prompt,
};

var response = await _model.GenerateContent(request);

response.CheckResponse();

if (_enableAutomaticFunctionCalling)
{
var result = HandleAutomaticFunctionCalling(response,
History,
generationConfig ?? _generationConfig,
safetySettings ?? _safetySettings,
_tools);
}

_lastReceived = new ContentResponse
{
Role = Role.Model, Parts = new List<Part> { new Part { Text = response.Text ?? string.Empty } }
Expand All @@ -105,6 +124,8 @@ public async Task<GenerateContentResponse> SendMessage(string prompt,
/// <param name="safetySettings">Optional. Overrides for the model's safety settings.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="BlockedPromptException"></exception>
public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object content,
GenerationConfig? generationConfig = null,
List<SafetySetting>? safetySettings = null,
Expand Down Expand Up @@ -145,8 +166,11 @@ public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object
var response = _model.GenerateContentStream(request, cancellationToken);
await foreach (var item in response)
{
item.CheckResponse(true);

if (cancellationToken.IsCancellationRequested)
yield break;

fullText.Append(item.Text);
yield return item;
}
Expand Down Expand Up @@ -181,5 +205,56 @@ public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object
History.RemoveRange(position, 2);
return result;
}

private (List<ContentResponse> history, ContentResponse content, GenerateContentResponse response) HandleAutomaticFunctionCalling(GenerateContentResponse response, List<ContentResponse> history, GenerationConfig generationConfig, List<SafetySetting> safetySettings, List<Tool> tools)
{
throw new NotImplementedException();
var functionResponseParts = new List<Part>();
var functionCalls = GetFunctionCalls(response);
history.Add(response.Candidates[0].Content);

foreach (var functionCall in functionCalls)
{
var functionResponse = tools.Find(x =>
x.FunctionDeclarations.Where(fd => fd.Name == functionCall.Name).Any());
if (functionResponse is not null)
{
// functionResponseParts.Add(functionResponse);
}
}

var send = new ContentResponse() { Role = Role.User, Parts = functionResponseParts };
history.Add(send);

var request = new GenerateContentRequest
{
Contents = history.Select(x =>
new Content { Role = x.Role, PartTypes = x.Parts }
).ToList(),
GenerationConfig = generationConfig,
SafetySettings = safetySettings,
Tools = tools
};
response = _model.GenerateContent(request).Result;
response.CheckResponse();

return (history, send, response);
}

private List<FunctionCall> GetFunctionCalls(GenerateContentResponse response)
{
if (response.Candidates.Count != 1)
{
throw new ValueErrorException(
$"Automatic function calling only works with 1 candidate, got {response.Candidates.Count}");
}

var parts = response.Candidates[0].Content.Parts;
var functionCalls = parts
.Where(x => x.FunctionCall != null)
.Select(x => x.FunctionCall)
.ToList();
return functionCalls;
}
}
}

0 comments on commit ba34bb5

Please sign in to comment.