diff --git a/LM-Kit-Maestro/Services/LMKitService.cs b/LM-Kit-Maestro/Services/LMKitService.cs index 3f35082..1f8c173 100644 --- a/LM-Kit-Maestro/Services/LMKitService.cs +++ b/LM-Kit-Maestro/Services/LMKitService.cs @@ -171,10 +171,10 @@ public async Task SubmitPrompt(Conversation conversation, string pr public async Task RegenerateResponse(Conversation conversation, ChatHistory.Message message) { - // Ignoring message parameter, only regenerate the latest response for now. - var prompt = conversation.ChatHistory!.Messages[conversation.ChatHistory.Messages.Count - 1]; + var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse, + message, LMKitConfig.RequestTimeout); - var regenerateResponseRequest = new LMKitRequest(LMKitRequest.LMKitRequestType.RegenerateResponse, prompt.Content, LMKitConfig.RequestTimeout); + ScheduleRequest(regenerateResponseRequest); return await HandleLmKitRequest(regenerateResponseRequest); } @@ -205,14 +205,14 @@ private void ScheduleRequest(LMKitRequest request) } } - private async Task HandleLmKitRequest(LMKitRequest promptRequest) + private async Task HandleLmKitRequest(LMKitRequest request) { // Ensuring we don't touch anything until Lm-Kit objects' state has been set to handle this prompt request. _lmKitServiceSemaphore.Wait(); LMKitResult result; - if (promptRequest.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded) + if (request.CancellationTokenSource.IsCancellationRequested || ModelLoadingState == LMKitModelLoadingState.Unloaded) { result = new LMKitResult() { @@ -223,22 +223,26 @@ private async Task HandleLmKitRequest(LMKitRequest promptRequest) } else { - if (promptRequest.RequestType == LMKitRequest.LMKitRequestType.Prompt) + if (request.RequestType == LMKitRequest.LMKitRequestType.Prompt) { - BeforeSubmittingPrompt(((LMKitRequest.PromptRequestParameters)promptRequest.Parameters!).Conversation); + BeforeSubmittingPrompt(((LMKitRequest.PromptRequestParameters)request.Parameters!).Conversation); + } + else if (request.RequestType == LMKitRequest.LMKitRequestType.RegenerateResponse) + { + BeforeSubmittingPrompt(((LMKitRequest.RegenerateResponseParameters)request.Parameters!).Conversation); } _lmKitServiceSemaphore.Release(); - result = await SubmitRequest(promptRequest); + result = await SubmitRequest(request); } - if (_requestSchedule.Contains(promptRequest)) + if (_requestSchedule.Contains(request)) { - _requestSchedule.Remove(promptRequest); + _requestSchedule.Remove(request); } - promptRequest.ResponseTask.TrySetResult(result); + request.ResponseTask.TrySetResult(result); return result; diff --git a/LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs b/LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs index 5366d50..e815259 100644 --- a/LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs +++ b/LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs @@ -1,4 +1,5 @@ using LMKit.TextGeneration; +using LMKit.TextGeneration.Chat; using System.ComponentModel; namespace LMKit.Maestro.Services; @@ -48,6 +49,19 @@ public PromptRequestParameters(Conversation conversation, string prompt) } } + public sealed class RegenerateResponseParameters + { + public Conversation Conversation { get; set; } + + public ChatHistory.Message Message { get; set; } + + public RegenerateResponseParameters(Conversation conversation, ChatHistory.Message message) + { + Conversation = conversation; + Message = message; + } + } + public sealed class TranslationRequestParameters { public string InputText { get; set; } diff --git a/LM-Kit-Maestro/ViewModels/ConversationViewModel.cs b/LM-Kit-Maestro/ViewModels/ConversationViewModel.cs index 1ca9986..64a8b9e 100644 --- a/LM-Kit-Maestro/ViewModels/ConversationViewModel.cs +++ b/LM-Kit-Maestro/ViewModels/ConversationViewModel.cs @@ -294,7 +294,7 @@ private void OnLMKitChatHistoryChanged(object? sender, NotifyCollectionChangedEv { foreach (var item in e.NewItems!) { - _mainThread.BeginInvokeOnMainThread(() => Messages.Add(new MessageViewModel(this, (ChatHistory.Message)item))); + Messages.Add(new MessageViewModel(this, (ChatHistory.Message)item)); } } else if (e.Action == NotifyCollectionChangedAction.Remove) @@ -307,10 +307,6 @@ private void OnLMKitChatHistoryChanged(object? sender, NotifyCollectionChangedEv count++; } } - else if (e.Action == NotifyCollectionChangedAction.Replace) - { - - } } private void OnMessagesCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)