Skip to content

Commit

Permalink
more overloads and extensions!
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenHodgson committed Nov 19, 2023
1 parent 77162d1 commit b292c2b
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 53 deletions.
2 changes: 1 addition & 1 deletion OpenAI-DotNet-Tests/TestFixture_12_Threads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public async Task Test_04_03_ModifyMessage()
{
["test"] = "04_03"
};
var modified = await testMessage.ModifyMessageAsync(metadata);
var modified = await testMessage.ModifyAsync(metadata);
Assert.IsNotNull(modified);
Assert.IsNotNull(modified.Metadata);
Assert.IsTrue(modified.Metadata["test"].Equals("04_03"));
Expand Down
79 changes: 31 additions & 48 deletions OpenAI-DotNet-Tests/TestFixture_13_ThreadRuns.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace OpenAI.Tests
/// </summary>
internal class TestFixture_13_ThreadRuns : AbstractTestFixture
{
private static string assistantId;
private static string threadId;
private static string runId;
private static AssistantResponse testAssistant;
private static ThreadResponse testThread;
private static RunResponse testRun;

[Test]
public async Task Test_01_CreateRun()
Expand All @@ -26,67 +26,69 @@ public async Task Test_01_CreateRun()
instructions: "You are a personal math tutor. Answer questions briefly, in a sentence or less.",
model: "gpt-4-1106-preview"));
Assert.NotNull(assistant);
assistantId = assistant.Id;
testAssistant = assistant;
var thread = await OpenAIClient.ThreadsEndpoint.CreateThreadAsync();
Assert.NotNull(thread);
threadId = thread.Id;

try
{
var message = OpenAIClient.ThreadsEndpoint.CreateMessageAsync(thread, "I need to solve the equation `3x + 11 = 14`. Can you help me?");
var message = thread.CreateMessageAsync("I need to solve the equation `3x + 11 = 14`. Can you help me?");
Assert.NotNull(message);
var run = await OpenAIClient.ThreadsEndpoint.CreateRunAsync(thread, new CreateRunRequest(assistant));
var run = await thread.CreateRunAsync(assistant);
Assert.IsNotNull(run);
}
finally
{
await OpenAIClient.ThreadsEndpoint.DeleteThreadAsync(thread);
await thread.DeleteAsync();
}
}

[Test]
public async Task Test_02_CreateThreadAndRun()
{
var request = new CreateThreadAndRunRequest(assistantId);
var run = await OpenAIClient.ThreadsEndpoint.CreateThreadAndRunAsync(request);
Assert.NotNull(testAssistant);
Assert.NotNull(OpenAIClient.ThreadsEndpoint);
var run = await testAssistant.CreateThreadAndRunAsync();
Assert.IsNotNull(run);
runId = run.Id;
Assert.IsFalse(string.IsNullOrWhiteSpace(run.ThreadId));
threadId = run.ThreadId;
testRun = run;
var thread = await run.GetThreadAsync();
Assert.NotNull(thread);
testThread = thread;
}

[Test]
public async Task Test_04_ModifyRun()
{
Assert.NotNull(testRun);
Assert.NotNull(OpenAIClient.ThreadsEndpoint);
// run in Queued and InProgress can't be modified
var run = await WaitOnRunAsync(threadId, runId, RunStatus.Queued, RunStatus.InProgress);

var modified = await OpenAIClient.ThreadsEndpoint.ModifyRunAsync(run,
new Dictionary<string, string>
{
["key"] = "value"
});

var run = await testRun.WaitForStatusAsync();
Assert.IsNotNull(run);
Assert.IsTrue(run.Status == RunStatus.Completed);
var metadata = new Dictionary<string, string>
{
["test"] = nameof(Test_04_ModifyRun)
};
var modified = await run.ModifyAsync(metadata);
Assert.IsNotNull(modified);
Assert.AreEqual(run.Id, modified.Id);
Assert.IsNotNull(modified.Metadata);
Assert.Contains("key", modified.Metadata.Keys.ToList());
Assert.AreEqual("value", modified.Metadata["key"]);
Assert.Contains("test", modified.Metadata.Keys.ToList());
Assert.AreEqual(nameof(Test_04_ModifyRun), modified.Metadata["test"]);
}

