diff --git a/Protos/V1/agent_common.proto b/Protos/V1/agent_common.proto index 0a35ea048..520e399eb 100644 --- a/Protos/V1/agent_common.proto +++ b/Protos/V1/agent_common.proto @@ -47,41 +47,17 @@ message CreateTaskReply { string communication_token = 4; /** Communication token received by the worker during task processing */ } +// Request to retrieve data message DataRequest { string communication_token = 1; /** Communication token received by the worker during task processing */ - string key = 2; + // Id of the result that will be retrieved + string result_id = 2; } -message DataReply { - message Init { - string key = 1; - oneof has_result { - DataChunk data = 2; - string error = 3; - } - } - string communication_token = 1; /** Communication token received by the worker during task processing */ - oneof type { - Init init = 2; - DataChunk data = 3; - string error = 4; - } -} - -message Result { - oneof type { - InitKeyedDataStream init = 1; - DataChunk data = 2; - } - string communication_token = 3; /** Communication token received by the worker during task processing */ -} - -message ResultReply { - string communication_token = 3; /** Communication token received by the worker during task processing */ - oneof type { - Empty Ok = 1; - string Error = 2; - } +// Response when data is available in the shared folder +message DataResponse { + // Id of the result that will be retrieved + string result_id = 2; } /* @@ -156,7 +132,7 @@ message SubmitTasksResponse { } /* -* Request for creating results without data +* Request for creating results with data */ message CreateResultsRequest { /** @@ -180,11 +156,9 @@ message CreateResultsResponse { } /* -* Request for uploading results data through stream. -* Data must be sent in multiple chunks. -* Only one result can be uploaded. +* Request for notifying results data are available in files. */ -message UploadResultDataRequest { +message NotifyResultDataRequest { /** * The metadata to identify the result to update. */ @@ -195,23 +169,15 @@ message UploadResultDataRequest { /** * The possible messages that constitute a UploadResultDataRequest - * They should be sent in the following order: - * - id - * - data_chunk (stream can have multiple data_chunk messages that represent data divided in several parts) - * - * Data chunk cannot exceed the size returned by the GetServiceConfiguration rpc method */ - oneof type { - ResultIdentifier id = 1; /** The identifier of the result to which add data. */ - bytes data_chunk = 2; /** A chunk of data. */ - } + repeated ResultIdentifier ids = 1; /** The identifier of the result to which add data. */ string communication_token = 4; /** Communication token received by the worker during task processing */ } /* -* Response for uploading data with stream for result +* Response for notifying data file availability for result +* Received when data are successfully copied to the ObjectStorage */ -message UploadResultDataResponse { - string result_id = 1; /** The Id of the result to which data were added */ - string communication_token = 2; /** Communication token received by the worker during task processing */ +message NotifyResultDataResponse { + repeated string result_ids = 1; /** The Id of the result to which data were added */ } diff --git a/Protos/V1/agent_service.proto b/Protos/V1/agent_service.proto index 7badffc2b..739396d2e 100644 --- a/Protos/V1/agent_service.proto +++ b/Protos/V1/agent_service.proto @@ -7,6 +7,8 @@ import "agent_common.proto"; option csharp_namespace = "ArmoniK.Api.gRPC.V1.Agent"; service Agent { + rpc CreateTask(stream CreateTaskRequest) returns (CreateTaskReply); + /** * Create the metadata of multiple results at once * Data have to be uploaded separately @@ -19,18 +21,35 @@ service Agent { rpc CreateResults(CreateResultsRequest) returns (CreateResultsResponse) {} /** - * Upload data for result with stream + * Notify Agent that a data file representing the Result to upload is available in the shared folder + * The name of the file should be the result id + * Blocks until data are stored in Object Storage */ - rpc UploadResultData(stream UploadResultDataRequest) returns (UploadResultDataResponse) {} + rpc NotifyResultData(NotifyResultDataRequest) returns (NotifyResultDataResponse) {} /** * Create tasks metadata and submit task for processing. */ rpc SubmitTasks(SubmitTasksRequest) returns (SubmitTasksResponse) {} - rpc CreateTask(stream CreateTaskRequest) returns (CreateTaskReply); - rpc GetResourceData(DataRequest) returns (stream DataReply); - rpc GetCommonData(DataRequest) returns (stream DataReply); - rpc GetDirectData(DataRequest) returns (stream DataReply); - rpc SendResult(stream Result) returns (ResultReply); + /** + * Retrieve Resource Data from the Agent + * Data is stored in the shared folder between Agent and Worker as a file with the result id as name + * Blocks until data are available in the shared folder + */ + rpc GetResourceData(DataRequest) returns (DataResponse); + + /** + * Retrieve Resource Data from the Agent + * Data is stored in the shared folder between Agent and Worker as a file with the result id as name + * Blocks until data are available in the shared folder + */ + rpc GetCommonData(DataRequest) returns (DataResponse); + + /** + * Retrieve Resource Data from the Agent + * Data is stored in the shared folder between Agent and Worker as a file with the result id as name + * Blocks until data are available in the shared folder + */ + rpc GetDirectData(DataRequest) returns (DataResponse); } diff --git a/Protos/V1/worker_common.proto b/Protos/V1/worker_common.proto index a4a704040..9ac8469e9 100644 --- a/Protos/V1/worker_common.proto +++ b/Protos/V1/worker_common.proto @@ -7,35 +7,19 @@ import "objects.proto"; option csharp_namespace = "ArmoniK.Api.gRPC.V1.Worker"; message ProcessRequest { - message ComputeRequest { - message InitRequest { - Configuration configuration = 1; - string session_id = 2; - string task_id = 3; - TaskOptions task_options = 4; - repeated string expected_output_keys = 5; - DataChunk payload = 6; - } - message InitData { - oneof type { - string key = 1; - bool last_data = 2; - } - } - oneof type { - InitRequest init_request = 1; - DataChunk payload = 2; - InitData init_data = 3; - DataChunk data = 4; - } - } string communication_token = 1; - ComputeRequest compute = 2; + string session_id = 2; + string task_id = 3; + TaskOptions task_options = 4; + repeated string expected_output_keys = 5; + string payload_id = 6; + repeated string data_dependencies = 7; + string data_folder = 8; + Configuration configuration = 9; } message ProcessReply { - string communication_token = 1; - Output output = 2; + Output output = 1; } message HealthCheckReply { diff --git a/Protos/V1/worker_service.proto b/Protos/V1/worker_service.proto index 31465afce..2a5a3fb1f 100644 --- a/Protos/V1/worker_service.proto +++ b/Protos/V1/worker_service.proto @@ -8,6 +8,6 @@ import "worker_common.proto"; option csharp_namespace = "ArmoniK.Api.gRPC.V1.Worker"; service Worker { - rpc Process(stream ProcessRequest) returns (ProcessReply); + rpc Process(ProcessRequest) returns (ProcessReply); rpc HealthCheck(Empty) returns (HealthCheckReply); } diff --git a/packages/csharp/ArmoniK.Api.Mock/Services/Agent.cs b/packages/csharp/ArmoniK.Api.Mock/Services/Agent.cs index 14c59c69d..22075a27c 100644 --- a/packages/csharp/ArmoniK.Api.Mock/Services/Agent.cs +++ b/packages/csharp/ArmoniK.Api.Mock/Services/Agent.cs @@ -14,9 +14,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System.Linq; using System.Threading.Tasks; -using ArmoniK.Api.gRPC.V1; using ArmoniK.Api.gRPC.V1.Agent; using Grpc.Core; @@ -36,63 +36,6 @@ public override Task CreateTask(IAsyncStreamReader - [Count] - public override async Task GetCommonData(DataRequest request, - IServerStreamWriter responseStream, - ServerCallContext context) - => await responseStream.WriteAsync(new DataReply - { - Data = new DataChunk - { - DataComplete = true, - }, - }) - .ConfigureAwait(false); - - /// - [Count] - public override async Task GetDirectData(DataRequest request, - IServerStreamWriter responseStream, - ServerCallContext context) - => await responseStream.WriteAsync(new DataReply - { - Data = new DataChunk - { - DataComplete = true, - }, - }) - .ConfigureAwait(false); - - /// - [Count] - public override async Task GetResourceData(DataRequest request, - IServerStreamWriter responseStream, - ServerCallContext context) - => await responseStream.WriteAsync(new DataReply - { - Data = new DataChunk - { - DataComplete = true, - }, - }) - .ConfigureAwait(false); - - /// - [Count] - public override async Task SendResult(IAsyncStreamReader requestStream, - ServerCallContext context) - { - await foreach (var _ in requestStream.ReadAllAsync()) - { - } - - return new ResultReply - { - Ok = new Empty(), - }; - } - /// [Count] public override Task CreateResultsMetaData(CreateResultsMetaDataRequest request, @@ -111,21 +54,6 @@ public override Task SubmitTasks(SubmitTasksRequest request CommunicationToken = request.CommunicationToken, }); - /// - [Count] - public override async Task UploadResultData(IAsyncStreamReader requestStream, - ServerCallContext context) - { - await foreach (var _ in requestStream.ReadAllAsync()) - { - } - - return new UploadResultDataResponse - { - ResultId = "result-id", - CommunicationToken = "communication-token", - }; - } /// [Count] @@ -135,4 +63,43 @@ public override Task CreateResults(CreateResultsRequest r { CommunicationToken = request.CommunicationToken, }); + + /// + [Count] + public override Task GetCommonData(DataRequest request, + ServerCallContext context) + => Task.FromResult(new DataResponse + { + ResultId = request.ResultId, + }); + + /// + [Count] + public override Task GetDirectData(DataRequest request, + ServerCallContext context) + => Task.FromResult(new DataResponse + { + ResultId = request.ResultId, + }); + + /// + [Count] + public override Task GetResourceData(DataRequest request, + ServerCallContext context) + => Task.FromResult(new DataResponse + { + ResultId = request.ResultId, + }); + + /// + [Count] + public override Task NotifyResultData(NotifyResultDataRequest request, + ServerCallContext context) + => Task.FromResult(new NotifyResultDataResponse + { + ResultIds = + { + request.Ids.Select(identifier => identifier.ResultId), + }, + }); } diff --git a/packages/csharp/ArmoniK.Api.Tests/TaskHandlerTest.cs b/packages/csharp/ArmoniK.Api.Tests/TaskHandlerTest.cs index da55dd12d..af4ff1992 100644 --- a/packages/csharp/ArmoniK.Api.Tests/TaskHandlerTest.cs +++ b/packages/csharp/ArmoniK.Api.Tests/TaskHandlerTest.cs @@ -25,6 +25,7 @@ using System.Collections; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using System.Threading; @@ -35,8 +36,6 @@ using ArmoniK.Api.gRPC.V1.Worker; using ArmoniK.Api.Worker.Worker; -using Google.Protobuf; - using Grpc.Core; using Microsoft.Extensions.Logging; @@ -48,32 +47,6 @@ namespace ArmoniK.Api.Worker.Tests; [TestFixture] public class TaskHandlerTest { - [SetUp] - public void SetUp() - { - } - - [TearDown] - public virtual void TearDown() - { - } - - private class MyAsyncStreamReader : IAsyncStreamReader - { - private readonly IAsyncEnumerator asyncEnumerator_; - - public MyAsyncStreamReader(IEnumerable requests) - => asyncEnumerator_ = requests.ToAsyncEnumerable() - .GetAsyncEnumerator(); - - public async Task MoveNext(CancellationToken cancellationToken) - => await asyncEnumerator_.MoveNextAsync(cancellationToken) - .ConfigureAwait(false); - - public ProcessRequest Current - => asyncEnumerator_.Current; - } - private class MyClientStreamWriter : IClientStreamWriter { public readonly ConcurrentBag Messages = new(); @@ -104,26 +77,8 @@ public Task CompleteAsync() private class MyAgent : Agent.AgentClient { - private readonly MyClientStreamWriter resultStream_; - private readonly MyClientStreamWriter taskStream_; - - public MyAgent() - { - resultStream_ = new MyClientStreamWriter(); - taskStream_ = new MyClientStreamWriter(); - } + private readonly MyClientStreamWriter taskStream_ = new(); - public override AsyncClientStreamingCall SendResult(Metadata headers = null, - DateTime? deadline = null, - CancellationToken cancellationToken = default) - => new(resultStream_, - Task.FromResult(new ResultReply()), - Task.FromResult(new Metadata()), - () => Status.DefaultSuccess, - () => new Metadata(), - () => - { - }); public override AsyncClientStreamingCall CreateTask(Metadata headers = null, DateTime? deadline = null, @@ -137,9 +92,6 @@ public override AsyncClientStreamingCall Cre { }); - public List GetResults() - => resultStream_.Messages.ToList(); - public List GetTaskRequests() => taskStream_.Messages.ToList(); } @@ -147,452 +99,96 @@ public List GetTaskRequests() [Test] [TestCaseSource(typeof(TaskHandlerTest), - nameof(TaskHandlerCreateShouldThrowTestCases))] - public void TaskHandlerCreateShouldThrow(IEnumerable requests) + nameof(InvalidRequests))] + public void NewTaskHandlerShouldThrow(ProcessRequest request) { - var stream = new MyAsyncStreamReader(requests); - var agent = new MyAgent(); - Assert.ThrowsAsync(async () => await TaskHandler.Create(stream, - agent, - new LoggerFactory(), - CancellationToken.None) - .ConfigureAwait(false)); + Assert.Throws(() => new TaskHandler(request, + agent, + new LoggerFactory(), + CancellationToken.None)); } - [Test] - [TestCaseSource(typeof(TaskHandlerTest), - nameof(TaskHandlerCreateShouldSucceedTestCases))] - public async Task TaskHandlerCreateShouldSucceed(IEnumerable requests) + public static IEnumerable InvalidRequests { - var stream = new MyAsyncStreamReader(requests); - - var agent = new MyAgent(); - - var taskHandler = await TaskHandler.Create(stream, - agent, - new LoggerFactory(), - CancellationToken.None) - .ConfigureAwait(false); - - Assert.NotNull(taskHandler.Token); - Assert.IsNotEmpty(taskHandler.Token); - Assert.IsNotEmpty(taskHandler.Payload); - Assert.IsNotEmpty(taskHandler.SessionId); - Assert.IsNotEmpty(taskHandler.TaskId); + get { yield return new TestCaseData(new ProcessRequest()).SetArgDisplayNames("Empty request"); } } [Test] - public async Task CheckTaskHandlerDataAreCorrect() + public async Task NewTaskHandlerShouldSucceed() { - var stream = new MyAsyncStreamReader(WorkingRequest1); - var agent = new MyAgent(); - var taskHandler = await TaskHandler.Create(stream, - agent, - new LoggerFactory(), - CancellationToken.None) - .ConfigureAwait(false); - - Assert.IsNotEmpty(taskHandler.Payload); - Assert.AreEqual("testPayload1Payload2", - ByteString.CopyFrom(taskHandler.Payload) - .ToStringUtf8()); - Assert.AreEqual(2, - taskHandler.DataDependencies.Count); - Assert.AreEqual("Data1Data2", - ByteString.CopyFrom(taskHandler.DataDependencies.Values.First()) - .ToStringUtf8()); - Assert.AreEqual("Data1Data2Data2Data2", - ByteString.CopyFrom(taskHandler.DataDependencies.Values.Last()) - .ToStringUtf8()); - Assert.AreEqual("TaskId", - taskHandler.TaskId); - Assert.AreEqual("SessionId", - taskHandler.SessionId); - Assert.AreEqual("Token", - taskHandler.Token); - - await taskHandler.SendResult("test", - Encoding.ASCII.GetBytes("TestData")); - - var results = agent.GetResults(); - foreach (var r in results) - { - Console.WriteLine(r); - } - - Assert.AreEqual(4, - results.Count); - - Assert.AreEqual(Result.TypeOneofCase.Init, - results[0] - .TypeCase); - Assert.AreEqual(true, - results[0] - .Init.LastResult); - - Assert.AreEqual(Result.TypeOneofCase.Data, - results[1] - .TypeCase); - Assert.AreEqual(true, - results[1] - .Data.DataComplete); - - Assert.AreEqual(Result.TypeOneofCase.Data, - results[2] - .TypeCase); - Assert.AreEqual("TestData", - results[2] - .Data.Data); - - Assert.AreEqual(Result.TypeOneofCase.Init, - results[3] - .TypeCase); - Assert.AreEqual("test", - results[3] - .Init.Key); - - - await taskHandler.CreateTasksAsync(new List - { - new() - { - Payload = ByteString.CopyFromUtf8("Payload"), - DataDependencies = - { - "DD", - }, - ExpectedOutputKeys = - { - "EOK", - }, - }, - }); - - var tasks = agent.GetTaskRequests(); - Console.WriteLine(); - foreach (var t in tasks) - { - Console.WriteLine(t); - } - - Assert.AreEqual(5, - tasks.Count); - - Assert.AreEqual(CreateTaskRequest.TypeOneofCase.InitTask, - tasks[0] - .TypeCase); - Assert.AreEqual(true, - tasks[0] - .InitTask.LastTask); - - Assert.AreEqual(CreateTaskRequest.TypeOneofCase.TaskPayload, - tasks[1] - .TypeCase); - Assert.AreEqual(true, - tasks[1] - .TaskPayload.DataComplete); - - Assert.AreEqual(CreateTaskRequest.TypeOneofCase.TaskPayload, - tasks[2] - .TypeCase); - Assert.AreEqual("Payload", - tasks[2] - .TaskPayload.Data); - - Assert.AreEqual(CreateTaskRequest.TypeOneofCase.InitTask, - tasks[3] - .TypeCase); - Assert.AreEqual("DD", - tasks[3] - .InitTask.Header.DataDependencies.Single()); - Assert.AreEqual("EOK", - tasks[3] - .InitTask.Header.ExpectedOutputKeys.Single()); - - Assert.AreEqual(CreateTaskRequest.TypeOneofCase.InitRequest, - tasks[4] - .TypeCase); - } - - private static readonly ProcessRequest InitData1 = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitData = new ProcessRequest.Types.ComputeRequest.Types.InitData - { - Key = "DataKey1", - }, - }, - }; - - private static readonly ProcessRequest InitData2 = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitData = new ProcessRequest.Types.ComputeRequest.Types.InitData - { - Key = "DataKey2", - }, - }, - }; - - private static readonly ProcessRequest LastDataTrue = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitData = new ProcessRequest.Types.ComputeRequest.Types.InitData - { - LastData = true, - }, - }, - }; - - private static readonly ProcessRequest LastDataFalse = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitData = new ProcessRequest.Types.ComputeRequest.Types.InitData - { - LastData = false, - }, - }, - }; - - private static readonly ProcessRequest InitRequestPayload = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitRequest = new ProcessRequest.Types.ComputeRequest.Types.InitRequest - { - Payload = new DataChunk - { - Data = ByteString.CopyFromUtf8("test"), - }, - Configuration = new Configuration - { - DataChunkMaxSize = 100, - }, - ExpectedOutputKeys = - { - "EOK", - }, - SessionId = "SessionId", - TaskId = "TaskId", - }, - }, - }; - - private static readonly ProcessRequest InitRequestEmptyPayload = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - InitRequest = new ProcessRequest.Types.ComputeRequest.Types.InitRequest - { - Configuration = new Configuration - { - DataChunkMaxSize = 100, - }, - ExpectedOutputKeys = - { - "EOK", - }, - SessionId = "SessionId", - TaskId = "TaskId", - }, - }, - }; - - private static readonly ProcessRequest Payload1 = new() + var payloadId = Guid.NewGuid() + .ToString(); + var taskId = Guid.NewGuid() + .ToString(); + var token = Guid.NewGuid() + .ToString(); + var sessionId = Guid.NewGuid() + .ToString(); + var dd1 = Guid.NewGuid() + .ToString(); + var eok1 = Guid.NewGuid() + .ToString(); + + var folder = Path.Combine(Path.GetTempPath(), + token); + + Directory.CreateDirectory(folder); + + var payloadBytes = Encoding.ASCII.GetBytes("payload"); + var dd1Bytes = Encoding.ASCII.GetBytes("DataDependency1"); + var eok1Bytes = Encoding.ASCII.GetBytes("ExpectedOutput1"); + + await File.WriteAllBytesAsync(Path.Combine(folder, + payloadId), + payloadBytes); + await File.WriteAllBytesAsync(Path.Combine(folder, + dd1), + dd1Bytes); + + var handler = new TaskHandler(new ProcessRequest + { + CommunicationToken = token, + DataFolder = folder, + PayloadId = payloadId, + SessionId = sessionId, + Configuration = new Configuration { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Payload = new DataChunk - { - Data = ByteString.CopyFromUtf8("Payload1"), - }, - }, - }; - - private static readonly ProcessRequest Payload2 = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Payload = new DataChunk - { - Data = ByteString.CopyFromUtf8("Payload2"), - }, - }, - }; - - private static readonly ProcessRequest PayloadComplete = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Payload = new DataChunk - { - DataComplete = true, - }, - }, - }; - - private static readonly ProcessRequest Data1 = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Data = new DataChunk - { - Data = ByteString.CopyFromUtf8("Data1"), - }, - }, - }; - - private static readonly ProcessRequest Data2 = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Data = new DataChunk - { - Data = ByteString.CopyFromUtf8("Data2"), - }, - }, - }; - - private static readonly ProcessRequest DataComplete = new() - { - CommunicationToken = "Token", - Compute = new ProcessRequest.Types.ComputeRequest - { - Data = new DataChunk - { - DataComplete = true, - }, - }, - }; - - public static IEnumerable TaskHandlerCreateShouldThrowTestCases - { - get - { - yield return new TestCaseData(new ProcessRequest[] - { - }.AsEnumerable()); - yield return new TestCaseData(new[] - { - InitData1, - }.AsEnumerable()); - yield return new TestCaseData(new[] - { - InitData2, - }.AsEnumerable()); - yield return new TestCaseData(new[] - { - LastDataTrue, - }.AsEnumerable()); - yield return new TestCaseData(new[] - { - LastDataFalse, - }.AsEnumerable()); - yield return new TestCaseData(new[] - { - InitRequestPayload, - }.AsEnumerable()).SetArgDisplayNames(nameof(InitRequestPayload)); - yield return new TestCaseData(new[] - { - DataComplete, - }.AsEnumerable()).SetArgDisplayNames(nameof(DataComplete)); - yield return new TestCaseData(new[] - { - InitRequestEmptyPayload, - }.AsEnumerable()).SetArgDisplayNames(nameof(InitRequestEmptyPayload)); - yield return new TestCaseData(new[] + DataChunkMaxSize = 84, + }, + DataDependencies = { - InitRequestPayload, - PayloadComplete, - InitData1, - Data1, - LastDataTrue, - }.AsEnumerable()).SetArgDisplayNames("NotWorkingRequest1"); - yield return new TestCaseData(new[] + dd1, + }, + ExpectedOutputKeys = { - InitRequestPayload, - InitData1, - Data1, - DataComplete, - LastDataTrue, - }.AsEnumerable()).SetArgDisplayNames("NotWorkingRequest2"); - yield return new TestCaseData(new[] - { - InitRequestPayload, - PayloadComplete, - Data1, - DataComplete, - LastDataTrue, - }.AsEnumerable()).SetArgDisplayNames("NotWorkingRequest3"); - } - } - - private static readonly IEnumerable WorkingRequest1 = new[] - { - InitRequestPayload, - Payload1, - Payload2, - PayloadComplete, - InitData1, - Data1, - Data2, - DataComplete, - InitData2, - Data1, - Data2, - Data2, - Data2, - DataComplete, - LastDataTrue, - }.AsEnumerable(); - - private static readonly IEnumerable WorkingRequest2 = new[] - { - InitRequestPayload, - Payload1, - PayloadComplete, - InitData1, - Data1, - DataComplete, - LastDataTrue, - }.AsEnumerable(); - - private static readonly IEnumerable WorkingRequest3 = new[] - { - InitRequestPayload, - PayloadComplete, - InitData1, - Data1, - DataComplete, - LastDataTrue, - }.AsEnumerable(); - - public static IEnumerable TaskHandlerCreateShouldSucceedTestCases - { - get - { - yield return new TestCaseData(WorkingRequest1).SetArgDisplayNames(nameof(WorkingRequest1)); - yield return new TestCaseData(WorkingRequest2).SetArgDisplayNames(nameof(WorkingRequest2)); - yield return new TestCaseData(WorkingRequest3).SetArgDisplayNames(nameof(WorkingRequest3)); - } + eok1, + }, + TaskId = taskId, + }, + agent, + new LoggerFactory(), + CancellationToken.None); + + Assert.ThrowsAsync(() => handler.SendResult(eok1, + eok1Bytes)); + + Assert.Multiple(() => + { + Assert.AreEqual(payloadBytes, + handler.Payload); + Assert.AreEqual(sessionId, + handler.SessionId); + Assert.AreEqual(taskId, + handler.TaskId); + Assert.AreEqual(dd1Bytes, + handler.DataDependencies[dd1]); + Assert.AreEqual(eok1Bytes, + File.ReadAllBytes(Path.Combine(folder, + eok1))); + }); } } diff --git a/packages/csharp/ArmoniK.Api.Worker/Worker/ITaskHandler.cs b/packages/csharp/ArmoniK.Api.Worker/Worker/ITaskHandler.cs index 91abcef06..8d0ff3e89 100644 --- a/packages/csharp/ArmoniK.Api.Worker/Worker/ITaskHandler.cs +++ b/packages/csharp/ArmoniK.Api.Worker/Worker/ITaskHandler.cs @@ -32,6 +32,9 @@ namespace ArmoniK.Api.Worker.Worker; +/// +/// Higher level interface to implement to create tasks and populate results +/// [PublicAPI] public interface ITaskHandler : IAsyncDisposable { @@ -68,7 +71,7 @@ public interface ITaskHandler : IAsyncDisposable /// /// The configuration parameters for the interaction with ArmoniK. /// - Configuration? Configuration { get; } + Configuration Configuration { get; } /// /// This method allows to create subtasks. @@ -140,15 +143,4 @@ Task SubmitTasksAsync(IEnumerable Task CreateResultsAsync(IEnumerable results); - - /// - /// Upload data to an existing result - /// - /// The result Id - /// The data to submit for the given result - /// - /// The upload data response - /// - Task UploadResultData(string key, - byte[] data); } diff --git a/packages/csharp/ArmoniK.Api.Worker/Worker/TaskHandler.cs b/packages/csharp/ArmoniK.Api.Worker/Worker/TaskHandler.cs index 62158608a..c13820e6e 100644 --- a/packages/csharp/ArmoniK.Api.Worker/Worker/TaskHandler.cs +++ b/packages/csharp/ArmoniK.Api.Worker/Worker/TaskHandler.cs @@ -22,7 +22,10 @@ // limitations under the License. using System; +using System.Collections; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -31,77 +34,147 @@ using ArmoniK.Api.gRPC.V1.Agent; using ArmoniK.Api.gRPC.V1.Worker; -using Google.Protobuf; - -using Grpc.Core; - using Microsoft.Extensions.Logging; namespace ArmoniK.Api.Worker.Worker; +internal class ReadFromFolderDict : IReadOnlyDictionary +{ + private readonly Dictionary data_ = new(); + private readonly IList dataDependencies_; + private readonly string folder_; + + public ReadFromFolderDict(string folder, + IList dataDependencies) + { + folder_ = folder; + dataDependencies_ = dataDependencies; + } + + /// + public IEnumerator> GetEnumerator() + => dataDependencies_.Select(key => new KeyValuePair(key, + this[key])) + .GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + /// + public int Count + => dataDependencies_.Count; + + /// + public bool ContainsKey(string key) + => dataDependencies_.Contains(key); + + /// + public bool TryGetValue(string key, + [MaybeNullWhen(false)] out byte[] value) + { + var r = ContainsKey(key); + if (r) + { + value = this[key]; + return r; + } + + value = null; + return r; + } + + /// + public byte[] this[string key] + { + get + { + if (data_.TryGetValue(key, + out var value)) + { + return value; + } + + var bytes = File.ReadAllBytes(Path.Combine(folder_, + key)); + data_.Add(key, + bytes); + return bytes; + } + } + + /// + public IEnumerable Keys + => dataDependencies_; + + /// + public IEnumerable Values + => dataDependencies_.Select(key => this[key]); +} + public class TaskHandler : ITaskHandler { private readonly CancellationToken cancellationToken_; private readonly Agent.AgentClient client_; + private readonly string folder_; private readonly ILogger logger_; private readonly ILoggerFactory loggerFactory_; - private readonly IAsyncStreamReader requestStream_; - - private IReadOnlyDictionary? dataDependencies_; - private IList? expectedResults_; - - private bool isInitialized_; - private byte[]? payload_; - private string? sessionId_; - private string? taskId_; - private TaskOptions? taskOptions_; - private string? token_; - - - private TaskHandler(IAsyncStreamReader requestStream, - Agent.AgentClient client, - CancellationToken cancellationToken, - ILoggerFactory loggerFactory) + public TaskHandler(ProcessRequest processRequest, + Agent.AgentClient client, + ILoggerFactory loggerFactory, + CancellationToken cancellationToken) { - requestStream_ = requestStream; client_ = client; cancellationToken_ = cancellationToken; loggerFactory_ = loggerFactory; logger_ = loggerFactory.CreateLogger(); + folder_ = processRequest.DataFolder; + + Token = processRequest.CommunicationToken; + SessionId = processRequest.SessionId; + TaskId = processRequest.TaskId; + TaskOptions = processRequest.TaskOptions; + ExpectedResults = processRequest.ExpectedOutputKeys; + DataDependencies = new ReadFromFolderDict(processRequest.DataFolder, + processRequest.DataDependencies); + Configuration = processRequest.Configuration; + + + try + { + Payload = File.ReadAllBytes(Path.Combine(processRequest.DataFolder, + processRequest.PayloadId)); + } + catch (ArgumentException e) + { + throw new InvalidOperationException("Payload not found", + e); + } } - public string Token - => token_ ?? throw TaskHandlerException(nameof(Token)); + public string Token { get; } /// - public string SessionId - => sessionId_ ?? throw TaskHandlerException(nameof(SessionId)); + public Configuration Configuration { get; } /// - public string TaskId - => taskId_ ?? throw TaskHandlerException(nameof(TaskId)); + public string SessionId { get; } /// - public TaskOptions TaskOptions - => taskOptions_ ?? throw TaskHandlerException(nameof(TaskOptions)); + public string TaskId { get; } /// - public byte[] Payload - => payload_ ?? throw TaskHandlerException(nameof(Payload)); + public TaskOptions TaskOptions { get; } /// - public IReadOnlyDictionary DataDependencies - => dataDependencies_ ?? throw TaskHandlerException(nameof(DataDependencies)); + public byte[] Payload { get; } /// - public IList ExpectedResults - => expectedResults_ ?? throw TaskHandlerException(nameof(ExpectedResults)); + public IReadOnlyDictionary DataDependencies { get; } - // this ? was added due to the initialization pattern with the Create method /// - public Configuration? Configuration { get; private set; } + public IList ExpectedResults { get; } /// public async Task CreateTasksAsync(IEnumerable tasks, @@ -144,7 +217,7 @@ public async Task CreateResultsMetaDataAsync(IEnu { results, }, - SessionId = sessionId_, + SessionId = SessionId, }) .ConfigureAwait(false); @@ -153,70 +226,30 @@ public async Task CreateResultsMetaDataAsync(IEnu public async Task SendResult(string key, byte[] data) { - using var stream = client_.SendResult(); - - await stream.RequestStream.WriteAsync(new Result - { - CommunicationToken = Token, - Init = new InitKeyedDataStream - { - Key = key, - }, - }) - .ConfigureAwait(false); - var start = 0; - - while (start < data.Length) + await using (var fs = new FileStream(Path.Combine(folder_, + key), + FileMode.OpenOrCreate)) { - var chunkSize = Math.Min(Configuration!.DataChunkMaxSize, - data.Length - start); - - await stream.RequestStream.WriteAsync(new Result - { - CommunicationToken = Token, - Data = new DataChunk - { - Data = UnsafeByteOperations.UnsafeWrap(data.AsMemory() - .Slice(start, - chunkSize)), - }, - }) - .ConfigureAwait(false); - - start += chunkSize; + await using var w = new BinaryWriter(fs); + w.Write(data); } - await stream.RequestStream.WriteAsync(new Result - { - CommunicationToken = Token, - Data = new DataChunk - { - DataComplete = true, - }, - }) - .ConfigureAwait(false); - - await stream.RequestStream.WriteAsync(new Result + await client_.NotifyResultDataAsync(new NotifyResultDataRequest + { + CommunicationToken = Token, + Ids = { - CommunicationToken = Token, - Init = new InitKeyedDataStream - { - LastResult = true, - }, - }) - .ConfigureAwait(false); - - await stream.RequestStream.CompleteAsync() - .ConfigureAwait(false); - - var reply = await stream.ResponseAsync.ConfigureAwait(false); - if (reply.TypeCase == ResultReply.TypeOneofCase.Error) - { - logger_.LogError(reply.Error); - throw new InvalidOperationException($"Cannot send result id={key}"); - } + new NotifyResultDataRequest.Types.ResultIdentifier + { + SessionId = SessionId, + ResultId = key, + }, + }, + }) + .ConfigureAwait(false); } + /// public ValueTask DisposeAsync() => ValueTask.CompletedTask; @@ -226,7 +259,7 @@ public async Task SubmitTasksAsync(IEnumerable await client_.SubmitTasksAsync(new SubmitTasksRequest { CommunicationToken = Token, - SessionId = sessionId_, + SessionId = SessionId, TaskCreations = { taskCreations, @@ -240,217 +273,11 @@ public async Task CreateResultsAsync(IEnumerable await client_.CreateResultsAsync(new CreateResultsRequest { CommunicationToken = Token, - SessionId = sessionId_, + SessionId = SessionId, Results = { results, }, }) .ConfigureAwait(false); - - public async Task UploadResultData(string key, - byte[] data) - { - var stream = client_.UploadResultData(); - - await stream.RequestStream.WriteAsync(new UploadResultDataRequest - { - Id = new UploadResultDataRequest.Types.ResultIdentifier - { - ResultId = key, - SessionId = sessionId_, - }, - CommunicationToken = Token, - }) - .ConfigureAwait(false); - - var start = 0; - while (start < data.Length) - { - var chunkSize = Math.Min(Configuration!.DataChunkMaxSize, - data.Length - start); - - await stream.RequestStream.WriteAsync(new UploadResultDataRequest - { - CommunicationToken = Token, - DataChunk = UnsafeByteOperations.UnsafeWrap(data.AsMemory() - .Slice(start, - chunkSize)), - }) - .ConfigureAwait(false); - - start += chunkSize; - } - - await stream.RequestStream.CompleteAsync() - .ConfigureAwait(false); - - return await stream.ResponseAsync.ConfigureAwait(false); - } - - public static async Task Create(IAsyncStreamReader requestStream, - Agent.AgentClient agentClient, - ILoggerFactory loggerFactory, - CancellationToken cancellationToken) - { - var output = new TaskHandler(requestStream, - agentClient, - cancellationToken, - loggerFactory); - await output.Init() - .ConfigureAwait(false); - return output; - } - - private async Task Init() - { - if (!await requestStream_.MoveNext() - .ConfigureAwait(false)) - { - throw new InvalidOperationException("Request stream ended unexpectedly."); - } - - if (requestStream_.Current.Compute.TypeCase != ProcessRequest.Types.ComputeRequest.TypeOneofCase.InitRequest) - { - throw new InvalidOperationException("Expected a Compute request type with InitRequest to start the stream."); - } - - var initRequest = requestStream_.Current.Compute.InitRequest; - sessionId_ = initRequest.SessionId; - taskId_ = initRequest.TaskId; - taskOptions_ = initRequest.TaskOptions; - expectedResults_ = initRequest.ExpectedOutputKeys; - Configuration = initRequest.Configuration; - token_ = requestStream_.Current.CommunicationToken; - - if (initRequest.Payload is null) - { - throw new InvalidOperationException("Payload from InitRequest should not be null"); - } - - - if (initRequest.Payload.DataComplete) - { - payload_ = initRequest.Payload.Data.ToByteArray(); - } - else - { - var chunks = new List(); - var dataChunk = initRequest.Payload; - - chunks.Add(dataChunk.Data); - - while (!dataChunk.DataComplete) - { - if (!await requestStream_.MoveNext(cancellationToken_) - .ConfigureAwait(false)) - { - throw new InvalidOperationException("Request stream ended unexpectedly."); - } - - if (requestStream_.Current.Compute.TypeCase != ProcessRequest.Types.ComputeRequest.TypeOneofCase.Payload) - { - throw new InvalidOperationException("Expected a Compute request type with Payload to continue the stream."); - } - - dataChunk = requestStream_.Current.Compute.Payload; - - chunks.Add(dataChunk.Data); - } - - - var size = chunks.Sum(s => s.Length); - - var payload = new byte[size]; - - var start = 0; - - foreach (var chunk in chunks) - { - chunk.CopyTo(payload, - start); - start += chunk.Length; - } - - payload_ = payload; - } - - var dataDependencies = new Dictionary(); - - ProcessRequest.Types.ComputeRequest.Types.InitData initData; - do - { - if (!await requestStream_.MoveNext(cancellationToken_) - .ConfigureAwait(false)) - { - throw new InvalidOperationException("Request stream ended unexpectedly."); - } - - - if (requestStream_.Current.Compute.TypeCase != ProcessRequest.Types.ComputeRequest.TypeOneofCase.InitData) - { - throw new InvalidOperationException("Expected a Compute request type with InitData to continue the stream."); - } - - initData = requestStream_.Current.Compute.InitData; - if (!string.IsNullOrEmpty(initData.Key)) - { - var chunks = new List(); - - while (true) - { - if (!await requestStream_.MoveNext(cancellationToken_) - .ConfigureAwait(false)) - { - throw new InvalidOperationException("Request stream ended unexpectedly."); - } - - if (requestStream_.Current.Compute.TypeCase != ProcessRequest.Types.ComputeRequest.TypeOneofCase.Data) - { - throw new InvalidOperationException("Expected a Compute request type with Data to continue the stream."); - } - - var dataChunk = requestStream_.Current.Compute.Data; - - if (dataChunk.TypeCase == DataChunk.TypeOneofCase.Data) - { - chunks.Add(dataChunk.Data); - } - - if (dataChunk.TypeCase == DataChunk.TypeOneofCase.None) - { - throw new InvalidOperationException("Expected a Compute request type with a DataChunk Payload to continue the stream."); - } - - if (dataChunk.TypeCase == DataChunk.TypeOneofCase.DataComplete) - { - break; - } - } - - var size = chunks.Sum(s => s.Length); - - var data = new byte[size]; - - var start = 0; - - foreach (var chunk in chunks) - { - chunk.CopyTo(data, - start); - start += chunk.Length; - } - - dataDependencies[initData.Key] = data; - } - } while (!string.IsNullOrEmpty(initData.Key)); - - dataDependencies_ = dataDependencies; - isInitialized_ = true; - } - - private Exception TaskHandlerException(string argumentName) - => isInitialized_ - ? new InvalidOperationException($"Error in initalization: {argumentName} is null") - : new InvalidOperationException(""); } diff --git a/packages/csharp/ArmoniK.Api.Worker/Worker/WorkerStreamWrapper.cs b/packages/csharp/ArmoniK.Api.Worker/Worker/WorkerStreamWrapper.cs index 0edca8218..efd4b17aa 100644 --- a/packages/csharp/ArmoniK.Api.Worker/Worker/WorkerStreamWrapper.cs +++ b/packages/csharp/ArmoniK.Api.Worker/Worker/WorkerStreamWrapper.cs @@ -57,20 +57,22 @@ public WorkerStreamWrapper(ILoggerFactory loggerFactory, client_ = new Agent.AgentClient(channel_); } + /// public async ValueTask DisposeAsync() => await channel_.ShutdownAsync() .ConfigureAwait(false); - public sealed override async Task Process(IAsyncStreamReader requestStream, - ServerCallContext context) + + /// + public sealed override async Task Process(ProcessRequest request, + ServerCallContext context) { Output output; { - await using var taskHandler = await TaskHandler.Create(requestStream, - client_, - loggerFactory_, - context.CancellationToken) - .ConfigureAwait(false); + await using var taskHandler = new TaskHandler(request, + client_, + loggerFactory_, + context.CancellationToken); using var _ = logger_.BeginNamedScope("Execute task", ("taskId", taskHandler.TaskId), @@ -89,6 +91,7 @@ public virtual Task Process(ITaskHandler taskHandler) => throw new RpcException(new Status(StatusCode.Unimplemented, "")); + /// public override Task HealthCheck(Empty request, ServerCallContext context) => Task.FromResult(new HealthCheckReply