Skip to content

Commit

Permalink
Issue 42: LmKitService requests handling refactoring (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
BeepBeepBopBop authored Nov 29, 2024
1 parent 03b8e23 commit 1ccd9c6
Show file tree
Hide file tree
Showing 12 changed files with 375 additions and 360 deletions.
8 changes: 5 additions & 3 deletions LM-Kit-Maestro/Services/LMKitService.Conversation.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Specialized;
using System.ComponentModel;
using LMKit.TextGeneration;
using LMKit.TextGeneration.Chat;

namespace LMKit.Maestro.Services;
Expand Down Expand Up @@ -107,13 +108,14 @@ private void OnChatHistoryMessageCollectionChanged(object? sender, NotifyCollect
ChatHistoryChanged?.Invoke(this, e);
}

public void SetGeneratedTitle(PromptResult textGenerationResult)
public void SetGeneratedTitle(LMKitResult result)
{
string? conversationTopic = null;

if (textGenerationResult.TextGenerationResult != null && !string.IsNullOrEmpty(textGenerationResult.TextGenerationResult.Completion))
if (result.Result != null && result.Result is TextGenerationResult textGenerationResult &&
!string.IsNullOrEmpty(textGenerationResult.Completion))
{
foreach (var sentence in textGenerationResult.TextGenerationResult.Completion.Split('\n'))
foreach (var sentence in textGenerationResult.Completion.Split('\n'))
{
if (sentence.ToLower().StartsWith("topic"))
{
Expand Down
416 changes: 88 additions & 328 deletions LM-Kit-Maestro/Services/LMKitService.cs

Large diffs are not rendered by default.

64 changes: 64 additions & 0 deletions LM-Kit-Maestro/Services/LmKitService.LmKitRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using LMKit.TextGeneration;
using System.ComponentModel;

namespace LMKit.Maestro.Services;

public partial class LMKitService : INotifyPropertyChanged
{
private sealed class LMKitRequest
{
public ManualResetEvent CanBeExecutedSignal { get; } = new ManualResetEvent(false);
public CancellationTokenSource CancellationTokenSource { get; }
public TaskCompletionSource<LMKitResult> ResponseTask { get; } = new TaskCompletionSource<LMKitResult>();
public object? Parameters { get; }

public LMKitRequestType RequestType { get; }

public LMKitRequest(LMKitRequestType requestType, object? parameter, int requestTimeout)
{
RequestType = requestType;
Parameters = parameter;
CancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(requestTimeout));
}

public void CancelAndAwaitTermination()
{
CancellationTokenSource.Cancel();
ResponseTask.Task.Wait();
}

public enum LMKitRequestType
{
Prompt,
RegenerateResponse,
GenerateTitle,
Translate
}

public sealed class PromptRequestParameters
{
public Conversation Conversation { get; set; }

public string Prompt { get; set; }

public PromptRequestParameters(Conversation conversation, string prompt)
{
Conversation = conversation;
Prompt = prompt;
}
}

public sealed class TranslationRequestParameters
{
public string InputText { get; set; }

public Language Language { get; set; }

public TranslationRequestParameters(string inputText, Language language)
{
InputText = inputText;
Language = language;
}
}
}
}
114 changes: 114 additions & 0 deletions LM-Kit-Maestro/Services/LmKitService.RequestSchedule.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using System.ComponentModel;
using LMKit.TextGeneration;

namespace LMKit.Maestro.Services;

public partial class LMKitService : INotifyPropertyChanged
{
private sealed class RequestSchedule
{
private readonly object _locker = new object();

private List<LMKitRequest> _scheduledPrompts = new List<LMKitRequest>();

public int Count
{
get
{
lock (_locker)
{
return _scheduledPrompts.Count;
}
}
}

public LMKitRequest? Next
{
get
{
lock (_locker)
{
if (_scheduledPrompts.Count > 0)
{
var scheduledPrompt = _scheduledPrompts[0];

return scheduledPrompt;
}

return null;
}
}
}

public LMKitRequest? RunningPromptRequest { get; set; }

public void Schedule(LMKitRequest promptRequest)
{
lock (_locker)
{
_scheduledPrompts.Add(promptRequest);

if (Count == 1)
{
promptRequest.CanBeExecutedSignal.Set();
}
}
}

public bool Contains(LMKitRequest scheduledPrompt)
{
lock (_locker)
{
return _scheduledPrompts.Contains(scheduledPrompt);
}
}

public void Remove(LMKitRequest scheduledPrompt)
{
lock (_locker)
{
HandleScheduledPromptRemoval(scheduledPrompt);
}
}

public LMKitRequest? Unschedule(Conversation conversation)
{
LMKitRequest? prompt = null;

lock (_locker)
{
foreach (var scheduledPrompt in _scheduledPrompts)
{
if (scheduledPrompt.Parameters is LMKitRequest.PromptRequestParameters parameter && parameter.Conversation == conversation)
{
prompt = scheduledPrompt;
break;
}
}

if (prompt != null)
{
HandleScheduledPromptRemoval(prompt);
}
}

return prompt;
}

private void HandleScheduledPromptRemoval(LMKitRequest scheduledPrompt)
{
bool wasFirstInLine = scheduledPrompt == _scheduledPrompts[0];

_scheduledPrompts.Remove(scheduledPrompt);

if (wasFirstInLine && Next != null)
{
Next.CanBeExecutedSignal.Set();
}
else
{
scheduledPrompt.CanBeExecutedSignal.Set();
}
}
}
}
9 changes: 9 additions & 0 deletions LM-Kit-Maestro/UI/Razor/Components/ChatMessage.razor
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
<button @onclick="OnCopyMessageButtonClicked" class="chatActionButton material-icons">
@(MessageJustCopied ? "check" : "content_copy")
</button>

<button @onclick="OnRegenerateResponseButtonClicked" class="chatActionButton material-icons">
sync
</button>
</div>

@if (MessageViewModel.Sender == MessageSender.Assistant)
Expand Down Expand Up @@ -286,6 +290,11 @@
}
}

private async Task OnRegenerateResponseButtonClicked()
{
MessageViewModel.ParentConversation.RegenerateResponseCommand.Execute(MessageViewModel);
}

private static string GenerateAssistantResponseHtml(MessageViewModel messageViewModel)
{
string html = !string.IsNullOrEmpty(messageViewModel.Text) ? Markdig.Markdown.ToHtml(messageViewModel.Text) : string.Empty;
Expand Down
70 changes: 49 additions & 21 deletions LM-Kit-Maestro/ViewModels/ConversationViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void LoadConversationLogs()
{
if (message.AuthorRole == AuthorRole.Assistant || message.AuthorRole == AuthorRole.User)
{
Messages.Add(new MessageViewModel(message) { MessageInProgress = false });
Messages.Add(new MessageViewModel(this, message) { MessageInProgress = false });
}
}

Expand All @@ -152,12 +152,20 @@ public async Task Delete()
}
}

[RelayCommand]
private async Task RegenerateResponse(MessageViewModel message)
{
var response = await _lmKitService.RegenerateResponse(_lmKitConversation, message.LMKitMessage!);

OnResponseRegenerated();
}

protected override void HandleSubmit()
{
string prompt = InputText;
OnNewlySubmittedPrompt(prompt);

LMKitService.PromptResult? promptResult = null;
LMKitService.LMKitResult? promptResult = null;

Task.Run(async () =>
{
Expand All @@ -173,7 +181,34 @@ protected override void HandleSubmit()
});
}

private void OnPromptResult(LMKitService.PromptResult? promptResult, Exception? submitPromptException = null)
private void OnResponseRegenerating(MessageViewModel message)
{
message.Text = string.Empty;
message.MessageInProgress = true;
AwaitingResponse = true;
_awaitingLMKitAssistantMessage = true;
}

private void OnNewlySubmittedPrompt(string prompt)
{
InputText = string.Empty;
UsedDifferentModel &= false;
LatestPromptStatus = LMKitTextGenerationStatus.Undefined;
AwaitingResponse = true;
_awaitingLMKitUserMessage = true;
_awaitingLMKitAssistantMessage = true;
_pendingPrompt = new MessageViewModel(this, new Message() { Sender = MessageSender.User, Text = prompt });
_pendingResponse = new MessageViewModel(this, new Message() { Sender = MessageSender.Assistant }) { MessageInProgress = true };

Messages.Add(_pendingPrompt);
Messages.Add(_pendingResponse);
}

private void OnResponseRegenerated()
{

}
private void OnPromptResult(LMKitService.LMKitResult? promptResult, Exception? submitPromptException = null)
{
AwaitingResponse = false;

Expand All @@ -199,9 +234,9 @@ private void OnPromptResult(LMKitService.PromptResult? promptResult, Exception?
_pendingResponse!.Status = LatestPromptStatus;
_pendingPrompt!.Status = LatestPromptStatus;

if (promptResult.Status == LMKitTextGenerationStatus.Undefined && promptResult.TextGenerationResult != null)
if (promptResult.Status == LMKitTextGenerationStatus.Undefined && promptResult.Result is TextGenerationResult textGenerationResult)
{
OnTextGenerationSuccess(promptResult.TextGenerationResult);
OnTextGenerationSuccess(textGenerationResult);
}
else
{
Expand Down Expand Up @@ -251,20 +286,6 @@ private void SaveConversation()
});
}

private void OnNewlySubmittedPrompt(string prompt)
{
InputText = string.Empty;
UsedDifferentModel &= false;
LatestPromptStatus = LMKitTextGenerationStatus.Undefined;
AwaitingResponse = true;
_awaitingLMKitUserMessage = true;
_awaitingLMKitAssistantMessage = true;
_pendingPrompt = new MessageViewModel(new Message() { Sender = MessageSender.User, Text = prompt });
_pendingResponse = new MessageViewModel(new Message() { Sender = MessageSender.Assistant }) { MessageInProgress = true };

Messages.Add(_pendingPrompt);
Messages.Add(_pendingResponse);
}

private void OnTextGenerationSuccess(TextGenerationResult result)
{
Expand Down Expand Up @@ -319,18 +340,25 @@ private void OnLMKitChatHistoryChanged(object? sender, NotifyCollectionChangedEv
}
else
{
MessageViewModel messageViewModel = new MessageViewModel(message);
MessageViewModel messageViewModel = new MessageViewModel(this, message);
_mainThread.BeginInvokeOnMainThread(() => Messages.Add(messageViewModel));
}
}
}
else if (e.Action == NotifyCollectionChangedAction.Remove)
{
int count = 0;

foreach (var item in e.OldItems!)
{
_mainThread.BeginInvokeOnMainThread(() => Messages.RemoveAt(e.OldStartingIndex));
_mainThread.BeginInvokeOnMainThread(() => Messages.RemoveAt(e.OldStartingIndex - e.OldItems.Count + count));
count++;
}
}
else if (e.Action == NotifyCollectionChangedAction.Replace)
{

}
}

