Skip to content

Commit

Permalink
fix: ensure task in error or retry children are properly cancelled (#790
Browse files Browse the repository at this point in the history
)

# Motivation

When tasks postprocessing is not executed, children tasks and results
are not aborted. It lets a lot of tasks and results in Creating, adding
lot of noise when debugging and misleading remnants of failed tasks as
tasks that could potentially complete.

# Description

Move children (tasks and results) cancellation from Agent.CancelChild to
Submitter.CompleteAsync ensuring that children cancellation is called
everytime a task is set to Error or Retry. It uses the new CreatedBy
field to find all related tasks and results.

# Testing

- Unit test were added to make sure children were properly cancelled
when tasks return an Error.
- Validated using integration tests: HtcMock is able to produce
exception during task execution, generating retries. All children tasks
were cancelled.

# Impact

- Improved debugging
- Clearer end of session

# Checklist

- [x] My code adheres to the coding and style guidelines of the project.
- [x] I have performed a self-review of my code.
- [ ] I have commented my code, particularly in hard-to-understand
areas.
- [ ] I have made corresponding changes to the documentation.
- [x] I have thoroughly tested my modifications and added tests when
necessary.
- [x] Tests pass locally and in the CI.
- [x] I have assessed the performance impact of my modifications.
  • Loading branch information
aneojgurhem authored Nov 4, 2024
2 parents c956937 + a063622 commit a6e9a4c
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 40 deletions.
15 changes: 0 additions & 15 deletions Common/src/Pollster/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,6 @@ public Task<string> GetDirectData(string token,
throw new NotImplementedException("Direct data are not implemented yet");
}

/// <inheritdoc />
public async Task CancelChildTasks(CancellationToken cancellationToken)
{
if (createdTasks_.Any())
{
await taskTable_.CancelTaskAsync(createdTasks_.Select(request => request.TaskId)
.AsICollection(),
cancellationToken)
.ConfigureAwait(false);
}

logger_.LogDebug("Cancel {n} child tasks created by this task",
createdTasks_.Count);
}

/// <inheritdoc />
public async Task<ICollection<TaskCreationRequest>> SubmitTasks(ICollection<TaskSubmissionRequest> requests,
TaskOptions? taskOptions,
Expand Down
9 changes: 0 additions & 9 deletions Common/src/Pollster/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,4 @@ Task<ICollection<Result>> CreateResults(string
Task<ICollection<string>> NotifyResultData(string token,
ICollection<string> resultIds,
CancellationToken cancellationToken);

/// <summary>
/// Cancel child tasks created by the current task in processing
/// </summary>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Task representing the asynchronous execution of the method
/// </returns>
Task CancelChildTasks(CancellationToken cancellationToken);
}
11 changes: 0 additions & 11 deletions Common/src/Pollster/TaskHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -960,11 +960,6 @@ public async Task PostProcessing()
await agent_.FinalizeTaskCreation(CancellationToken.None)
.ConfigureAwait(false);
}
else
{
await agent_.CancelChildTasks(CancellationToken.None)
.ConfigureAwait(false);
}

await submitter_.CompleteTaskAsync(taskData_,
sessionData_,
Expand Down Expand Up @@ -1089,12 +1084,6 @@ await submitter_.CompleteTaskAsync(taskData,
? QueueMessageStatus.Cancelled
: QueueMessageStatus.Processed;
}

if (agent_ is not null)
{
await agent_.CancelChildTasks(CancellationToken.None)
.ConfigureAwait(false);
}
}

