Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
BeepBeepBopBop committed Nov 29, 2024
1 parent d368939 commit 90c6d31
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
26 changes: 15 additions & 11 deletions LM-Kit-Maestro/Services/LMKitService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ public async Task<LMKitResult> SubmitPrompt(Conversation conversation, string pr

public async Task<LMKitResult> RegenerateResponse(Conversation conversation, ChatHistory.Message message)
{
// Ignoring message parameter, only regenerate the latest response for now.
var prompt = conversation.ChatHistory!.Messages[conversation.ChatHistory.Messages.Count - 1];
var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse,
message, LMKitConfig.RequestTimeout);

var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse, prompt.Content, LMKitConfig.RequestTimeout);
ScheduleRequest(regenerateResponseRequest);

return await HandleLmKitRequest(regenerateResponseRequest);
}
Expand Down Expand Up @@ -205,14 +205,14 @@ private void ScheduleRequest(LMKitRequest request)
}
}

private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest promptRequest)
private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest request)
{
// Ensuring we don't touch anything until Lm-Kit objects' state has been set to handle this prompt request.
_lmKitServiceSemaphore.Wait();

LMKitResult result;

if (promptRequest.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded)
if (request.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded)
{
result = new LMKitResult()
{
Expand All @@ -223,22 +223,26 @@ private async Task<LMKitResult> HandleLmKitRequest(LMKitRequest promptRequest)
}
else
{
if (promptRequest.RequestType == LMKitRequest.LMKitRequestType.Prompt)
if (request.RequestType == LMKitRequest.LMKitRequestType.Prompt)
{
BeforeSubmittingPrompt(((LMKitRequest.PromptRequestParameters)promptRequest.Parameters!).Conversation);
BeforeSubmittingPrompt(((LMKitRequest.PromptRequestParameters)request.Parameters!).Conversation);
}
else if (request.RequestType == LMKitRequest.LMKitRequestType.RegenerateResponse)
{
BeforeSubmittingPrompt(((LMKitRequest.RegenerateResponseParameters)request.Parameters!).Conversation);
}

_lmKitServiceSemaphore.Release();

result = await SubmitRequest(promptRequest);
result = await SubmitRequest(request);
}

if (_requestSchedule.Contains(promptRequest))
if (_requestSchedule.Contains(request))
{
_requestSchedule.Remove(promptRequest);
_requestSchedule.Remove(request);
}

promptRequest.ResponseTask.TrySetResult(result);
request.ResponseTask.TrySetResult(result);

return result;

Expand Down
14 changes: 14 additions & 0 deletions LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LMKit.TextGeneration;
using LMKit.TextGeneration.Chat;
using System.ComponentModel;

namespace LMKit.Maestro.Services;
Expand Down Expand Up @@ -48,6 +49,19 @@ public PromptRequestParameters(Conversation conversation, string prompt)
}
}

public sealed class RegenerateResponseParameters
{
public Conversation Conversation { get; set; }

public ChatHistory.Message Message { get; set; }

public RegenerateResponseParameters(Conversation conversation, ChatHistory.Message message)
{
Conversation = conversation;
Message = message;
}
}

public sealed class TranslationRequestParameters
{
public string InputText { get; set; }
Expand Down
6 changes: 1 addition & 5 deletions LM-Kit-Maestro/ViewModels/ConversationViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ private void OnLMKitChatHistoryChanged(object? sender, NotifyCollectionChangedEv
{
foreach (var item in e.NewItems!)
{
_mainThread.BeginInvokeOnMainThread(() => Messages.Add(new MessageViewModel(this, (ChatHistory.Message)item)));
Messages.Add(new MessageViewModel(this, (ChatHistory.Message)item));
}
}
else if (e.Action == NotifyCollectionChangedAction.Remove)
Expand All @@ -307,10 +307,6 @@ private void OnLMKitChatHistoryChanged(object? sender, NotifyCollectionChangedEv
count++;
}
}
else if (e.Action == NotifyCollectionChangedAction.Replace)
{

}
}

private void OnMessagesCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
Expand Down

0 comments on commit 90c6d31

Please sign in to comment.