Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure task in error or retry children are properly cancelled #790

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions Common/src/Pollster/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,6 @@ public Task<string> GetDirectData(string token,
throw new NotImplementedException("Direct data are not implemented yet");
}

/// <inheritdoc />
public async Task CancelChildTasks(CancellationToken cancellationToken)
{
if (createdTasks_.Any())
{
await taskTable_.CancelTaskAsync(createdTasks_.Select(request => request.TaskId)
.AsICollection(),
cancellationToken)
.ConfigureAwait(false);
}

logger_.LogDebug("Cancel {n} child tasks created by this task",
createdTasks_.Count);
}

/// <inheritdoc />
public async Task<ICollection<TaskCreationRequest>> SubmitTasks(ICollection<TaskSubmissionRequest> requests,
TaskOptions? taskOptions,
Expand Down
9 changes: 0 additions & 9 deletions Common/src/Pollster/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,4 @@ Task<ICollection<Result>> CreateResults(string
Task<ICollection<string>> NotifyResultData(string token,
ICollection<string> resultIds,
CancellationToken cancellationToken);

/// <summary>
/// Cancel child tasks created by the current task in processing
/// </summary>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Task representing the asynchronous execution of the method
/// </returns>
Task CancelChildTasks(CancellationToken cancellationToken);
}
11 changes: 0 additions & 11 deletions Common/src/Pollster/TaskHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -960,11 +960,6 @@ public async Task PostProcessing()
await agent_.FinalizeTaskCreation(CancellationToken.None)
.ConfigureAwait(false);
}
else
{
await agent_.CancelChildTasks(CancellationToken.None)
.ConfigureAwait(false);
}

await submitter_.CompleteTaskAsync(taskData_,
sessionData_,
Expand Down Expand Up @@ -1089,12 +1084,6 @@ await submitter_.CompleteTaskAsync(taskData,
? QueueMessageStatus.Cancelled
: QueueMessageStatus.Processed;
}

if (agent_ is not null)
{
await agent_.CancelChildTasks(CancellationToken.None)
.ConfigureAwait(false);
}
}

// Rethrow enable the recording of the error by the Pollster Main loop
Expand Down
14 changes: 14 additions & 0 deletions Common/src/gRPC/Services/Submitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,20 @@ await resultTable_.MarkAsDeleted(taskData.PayloadId,

break;
case OutputStatus.Error:

await taskTable_.UpdateManyTasks(data => data.CreatedBy == taskData.TaskId,
new UpdateDefinition<TaskData>().Set(data => data.Status,
TaskStatus.Cancelled),
CancellationToken.None)
.ConfigureAwait(false);

await resultTable_.UpdateManyResults(data => data.CreatedBy == taskData.TaskId,
new UpdateDefinition<Result>().Set(data => data.Status,
ResultStatus.Aborted),
CancellationToken.None)
.ConfigureAwait(false);


// TODO FIXME: nothing will resubmit the task if there is a crash there
if (resubmit && taskData.RetryOfIds.Count < taskData.Options.MaxRetries)
{
Expand Down
3 changes: 0 additions & 3 deletions Common/tests/Helpers/SimpleAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ public Task<ICollection<string>> NotifyResultData(string token,
=> Task.FromResult(Array.Empty<string>()
.AsICollection());

public Task CancelChildTasks(CancellationToken cancellationToken)
=> Task.CompletedTask;

public void Dispose()
=> GC.SuppressFinalize(this);
}
6 changes: 4 additions & 2 deletions Common/tests/Helpers/SimpleWorkerStreamHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace ArmoniK.Core.Common.Tests.Helpers;

public class SimpleWorkerStreamHandler : IWorkerStreamHandler
{
public Output Output = new(OutputStatus.Success,
"");

public Task<HealthCheckResult> Check(HealthCheckTag tag)
=> Task.FromResult(HealthCheckResult.Healthy());

Expand All @@ -42,6 +45,5 @@ public Task<Output> StartTaskProcessing(TaskData taskData,
string token,
string dataFolder,
CancellationToken cancellationToken)
=> Task.FromResult(new Output(OutputStatus.Success,
""));
=> Task.FromResult(Output);
}
4 changes: 4 additions & 0 deletions Common/tests/Helpers/TestTaskHandlerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,8 @@ public void Dispose()
.Wait();
GC.SuppressFinalize(this);
}

public T GetRequiredService<T>()
where T : notnull
=> app_.Services.GetRequiredService<T>();
}
45 changes: 45 additions & 0 deletions Common/tests/Helpers/WrapperAgentHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2024. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Common.Pollster;
using ArmoniK.Core.Common.Storage;

using Microsoft.Extensions.Logging;

namespace ArmoniK.Core.Common.Tests.Helpers;

public class WrapperAgentHandler : IAgentHandler
{
private readonly IAgent agent_;

public WrapperAgentHandler(IAgent agent)
=> agent_ = agent;

public Task Stop(CancellationToken cancellationToken)
=> Task.CompletedTask;

public Task<IAgent> Start(string token,
ILogger logger,
SessionData sessionData,
TaskData taskData,
string folder,
CancellationToken cancellationToken)
=> Task.FromResult(agent_);
}
164 changes: 164 additions & 0 deletions Common/tests/Pollster/TaskHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Base;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Services;
using ArmoniK.Core.Common.Meter;
using ArmoniK.Core.Common.Pollster;
using ArmoniK.Core.Common.Pollster.TaskProcessingChecker;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Core.Common.Stream.Worker;
using ArmoniK.Core.Common.Tests.Helpers;
using ArmoniK.Core.Common.Utils;