// Rethrow enable the recording of the error by the Pollster Main loop
Expand Down
14 changes: 14 additions & 0 deletions Common/src/gRPC/Services/Submitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,20 @@ await resultTable_.MarkAsDeleted(taskData.PayloadId,

break;
case OutputStatus.Error:

await taskTable_.UpdateManyTasks(data => data.CreatedBy == taskData.TaskId,
new UpdateDefinition<TaskData>().Set(data => data.Status,
TaskStatus.Cancelled),
CancellationToken.None)
.ConfigureAwait(false);

await resultTable_.UpdateManyResults(data => data.CreatedBy == taskData.TaskId,
new UpdateDefinition<Result>().Set(data => data.Status,
ResultStatus.Aborted),
CancellationToken.None)
.ConfigureAwait(false);


// TODO FIXME: nothing will resubmit the task if there is a crash there
if (resubmit && taskData.RetryOfIds.Count < taskData.Options.MaxRetries)
{
Expand Down
3 changes: 0 additions & 3 deletions Common/tests/Helpers/SimpleAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ public Task<ICollection<string>> NotifyResultData(string token,
=> Task.FromResult(Array.Empty<string>()
.AsICollection());

public Task CancelChildTasks(CancellationToken cancellationToken)
=> Task.CompletedTask;

public void Dispose()
=> GC.SuppressFinalize(this);
}
6 changes: 4 additions & 2 deletions Common/tests/Helpers/SimpleWorkerStreamHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace ArmoniK.Core.Common.Tests.Helpers;

public class SimpleWorkerStreamHandler : IWorkerStreamHandler
{
public Output Output = new(OutputStatus.Success,
"");

public Task<HealthCheckResult> Check(HealthCheckTag tag)
=> Task.FromResult(HealthCheckResult.Healthy());

Expand All @@ -42,6 +45,5 @@ public Task<Output> StartTaskProcessing(TaskData taskData,
string token,
string dataFolder,
CancellationToken cancellationToken)
=> Task.FromResult(new Output(OutputStatus.Success,
""));
=> Task.FromResult(Output);
}
4 changes: 4 additions & 0 deletions Common/tests/Helpers/TestTaskHandlerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,8 @@ public void Dispose()
.Wait();
GC.SuppressFinalize(this);
}

public T GetRequiredService<T>()
where T : notnull
=> app_.Services.GetRequiredService<T>();
}
45 changes: 45 additions & 0 deletions Common/tests/Helpers/WrapperAgentHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2024. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Common.Pollster;
using ArmoniK.Core.Common.Storage;

using Microsoft.Extensions.Logging;

namespace ArmoniK.Core.Common.Tests.Helpers;

public class WrapperAgentHandler : IAgentHandler
{
private readonly IAgent agent_;

public WrapperAgentHandler(IAgent agent)
=> agent_ = agent;

public Task Stop(CancellationToken cancellationToken)
=> Task.CompletedTask;

public Task<IAgent> Start(string token,
ILogger logger,
SessionData sessionData,
TaskData taskData,
string folder,
CancellationToken cancellationToken)
=> Task.FromResult(agent_);
}
164 changes: 164 additions & 0 deletions Common/tests/Pollster/TaskHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Base;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Services;
using ArmoniK.Core.Common.Meter;
using ArmoniK.Core.Common.Pollster;
using ArmoniK.Core.Common.Pollster.TaskProcessingChecker;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Core.Common.Stream.Worker;
using ArmoniK.Core.Common.Tests.Helpers;
using ArmoniK.Core.Common.Utils;

using Grpc.Core;

Expand Down Expand Up @@ -1582,6 +1587,165 @@ await testServiceProvider.TaskHandler.PostProcessing()
sqmh.Status);
}

[Test]
public async Task ExecuteErrorTaskAndAbortChildrenShouldSucceed()
{
var sqmh = new SimpleQueueMessageHandler
{
CancellationToken = CancellationToken.None,
Status = QueueMessageStatus.Waiting,
MessageId = Guid.NewGuid()
.ToString(),
};

var sh = new SimpleWorkerStreamHandler
{
Output = new Output(OutputStatus.Error,
"Error task to validate child tasks are cancelled properly"),
};
using var testServiceProvider = new TestTaskHandlerProvider(sh,
new SimpleAgentHandler(),
sqmh);

var (taskId, _, _, _, sessionId) = await InitProviderRunnableTask(testServiceProvider)
.ConfigureAwait(false);


var sessionData = await testServiceProvider.SessionTable.GetSessionAsync(sessionId)
.ConfigureAwait(false);
var taskData = await testServiceProvider.TaskTable.ReadTaskAsync(taskId,
CancellationToken.None)
.ConfigureAwait(false);

var token = Guid.NewGuid()
.ToString();

var agent = new Agent(testServiceProvider.GetRequiredService<ISubmitter>(),
testServiceProvider.GetRequiredService<IObjectStorage>(),
testServiceProvider.GetRequiredService<IPushQueueStorage>(),
testServiceProvider.GetRequiredService<IResultTable>(),
testServiceProvider.GetRequiredService<ITaskTable>(),
sessionData,
taskData,
Path.GetTempFileName(),
token,
testServiceProvider.Logger);

var payloadId = (await agent.CreateResults(token,
new[]
{
(new ResultCreationRequest(sessionId,
"payload"), new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("payload"))),
},
CancellationToken.None)
.ConfigureAwait(false)).Single()
.ResultId;

