Skip to content

Commit

Permalink
Add regenerate last response feature (#46)
Browse files Browse the repository at this point in the history
Co-authored-by: Loïc Carrère <[email protected]>
  • Loading branch information
BeepBeepBopBop and lcarrere authored Dec 7, 2024
1 parent 84401b3 commit dbb6b57
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 207 deletions.
8 changes: 0 additions & 8 deletions LM-Kit-Maestro/Models/Message.cs

This file was deleted.

2 changes: 1 addition & 1 deletion LM-Kit-Maestro/Services/Enumerations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ public enum LMKitTextGenerationStatus
{
Undefined,
Cancelled,
UnknownError
GenericError
}
43 changes: 26 additions & 17 deletions LM-Kit-Maestro/Services/LMKitService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using LMKit.TextGeneration.Sampling;
using LMKit.TextGeneration.Chat;
using LMKit.Translation;
using System.Diagnostics;

namespace LMKit.Maestro.Services;

Expand Down Expand Up @@ -171,10 +172,10 @@ public async Task<LMKitResult> SubmitPrompt(Conversation conversation, string pr

public async Task<LMKitResult> RegenerateResponse(Conversation conversation, ChatHistory.Message message)
{
// Ignoring message parameter, only regenerate the latest response for now.
var prompt = conversation.ChatHistory!.Messages[conversation.ChatHistory.Messages.Count - 1];
var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse,
new LMKitRequest.RegenerateResponseParameters(conversation, message), LMKitConfig.RequestTimeout);

var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse, prompt.Content, LMKitConfig.RequestTimeout);
ScheduleRequest(regenerateResponseRequest);

return await HandleLmKitRequest(regenerateResponseRequest);
}
Expand All @@ -184,14 +185,17 @@ public async Task CancelPrompt(Conversation conversation, bool shouldAwaitTermin
var conversationPrompt = _requestSchedule.Unschedule(conversation);

if (conversationPrompt != null)
{
{
_lmKitServiceSemaphore.Wait();
conversationPrompt.CancellationTokenSource.Cancel();
conversationPrompt.ResponseTask.TrySetCanceled();
_lmKitServiceSemaphore.Release();

if (shouldAwaitTermination)
{
await conversationPrompt.ResponseTask.Task.WaitAsync(TimeSpan.FromSeconds(10));
}

}
}

Expand All @@ -205,14 +209,14 @@ private void ScheduleRequest(LMKitRequest request)
}
}

private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest promptRequest)
private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest request)
{
// Ensuring we don't touch anything until Lm-Kit objects' state has been set to handle this prompt request.
// Ensuring we don't touch anything until Lm-Kit objects' state has been set to handle this request.
_lmKitServiceSemaphore.Wait();

LMKitResult result;

if (promptRequest.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded)
if (request.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded)
{
result = new LMKitResult()
{
Expand All @@ -223,22 +227,27 @@ private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest promptRequest)
}
else
{
if (promptRequest.RequestType == LMKitRequest.LMKitRequestType.Prompt)
if (request.RequestType == LMKitRequest.LMKitRequestType.Prompt || request.RequestType == LMKitRequest.LMKitRequestType.RegenerateResponse)
{
BeforeSubmittingPrompt(((LMKitRequest.PromptRequestParameters)promptRequest.Parameters!).Conversation);
var conversation = request.RequestType == LMKitRequest.LMKitRequestType.Prompt ?
((LMKitRequest.PromptRequestParameters)request.Parameters!).Conversation :
((LMKitRequest.RegenerateResponseParameters)request.Parameters!).Conversation;

BeforeSubmittingPrompt(conversation);
}

_lmKitServiceSemaphore.Release();

result = await SubmitRequest(promptRequest);

result = await SubmitRequest(request);
}

if (_requestSchedule.Contains(promptRequest))
if (_requestSchedule.Contains(request))
{
_requestSchedule.Remove(promptRequest);
_requestSchedule.Remove(request);
}

promptRequest.ResponseTask.TrySetResult(result);
request.ResponseTask.TrySetResult(result);

return result;

Expand Down Expand Up @@ -271,7 +280,7 @@ private async Task<LMKitResult> SubmitRequest(LMKitRequest request)
}
}
catch (Exception exception)
{
{
result.Exception = exception;

if (result.Exception is OperationCanceledException)
Expand All @@ -280,7 +289,7 @@ private async Task<LMKitResult> SubmitRequest(LMKitRequest request)
}
else
{
result.Status = LMKitTextGenerationStatus.UnknownError;
result.Status = LMKitTextGenerationStatus.GenericError;
}
}

Expand All @@ -291,7 +300,7 @@ private async Task<LMKitResult> SubmitRequest(LMKitRequest request)
parameter.Conversation.ChatHistory = _multiTurnConversation.ChatHistory;
parameter.Conversation.LatestChatHistoryData = _multiTurnConversation.ChatHistory.Serialize();

if (parameter.Conversation.GeneratedTitleSummary == null && result.Status == LMKitTextGenerationStatus.Undefined
if (parameter.Conversation.GeneratedTitleSummary == null && result.Status == LMKitTextGenerationStatus.Undefined
&& !string.IsNullOrEmpty(((TextGenerationResult)result.Result!).Completion))
{
GenerateConversationSummaryTitle(parameter.Conversation, parameter.Prompt);
Expand All @@ -310,7 +319,7 @@ private async Task<LMKitResult> SubmitRequest(LMKitRequest request)
return new LMKitResult()
{
Exception = exception,
Status = LMKitTextGenerationStatus.UnknownError
Status = LMKitTextGenerationStatus.GenericError
};
}
finally
Expand Down
14 changes: 14 additions & 0 deletions LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LMKit.TextGeneration;
using LMKit.TextGeneration.Chat;
using System.ComponentModel;

namespace LMKit.Maestro.Services;
Expand Down Expand Up @@ -48,6 +49,19 @@ public PromptRequestParameters(Conversation conversation, string prompt)
}
}

public sealed class RegenerateResponseParameters
{
public Conversation Conversation { get; set; }

public ChatHistory.Message Message { get; set; }

public RegenerateResponseParameters(Conversation conversation, ChatHistory.Message message)
{
Conversation = conversation;
Message = message;
}
}

public sealed class TranslationRequestParameters
{
public string InputText { get; set; }
Expand Down
7 changes: 6 additions & 1 deletion LM-Kit-Maestro/Services/LmKitService.RequestSchedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ public void Remove(LMKitRequest scheduledPrompt)
{
foreach (var scheduledPrompt in _scheduledPrompts)
{
if (scheduledPrompt.Parameters is LMKitRequest.PromptRequestParameters parameter && parameter.Conversation == conversation)
if (scheduledPrompt.Parameters is LMKitRequest.PromptRequestParameters promptParameters && promptParameters.Conversation == conversation)
{
prompt = scheduledPrompt;
break;
}
else if (scheduledPrompt.Parameters is LMKitRequest.RegenerateResponseParameters regenerateParameters && regenerateParameters.Conversation == conversation)
{
prompt = scheduledPrompt;
break;
Expand Down
Loading

0 comments on commit dbb6b57

Please sign in to comment.