using Grpc.Core;

Expand Down Expand Up @@ -1582,6 +1587,165 @@ await testServiceProvider.TaskHandler.PostProcessing()
sqmh.Status);
}

[Test]
public async Task ExecuteErrorTaskAndAbortChildrenShouldSucceed()
{
var sqmh = new SimpleQueueMessageHandler
{
CancellationToken = CancellationToken.None,
Status = QueueMessageStatus.Waiting,
MessageId = Guid.NewGuid()
.ToString(),
};

var sh = new SimpleWorkerStreamHandler
{
Output = new Output(OutputStatus.Error,
"Error task to validate child tasks are cancelled properly"),
};
using var testServiceProvider = new TestTaskHandlerProvider(sh,
new SimpleAgentHandler(),
sqmh);

var (taskId, _, _, _, sessionId) = await InitProviderRunnableTask(testServiceProvider)
.ConfigureAwait(false);


var sessionData = await testServiceProvider.SessionTable.GetSessionAsync(sessionId)
.ConfigureAwait(false);
var taskData = await testServiceProvider.TaskTable.ReadTaskAsync(taskId,
CancellationToken.None)
.ConfigureAwait(false);

var token = Guid.NewGuid()
.ToString();

var agent = new Agent(testServiceProvider.GetRequiredService<ISubmitter>(),
testServiceProvider.GetRequiredService<IObjectStorage>(),
testServiceProvider.GetRequiredService<IPushQueueStorage>(),
testServiceProvider.GetRequiredService<IResultTable>(),
testServiceProvider.GetRequiredService<ITaskTable>(),
sessionData,
taskData,
Path.GetTempFileName(),
token,
testServiceProvider.Logger);

var payloadId = (await agent.CreateResults(token,
new[]
{
(new ResultCreationRequest(sessionId,
"payload"), new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("payload"))),
},
CancellationToken.None)
.ConfigureAwait(false)).Single()
.ResultId;

var output = (await agent.CreateResultsMetaData(token,
new[]
{
new ResultCreationRequest(sessionId,
"output"),
},
CancellationToken.None)
.ConfigureAwait(false)).Single()
.ResultId;

var task = (await agent.SubmitTasks(new List<TaskSubmissionRequest>
{
new(payloadId,
null,
new List<string>
{
output,
},
new List<string>()),
},
null,
sessionId,
token,
CancellationToken.None)
.ConfigureAwait(false)).Single()
.TaskId;

var agentHandler = new WrapperAgentHandler(agent);

var taskHandler = new TaskHandler(testServiceProvider.GetRequiredService<ISessionTable>(),
testServiceProvider.GetRequiredService<ITaskTable>(),
testServiceProvider.GetRequiredService<IResultTable>(),
testServiceProvider.GetRequiredService<ISubmitter>(),
testServiceProvider.GetRequiredService<DataPrefetcher>(),
sh,
sqmh,
testServiceProvider.GetRequiredService<ITaskProcessingChecker>(),
"ownerpodid",
"ownerpodname",
testServiceProvider.GetRequiredService<ActivitySource>(),
agentHandler,
testServiceProvider.GetRequiredService<ILogger>(),
testServiceProvider.GetRequiredService<Injection.Options.Pollster>(),
() =>
{
},
testServiceProvider.GetRequiredService<ExceptionManager>(),
testServiceProvider.GetRequiredService<FunctionExecutionMetrics<TaskHandler>>());

sqmh.TaskId = taskId;

var acquired = await taskHandler.AcquireTask()
.ConfigureAwait(false);

Assert.AreEqual(AcquisitionStatus.Acquired,
acquired);

await taskHandler.PreProcessing()
.ConfigureAwait(false);

await taskHandler.ExecuteTask()
.ConfigureAwait(false);

await taskHandler.PostProcessing()
.ConfigureAwait(false);

taskData = await testServiceProvider.TaskTable.ReadTaskAsync(taskId,
CancellationToken.None)
.ConfigureAwait(false);

Console.WriteLine(taskData);

Assert.AreEqual(TaskStatus.Error,
taskData.Status);
Assert.IsNotNull(taskData.StartDate);
Assert.IsNotNull(taskData.EndDate);
Assert.IsNotNull(taskData.ProcessingToEndDuration);
Assert.IsNotNull(taskData.CreationToEndDuration);
Assert.Greater(taskData.CreationToEndDuration,
taskData.ProcessingToEndDuration);

Assert.AreEqual(QueueMessageStatus.Processed,
sqmh.Status);

taskData = await testServiceProvider.TaskTable.ReadTaskAsync(task,
CancellationToken.None)
.ConfigureAwait(false);
Console.WriteLine(taskData);
Assert.AreEqual(TaskStatus.Cancelled,
taskData.Status);

var result = await testServiceProvider.ResultTable.GetResult(payloadId)
.ConfigureAwait(false);

Assert.AreEqual(ResultStatus.Aborted,
result.Status);

result = await testServiceProvider.ResultTable.GetResult(output)
.ConfigureAwait(false);

Assert.AreEqual(ResultStatus.Aborted,
result.Status);
}


private class ObjectStorageThrowNotFound : IObjectStorage
{
public Task<HealthCheckResult> Check(HealthCheckTag tag)
Expand Down