var output = (await agent.CreateResultsMetaData(token,
new[]
{
new ResultCreationRequest(sessionId,
"output"),
},
CancellationToken.None)
.ConfigureAwait(false)).Single()
.ResultId;

var task = (await agent.SubmitTasks(new List<TaskSubmissionRequest>
{
new(payloadId,
null,
new List<string>
{
output,
},
new List<string>()),
},
null,
sessionId,
token,
CancellationToken.None)
.ConfigureAwait(false)).Single()
.TaskId;

var agentHandler = new WrapperAgentHandler(agent);

var taskHandler = new TaskHandler(testServiceProvider.GetRequiredService<ISessionTable>(),
testServiceProvider.GetRequiredService<ITaskTable>(),
testServiceProvider.GetRequiredService<IResultTable>(),
testServiceProvider.GetRequiredService<ISubmitter>(),
testServiceProvider.GetRequiredService<DataPrefetcher>(),
sh,
sqmh,
testServiceProvider.GetRequiredService<ITaskProcessingChecker>(),
"ownerpodid",
"ownerpodname",
testServiceProvider.GetRequiredService<ActivitySource>(),
agentHandler,
testServiceProvider.GetRequiredService<ILogger>(),
testServiceProvider.GetRequiredService<Injection.Options.Pollster>(),
() =>
{
},
testServiceProvider.GetRequiredService<ExceptionManager>(),
testServiceProvider.GetRequiredService<FunctionExecutionMetrics<TaskHandler>>());

sqmh.TaskId = taskId;

var acquired = await taskHandler.AcquireTask()
.ConfigureAwait(false);

Assert.AreEqual(AcquisitionStatus.Acquired,
acquired);

await taskHandler.PreProcessing()
.ConfigureAwait(false);

await taskHandler.ExecuteTask()
.ConfigureAwait(false);

await taskHandler.PostProcessing()
.ConfigureAwait(false);

taskData = await testServiceProvider.TaskTable.ReadTaskAsync(taskId,
CancellationToken.None)
.ConfigureAwait(false);

Console.WriteLine(taskData);

Assert.AreEqual(TaskStatus.Error,
taskData.Status);
Assert.IsNotNull(taskData.StartDate);
Assert.IsNotNull(taskData.EndDate);
Assert.IsNotNull(taskData.ProcessingToEndDuration);
Assert.IsNotNull(taskData.CreationToEndDuration);
Assert.Greater(taskData.CreationToEndDuration,
taskData.ProcessingToEndDuration);

Assert.AreEqual(QueueMessageStatus.Processed,
sqmh.Status);

taskData = await testServiceProvider.TaskTable.ReadTaskAsync(task,
CancellationToken.None)
.ConfigureAwait(false);
Console.WriteLine(taskData);
Assert.AreEqual(TaskStatus.Cancelled,
taskData.Status);

var result = await testServiceProvider.ResultTable.GetResult(payloadId)
.ConfigureAwait(false);

Assert.AreEqual(ResultStatus.Aborted,
result.Status);

result = await testServiceProvider.ResultTable.GetResult(output)
.ConfigureAwait(false);

Assert.AreEqual(ResultStatus.Aborted,
result.Status);
}


private class ObjectStorageThrowNotFound : IObjectStorage
{
public Task<HealthCheckResult> Check(HealthCheckTag tag)
Expand Down

0 comments on commit a6e9a4c

Please sign in to comment.