Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
BeepBeepBopBop committed Nov 28, 2024
1 parent 7071e63 commit 9edeea7
Showing 1 changed file with 16 additions and 103 deletions.
119 changes: 16 additions & 103 deletions LM-Kit-Maestro/Services/LMKitService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public void UnloadModel()
public async Task<LMKitResult> SubmitPrompt(Conversation conversation, string prompt)
{
var promptRequest = new LMKitRequest(LMKitRequestType.Prompt,
new LMKitPromptRequestParameters(conversation, prompt),
new PromptRequestParameters(conversation, prompt),
LMKitConfig.RequestTimeout);

_promptSchedule.Schedule(promptRequest);
Expand All @@ -180,7 +180,7 @@ public async Task<LMKitResult> SubmitPrompt(Conversation conversation, string pr
return await HandlePromptRequest(promptRequest);
}

public async Task<PromptResult> RegenerateResponse(Conversation conversation)
public async Task<LMKitResult> RegenerateResponse(Conversation conversation)
{
//var message = conversation.ChatHistory.Messages.Last();

Expand Down Expand Up @@ -258,7 +258,7 @@ private async Task<LMKitResult> HandlePromptRequest(LMKitRequest promptRequest)
}
else
{
BeforeSubmittingPrompt(((LMKitPromptRequestParameters)promptRequest.Parameters!).Conversation);
BeforeSubmittingPrompt(((PromptRequestParameters)promptRequest.Parameters!).Conversation);
_lmKitServiceSemaphore.Release();

promptResult = await SubmitPromptRequest(promptRequest);
Expand All @@ -277,7 +277,7 @@ private async Task<LMKitResult> HandlePromptRequest(LMKitRequest promptRequest)

private async Task<LMKitResult> SubmitPromptRequest(LMKitRequest promptRequest)
{
LMKitPromptRequestParameters parameter = (promptRequest.Parameters as LMKitPromptRequestParameters)!;
PromptRequestParameters parameter = (promptRequest.Parameters as PromptRequestParameters)!;

try
{
Expand Down Expand Up @@ -390,7 +390,7 @@ private async Task<LMKitResult> SubmitTranslationRequest(LMKitRequest translatio
private void GenerateConversationSummaryTitle(Conversation conversation, string prompt)
{
LMKitRequest titleGenerationRequest = new LMKitRequest(LMKitRequestType.GenerateTitle,
new LMKitPromptRequestParameters(conversation, prompt), 60);
new PromptRequestParameters(conversation, prompt), 60);

_titleGenerationSchedule.Schedule(titleGenerationRequest);

Expand Down Expand Up @@ -525,8 +525,10 @@ private static TokenSampling GetTokenSampling(LMKitConfig config)
TargetEntropy = config.Mirostat2SamplingConfig.TargetEntropy
};
}
}

}

#region Data structures

private sealed class PromptSchedule
{
private readonly object _locker = new object();
Expand Down Expand Up @@ -601,7 +603,7 @@ public void Remove(LMKitRequest scheduledPrompt)
{
foreach (var scheduledPrompt in _scheduledPrompts)
{
if (scheduledPrompt.Parameters is LMKitPromptRequestParameters parameter && parameter.Conversation == conversation)
if (scheduledPrompt.Parameters is PromptRequestParameters parameter && parameter.Conversation == conversation)
{
prompt = scheduledPrompt;
break;
Expand Down Expand Up @@ -632,39 +634,8 @@ private void HandleScheduledPromptRemoval(LMKitRequest scheduledPrompt)
scheduledPrompt.CanBeExecutedSignal.Set();
}
}
}

private sealed class PromptRequest : LMKitPromptRequestBase
{
public PromptRequest(Conversation conversation, string prompt, int requestTimeout) : base(conversation, prompt, requestTimeout)
{
}
}

private sealed class TitleGenerationRequest : LMKitPromptRequestBase
{
public string Response { get; }

public TitleGenerationRequest(Conversation conversation, string prompt, string response, int requestTimeout) : base(conversation, prompt, requestTimeout)
{
Response = response;
}
}

private sealed class TranslationRequest : LMKitRequestBase
{
public TaskCompletionSource<TranslationResult> TranslationTask { get; } = new TaskCompletionSource<TranslationResult>();

public TranslationRequest(string prompt, int requestTimeout) : base(prompt, requestTimeout)
{
}

protected override void AwaitResult()
{
TranslationTask.Task.Wait();
}
}

}

private sealed class LMKitRequest
{
public ManualResetEvent CanBeExecutedSignal { get; } = new ManualResetEvent(false);
Expand Down Expand Up @@ -696,13 +667,13 @@ private enum LMKitRequestType
Translate
}

private sealed class LMKitPromptRequestParameters
private sealed class PromptRequestParameters
{
public Conversation Conversation { get; set; }

public string Prompt { get; set; }

public LMKitPromptRequestParameters(Conversation conversation, string prompt)
public PromptRequestParameters(Conversation conversation, string prompt)
{
Conversation = conversation;
Prompt = prompt;
Expand All @@ -722,45 +693,6 @@ public TranslationRequestParameters(string inputText, Language language)
}
}

private abstract class LMKitRequestBase
{
public string Prompt { get; }

public ManualResetEvent CanBeExecutedSignal { get; } = new ManualResetEvent(false);

public CancellationTokenSource CancellationTokenSource { get; }

public void CancelAndAwaitTermination()
{
CancellationTokenSource.Cancel();
AwaitResult();
}

protected abstract void AwaitResult();

protected LMKitRequestBase(string prompt, int requestTimeout)
{
Prompt = prompt;
CancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(requestTimeout));
}
}

private abstract class LMKitPromptRequestBase : LMKitRequestBase
{
public Conversation Conversation { get; }
public TaskCompletionSource<PromptResult> PromptResult { get; } = new TaskCompletionSource<PromptResult>();

protected LMKitPromptRequestBase(Conversation conversation, string prompt, int requestTimeout) : base(prompt, requestTimeout)
{
Conversation = conversation;
}

protected override void AwaitResult()
{
PromptResult.Task.Wait();
}
}

public class NotifyModelStateChangedEventArgs : EventArgs
{
public Uri FileUri { get; }
Expand Down Expand Up @@ -791,32 +723,13 @@ public ModelLoadingFailedEventArgs(Uri fileUri, Exception exception) : base(file
}
}

public sealed class PromptResult
{
public Exception? Exception { get; set; }

public LMKitTextGenerationStatus Status { get; set; }

public TextGenerationResult? TextGenerationResult { get; set; }
}

public sealed class LMKitResult
{
public Exception? Exception { get; set; }

public LMKitTextGenerationStatus Status { get; set; }

public object? Result { get; set; }
}

public sealed class TranslationResult
{
public Exception? Exception { get; set; }

public string? Result { get; set; }

public LMKitTextGenerationStatus Status { get; set; }
}


}
#endregion
}

0 comments on commit 9edeea7

Please sign in to comment.