private void OnMessagesCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e)
Expand Down
8 changes: 6 additions & 2 deletions LM-Kit-Maestro/ViewModels/MessageViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace LMKit.Maestro.ViewModels;

public partial class MessageViewModel : ViewModelBase
{
public ConversationViewModel ParentConversation { get; }

private ChatHistory.Message? _lmKitMessage;

public ChatHistory.Message? LMKitMessage
Expand Down Expand Up @@ -59,15 +61,17 @@ public ChatHistory.Message? LMKitMessage

public event EventHandler? MessageContentUpdated;

public MessageViewModel(ChatHistory.Message message)
public MessageViewModel(ConversationViewModel parentConversation, ChatHistory.Message message)
{
ParentConversation = parentConversation;
MessageModel = new Message();
Text = message.Content;
LMKitMessage = message;
}

public MessageViewModel(Message message)
public MessageViewModel(ConversationViewModel parentConversation, Message message)
{
ParentConversation = parentConversation;
MessageModel = message;
Sender = message.Sender;
Text = message.Text ?? string.Empty;
Expand Down
2 changes: 1 addition & 1 deletion tests/ChatHistoryPersistenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public async Task MessageCorrectlyPersistedWhenUnloadingModelDuringGeneration()
public async Task ChatHistoryIsRestored()
{
MaestroTestsService testService = new();
bool loadingSuccess = await testService.LoadModel(MaestroTestsService.Model2);
bool loadingSuccess = await testService.LoadModel();
Assert.True(loadingSuccess);

ConversationLog dummyConversationLog = new()
Expand Down
Loading

0 comments on commit 1ccd9c6

Please sign in to comment.