[Test]
public async Task Test_03_ListRuns()
{
var request = new CreateRunRequest(assistantId);
var run = await OpenAIClient.ThreadsEndpoint.CreateRunAsync(threadId, request);
Assert.IsNotNull(run);
var list = await OpenAIClient.ThreadsEndpoint.ListRunsAsync(threadId);
Assert.NotNull(testThread);
Assert.NotNull(OpenAIClient.ThreadsEndpoint);
var list = await testThread.ListRunsAsync();
Assert.IsNotNull(list);
Assert.IsNotEmpty(list.Items);

foreach (var threadRun in list.Items)
foreach (var run in list.Items)
{
var retrievedRun = await OpenAIClient.ThreadsEndpoint.RetrieveRunAsync(threadId, threadRun);
var retrievedRun = await run.UpdateAsync();
Assert.IsNotNull(retrievedRun);
}
}
Expand Down Expand Up @@ -197,24 +199,5 @@ public async Task Test_03_ListRuns()
// Console.WriteLine($"[{retrieved.ThreadId}] -> {retrieved.Id}");
// }
//}

private async Task<RunResponse> WaitOnRunAsync(string thread, string run, params RunStatus[] statuses)
{
var loopCounter = 0;
RunResponse runResponse;

do
{
if (++loopCounter > 10)
{
Assert.Fail($"Spent too much in long in {string.Join(',', statuses)} statuses");
}

await Task.Delay(2000);
runResponse = await OpenAIClient.ThreadsEndpoint.RetrieveRunAsync(thread, run);
} while (statuses.Contains(runResponse.Status));

return runResponse;
}
}
}
105 changes: 103 additions & 2 deletions OpenAI-DotNet/Threads/ThreadExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using OpenAI.Assistants;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -65,7 +67,7 @@ public static async Task<ListResponse<MessageResponse>> ListMessagesAsync(this T
/// <param name="message"><see cref="MessageResponse"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="MessageResponse"/>.</returns>
public static async Task<MessageResponse> RetrieveMessageAsync(this MessageResponse message, CancellationToken cancellationToken = default)
public static async Task<MessageResponse> RetrieveAsync(this MessageResponse message, CancellationToken cancellationToken = default)
=> await message.Client.ThreadsEndpoint.RetrieveMessageAsync(message, cancellationToken);

/// <summary>
Expand All @@ -91,7 +93,7 @@ public static async Task<MessageResponse> RetrieveMessageAsync(this ThreadRespon
/// </param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="MessageResponse"/>.</returns>
public static async Task<MessageResponse> ModifyMessageAsync(this MessageResponse message, IReadOnlyDictionary<string, string> metadata, CancellationToken cancellationToken = default)
public static async Task<MessageResponse> ModifyAsync(this MessageResponse message, IReadOnlyDictionary<string, string> metadata, CancellationToken cancellationToken = default)
=> await message.Client.ThreadsEndpoint.ModifyMessageAsync(message, metadata, cancellationToken);

/// <summary>
Expand Down Expand Up @@ -147,5 +149,104 @@ public static async Task<MessageFileResponse> RetrieveFileAsync(this MessageResp
=> await message.Client.ThreadsEndpoint.RetrieveFileAsync(message, fileId, cancellationToken);

#endregion Files

#region Runs

/// <summary>
/// Create a run.
/// </summary>
/// <param name="thread"><see cref="ThreadResponse"/>.</param>
/// <param name="request"><see cref="CreateRunRequest"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> CreateRunAsync(this ThreadResponse thread, CreateRunRequest request, CancellationToken cancellationToken = default)
=> await thread.Client.ThreadsEndpoint.CreateRunAsync(thread, request, cancellationToken);

/// <summary>
/// Create a run.
/// </summary>
/// <param name="thread"><see cref="ThreadResponse"/>.</param>
/// <param name="assistantId">Id of the assistant to use for the run.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> CreateRunAsync(this ThreadResponse thread, string assistantId, CancellationToken cancellationToken = default)
=> await thread.Client.ThreadsEndpoint.CreateRunAsync(thread, new CreateRunRequest(assistantId), cancellationToken);

/// <summary>
/// Create a thread and run it.
/// </summary>
/// <param name="assistant"><see cref="AssistantResponse"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> CreateThreadAndRunAsync(this AssistantResponse assistant, CancellationToken cancellationToken = default)
=> await assistant.Client.ThreadsEndpoint.CreateThreadAndRunAsync(new CreateThreadAndRunRequest(assistant.Id), cancellationToken);

/// <summary>
/// Gets the thread associated to the <see cref="RunResponse"/>.
/// </summary>
/// <param name="run"><see cref="RunResponse"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ThreadResponse"/>.</returns>
public static async Task<ThreadResponse> GetThreadAsync(this RunResponse run, CancellationToken cancellationToken = default)
=> await run.Client.ThreadsEndpoint.RetrieveThreadAsync(run.ThreadId, cancellationToken);

/// <summary>
/// List all of the runs associated to a thread.
/// </summary>
/// <param name="thread"><see cref="ThreadResponse"/>.</param>
/// <param name="query"><see cref="ListQuery"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ListResponse{RunResponse}"/></returns>
public static async Task<ListResponse<RunResponse>> ListRunsAsync(this ThreadResponse thread, ListQuery query = null, CancellationToken cancellationToken = default)
=> await thread.Client.ThreadsEndpoint.ListRunsAsync(thread.Id, query, cancellationToken);

/// <summary>
/// Get the latest status of the <see cref="RunResponse"/>.
/// </summary>
/// <param name="run"><see cref="RunResponse"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> UpdateAsync(this RunResponse run, CancellationToken cancellationToken = default)
=> await run.Client.ThreadsEndpoint.RetrieveRunAsync(run, cancellationToken);

private static RunStatus[] DefaultStatusChecks { get; } = { RunStatus.Queued, RunStatus.InProgress };

/// <summary>
/// Waits for run status to change from the provided <see cref="statusChecks"/>.
/// </summary>
/// <param name="run"></param>
/// <param name="statusChecks"><see cref="RunStatus"/> to wait for.</param>
/// <param name="pollingInterval">Optional, time in milliseconds to wait before polling status.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> WaitForStatusAsync(this RunResponse run, RunStatus[] statusChecks = null, int? pollingInterval = null, CancellationToken cancellationToken = default)
{
statusChecks ??= DefaultStatusChecks;
pollingInterval ??= 500;
RunResponse result;
do
{
await Task.Delay(pollingInterval.Value, cancellationToken).ConfigureAwait(false);
result = await run.UpdateAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
} while (statusChecks.Contains(result.Status));
return result;
}

/// <summary>
/// Modifies a run.
/// </summary>
/// <remarks>
/// Only the <see cref="RunResponse.Metadata"/> can be modified.
/// </remarks>
/// <param name="run"><see cref="RunResponse"/> to modify.</param>
/// <param name="metadata">Set of 16 key-value pairs that can be attached to an object.
/// This can be useful for storing additional information about the object in a structured format.
/// Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> ModifyAsync(this RunResponse run, IReadOnlyDictionary<string, string> metadata, CancellationToken cancellationToken = default)
=> await run.Client.ThreadsEndpoint.ModifyRunAsync(run, metadata, cancellationToken);

#endregion Runs
}
}
4 changes: 2 additions & 2 deletions OpenAI-DotNet/Threads/ThreadsEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public async Task<RunResponse> CreateRunAsync(string threadId, CreateRunRequest
/// <summary>
/// Create a thread and run it in one request.
/// </summary>
/// <param name="request"></param>
/// <param name="request"><see cref="CreateThreadAndRunRequest"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public async Task<RunResponse> CreateThreadAndRunAsync(CreateThreadAndRunRequest request, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -281,7 +281,7 @@ public async Task<RunResponse> RetrieveRunAsync(string threadId, string runId, C
/// <param name="threadId">The id of the thread the run belongs to.</param>
/// <param name="query"><see cref="ListQuery"/>.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns>A list of run objects.</returns>
/// <returns><see cref="ListResponse{RunResponse}"/></returns>
public async Task<ListResponse<RunResponse>> ListRunsAsync(string threadId, ListQuery query = null, CancellationToken cancellationToken = default)
{
var response = await Api.Client.GetAsync(GetUrl($"/{threadId}/runs", query), cancellationToken).ConfigureAwait(false);
Expand Down

0 comments on commit b292c2b

Please sign in to comment.