Skip to content

Commit

Permalink
feat: Task execution pipelining (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Sep 13, 2023
2 parents d526475 + 3f3a399 commit 8b5ea41
Show file tree
Hide file tree
Showing 20 changed files with 805 additions and 240 deletions.
20 changes: 14 additions & 6 deletions Adaptors/Memory/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,20 @@ public Task StartTask(TaskData taskData,
taskId2TaskData_.AddOrUpdate(taskData.TaskId,
_ => throw new TaskNotFoundException($"Key '{taskData.TaskId}' not found"),
(_,
data) => data with
{
Status = TaskStatus.Processing,
StartDate = taskData.StartDate,
PodTtl = taskData.PodTtl,
});
data) =>
{
if (data.Status is TaskStatus.Error or TaskStatus.Completed or TaskStatus.Retried or TaskStatus.Cancelled)
{
throw new TaskAlreadyInFinalStateException($"{taskData.TaskId} is already in a final state : {data.Status}");
}

return data with
{
Status = TaskStatus.Processing,
StartDate = taskData.StartDate,
PodTtl = taskData.PodTtl,
};
});
return Task.CompletedTask;
}

Expand Down
20 changes: 11 additions & 9 deletions Adaptors/MongoDB/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Utils;

using Microsoft.Extensions.Diagnostics.HealthChecks;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -143,30 +144,31 @@ public async Task StartTask(TaskData taskData,
taskData.StartDate)
.Set(tdm => tdm.PodTtl,
taskData.PodTtl);
Logger.LogInformation("update task {taskId} to status {status}",
Logger.LogInformation("Trying to start task {taskId} and update to status {status}",
taskData.TaskId,
TaskStatus.Processing);
var res = await taskCollection.UpdateManyAsync(x => x.TaskId == taskData.TaskId && x.Status != TaskStatus.Completed && x.Status != TaskStatus.Cancelled,
var res = await taskCollection.UpdateManyAsync(x => x.TaskId == taskData.TaskId && x.Status != TaskStatus.Completed && x.Status != TaskStatus.Cancelled &&
x.Status != TaskStatus.Error && x.Status != TaskStatus.Retried,
updateDefinition,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

switch (res.MatchedCount)
{
case 0:
var taskStatus = await GetTaskStatus(new[]
{
taskData.TaskId,
},
cancellationToken)
.ConfigureAwait(false);
var taskStatus = (await GetTaskStatus(new[]
{
taskData.TaskId,
},
cancellationToken)
.ConfigureAwait(false)).AsICollection();

if (!taskStatus.Any())
{
throw new TaskNotFoundException($"Task {taskData.TaskId} not found");
}

throw new ArmoniKException($"Task already in a terminal state - {taskStatus.Single()} to {TaskStatus.Processing}");
throw new TaskAlreadyInFinalStateException($"Task already in a terminal state - {taskStatus.Single()} to {TaskStatus.Processing}");
case > 1:
throw new ArmoniKException("Multiple tasks modified");
}
Expand Down
40 changes: 40 additions & 0 deletions Common/src/Exceptions/TaskAlreadyInFinalStateException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2023. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System;

namespace ArmoniK.Core.Common.Exceptions;

[Serializable]
public class TaskAlreadyInFinalStateException : ArmoniKException
{
public TaskAlreadyInFinalStateException()
{
}

public TaskAlreadyInFinalStateException(string message)
: base(message)
{
}

public TaskAlreadyInFinalStateException(string message,
Exception innerException)
: base(message,
innerException)
{
}
}
6 changes: 6 additions & 0 deletions Common/src/Injection/Options/Pollster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ public class Pollster
/// Negative values disable the check
/// </summary>
public int MaxErrorAllowed { get; set; } = 5;

/// <summary>
/// Timeout before releasing the current acquired task and acquiring a new one
/// This happens in parallel of the execution of another task
/// </summary>
public TimeSpan TimeoutBeforeNextAcquisition { get; set; } = TimeSpan.FromSeconds(10);
}
169 changes: 109 additions & 60 deletions Common/src/Pollster/Pollster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
Expand Down Expand Up @@ -44,27 +45,29 @@ namespace ArmoniK.Core.Common.Pollster;

public class Pollster : IInitializable
{
private readonly ActivitySource activitySource_;
private readonly IAgentHandler agentHandler_;
private readonly DataPrefetcher dataPrefetcher_;
private readonly IHostApplicationLifetime lifeTime_;
private readonly ILogger<Pollster> logger_;
private readonly int messageBatchSize_;
private readonly IObjectStorage objectStorage_;
private readonly string ownerPodId_;
private readonly string ownerPodName_;
private readonly Injection.Options.Pollster pollsterOptions_;
private readonly IPullQueueStorage pullQueueStorage_;
private readonly IResultTable resultTable_;
private readonly ISessionTable sessionTable_;
private readonly ISubmitter submitter_;
private readonly ITaskProcessingChecker taskProcessingChecker_;
private readonly ITaskTable taskTable_;
private readonly IWorkerStreamHandler workerStreamHandler_;
private bool endLoopReached_;
private HealthCheckResult? healthCheckFailedResult_;
public Func<Task>? StopCancelledTask;
public string TaskProcessing;
private readonly ActivitySource activitySource_;
private readonly IAgentHandler agentHandler_;
private readonly DataPrefetcher dataPrefetcher_;
private readonly IHostApplicationLifetime lifeTime_;
private readonly ILogger<Pollster> logger_;
private readonly ILoggerFactory loggerFactory_;
private readonly int messageBatchSize_;
private readonly IObjectStorage objectStorage_;
private readonly string ownerPodId_;
private readonly string ownerPodName_;
private readonly Injection.Options.Pollster pollsterOptions_;
private readonly IPullQueueStorage pullQueueStorage_;
private readonly IResultTable resultTable_;
private readonly RunningTaskQueue runningTaskQueue_;
private readonly ISessionTable sessionTable_;
private readonly ISubmitter submitter_;
private readonly ITaskProcessingChecker taskProcessingChecker_;
private readonly ConcurrentDictionary<string, TaskHandler> taskProcessingDict_ = new();
private readonly ITaskTable taskTable_;
private readonly IWorkerStreamHandler workerStreamHandler_;
private bool endLoopReached_;
private HealthCheckResult? healthCheckFailedResult_;


public Pollster(IPullQueueStorage pullQueueStorage,
DataPrefetcher dataPrefetcher,
Expand All @@ -73,14 +76,16 @@ public Pollster(IPullQueueStorage pullQueueStorage,
IHostApplicationLifetime lifeTime,
ActivitySource activitySource,
ILogger<Pollster> logger,
ILoggerFactory loggerFactory,
IObjectStorage objectStorage,
IResultTable resultTable,
ISubmitter submitter,
ISessionTable sessionTable,
ITaskTable taskTable,
ITaskProcessingChecker taskProcessingChecker,
IWorkerStreamHandler workerStreamHandler,
IAgentHandler agentHandler)
IAgentHandler agentHandler,
RunningTaskQueue runningTaskQueue)
{
if (options.MessageBatchSize < 1)
{
Expand All @@ -89,6 +94,7 @@ public Pollster(IPullQueueStorage pullQueueStorage,
}

logger_ = logger;
loggerFactory_ = loggerFactory;
activitySource_ = activitySource;
pullQueueStorage_ = pullQueueStorage;
lifeTime_ = lifeTime;
Expand All @@ -103,12 +109,15 @@ public Pollster(IPullQueueStorage pullQueueStorage,
taskProcessingChecker_ = taskProcessingChecker;
workerStreamHandler_ = workerStreamHandler;
agentHandler_ = agentHandler;
TaskProcessing = "";
runningTaskQueue_ = runningTaskQueue;
ownerPodId_ = LocalIpFinder.LocalIpv4Address();
ownerPodName_ = Dns.GetHostName();
Failed = false;
}

public ICollection<string> TaskProcessing
=> taskProcessingDict_.Keys;

/// <summary>
/// Is true when the MainLoop exited with an error
/// Used in Unit tests
Expand Down Expand Up @@ -194,6 +203,15 @@ public async Task<HealthCheckResult> Check(HealthCheckTag tag)
return result;
}

public async Task StopCancelledTask()
{
foreach (var taskHandler in taskProcessingDict_.Values)
{
await taskHandler.StopCancelledTask()
.ConfigureAwait(false);
}
}

public async Task MainLoop(CancellationToken cancellationToken)
{
await Init(cancellationToken)
Expand Down Expand Up @@ -251,62 +269,98 @@ void RecordError(Exception e)

await foreach (var message in messages.ConfigureAwait(false))
{
using var scopedLogger = logger_.BeginNamedScope("Prefetch messageHandler",
("messageHandler", message.MessageId),
("taskId", message.TaskId),
("ownerPodId", ownerPodId_));
TaskProcessing = message.TaskId;
var taskHandlerLogger = loggerFactory_.CreateLogger<TaskHandler>();
using var _ = taskHandlerLogger.BeginNamedScope("Prefetch messageHandler",
("messageHandler", message.MessageId),
("taskId", message.TaskId),
("ownerPodId", ownerPodId_));

// ReSharper disable once ExplicitCallerInfoArgument
using var activity = activitySource_.StartActivity("ProcessQueueMessage");
activity?.SetBaggage("TaskId",
message.TaskId);
activity?.SetBaggage("messageId",
message.MessageId);

logger_.LogDebug("Start a new Task to process the messageHandler");
taskHandlerLogger.LogDebug("Start a new Task to process the messageHandler");

try
while (runningTaskQueue_.RemoveException(out var exception))
{
if (exception is RpcException rpcException && TaskHandler.IsStatusFatal(rpcException.StatusCode))
{
// This exception should stop pollster
exception.RethrowWithStacktrace();
}

RecordError(exception);
}

var taskHandler = new TaskHandler(sessionTable_,
taskTable_,
resultTable_,
submitter_,
dataPrefetcher_,
workerStreamHandler_,
message,
taskProcessingChecker_,
ownerPodId_,
ownerPodName_,
activitySource_,
agentHandler_,
taskHandlerLogger,
pollsterOptions_,
() => taskProcessingDict_.TryRemove(message.TaskId,
out var _),
cts);

if (!taskProcessingDict_.TryAdd(message.TaskId,
taskHandler))
{
await using var taskHandler = new TaskHandler(sessionTable_,
taskTable_,
resultTable_,
submitter_,
dataPrefetcher_,
workerStreamHandler_,
message,
taskProcessingChecker_,
ownerPodId_,
ownerPodName_,
activitySource_,
agentHandler_,
logger_,
pollsterOptions_,
cts);

StopCancelledTask = taskHandler.StopCancelledTask;
message.Status = QueueMessageStatus.Processed;
await taskHandler.DisposeAsync()
.ConfigureAwait(false);
continue;
}


try
{
var precondition = await taskHandler.AcquireTask()
.ConfigureAwait(false);

if (precondition)
{
await taskHandler.PreProcessing()
.ConfigureAwait(false);

await taskHandler.ExecuteTask()
.ConfigureAwait(false);
try
{
await taskHandler.PreProcessing()
.ConfigureAwait(false);
}
catch
{
await taskHandler.DisposeAsync()
.ConfigureAwait(false);
throw;
}

await taskHandler.PostProcessing()
.ConfigureAwait(false);
await runningTaskQueue_.WriteAsync(taskHandler,
cancellationToken)
.ConfigureAwait(false);

StopCancelledTask = null;
await runningTaskQueue_.WaitForNextWriteAsync(pollsterOptions_.TimeoutBeforeNextAcquisition,
cancellationToken)
.ConfigureAwait(false);

// If the task was successful, we can remove a failure
if (recordedErrors.Count > 0)
{
recordedErrors.Dequeue();
}
}
else
{
await taskHandler.DisposeAsync()
.ConfigureAwait(false);
}
}
catch (RpcException e) when (TaskHandler.IsStatusFatal(e.StatusCode))
{
Expand All @@ -317,11 +371,6 @@ await taskHandler.PostProcessing()
{
RecordError(e);
}
finally
{
StopCancelledTask = null;
TaskProcessing = string.Empty;
}
}
}
catch (RpcException e) when (e.StatusCode == StatusCode.Unavailable)
Expand Down
Loading

0 comments on commit 8b5ea41

Please sign in to comment.