Skip to content

Commit

Permalink
refactor: Move grpc based interfaces to the task table in Convertors …
Browse files Browse the repository at this point in the history
…folder
  • Loading branch information
aneojgurhem committed Sep 29, 2023
1 parent 5682628 commit d4a3dec
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 297 deletions.
86 changes: 5 additions & 81 deletions Adaptors/Memory/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.Common.Utils;
using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Convertors;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Core.Utils;

Expand Down Expand Up @@ -128,20 +125,6 @@ public Task StartTask(TaskData taskData,
return Task.CompletedTask;
}

/// <inheritdoc />
public async Task<IEnumerable<TaskStatusCount>> CountTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default)
=> await ListTasksAsync(filter,
cancellationToken)
.Select(taskId => taskId2TaskData_[taskId]
.Status)
.GroupBy(status => status)
.SelectAwait(async grouping => new TaskStatusCount(grouping.Key,
await grouping.CountAsync(cancellationToken)
.ConfigureAwait(false)))
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

/// <inheritdoc />
public Task<IEnumerable<TaskStatusCount>> CountTasksAsync(Expression<Func<TaskData, bool>> filter,
CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -182,33 +165,6 @@ public Task DeleteTaskAsync(string id,
out _));
}

/// <inheritdoc />
public IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default)
{
IEnumerable<string> rawList = filter.IdsCase switch
{
TaskFilter.IdsOneofCase.None =>
throw new ArgumentException("Filter is not properly initialized. Either the session or the tasks are required",
nameof(filter)),
TaskFilter.IdsOneofCase.Session => filter.Session.Ids.SelectMany(s => session2TaskIds_[s])
.ToImmutableList(),
TaskFilter.IdsOneofCase.Task => filter.Task.Ids,
_ => throw new ArgumentException("Filter is set to an unknown IdsCase."),
};

return rawList.Where(taskId => filter.StatusesCase switch
{
TaskFilter.StatusesOneofCase.None => true,
TaskFilter.StatusesOneofCase.Included => filter.Included.Statuses.Contains(taskId2TaskData_[taskId]
.Status.ToGrpcStatus()),
TaskFilter.StatusesOneofCase.Excluded => !filter.Excluded.Statuses.Contains(taskId2TaskData_[taskId]
.Status.ToGrpcStatus()),
_ => throw new ArgumentException("Filter is set to an unknown StatusesCase."),
})
.ToAsyncEnumerable();
}

/// <inheritdoc />
public Task<(IEnumerable<T> tasks, long totalCount)> ListTasksAsync<T>(Expression<Func<TaskData, bool>> filter,
Expression<Func<TaskData, object?>> orderField,
Expand Down Expand Up @@ -333,8 +289,7 @@ public Task<Output> GetTaskOutput(string taskId,
throw new TaskNotFoundException($"Key '{taskId}' not found");
}

return Task.FromResult(taskId2TaskData_[taskId]
.Output);
return Task.FromResult(taskId2TaskData_[taskId].Output);
}

/// <inheritdoc />
Expand Down Expand Up @@ -397,8 +352,7 @@ public Task<IEnumerable<string>> GetParentTaskIds(string taskId,
throw new TaskNotFoundException($"Key '{taskId}' not found");
}

return Task.FromResult(taskId2TaskData_[taskId]
.ParentTaskIds as IEnumerable<string>);
return Task.FromResult(taskId2TaskData_[taskId].ParentTaskIds as IEnumerable<string>);
}

/// <inheritdoc />
Expand Down Expand Up @@ -466,39 +420,9 @@ public Task<HealthCheckResult> Check(HealthCheckTag tag)
: HealthCheckResult.Unhealthy());

/// <inheritdoc />
public async Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
{
var count = 0;

foreach (var session in session2TaskIds_.Keys)
{
var statusFilter = new TaskFilter
{
Included = new TaskFilter.Types.StatusesRequest
{
Statuses =
{
status.ToGrpcStatus(),
},
},
Session = new TaskFilter.Types.IdsRequest
{
Ids =
{
session,
},
},
};

count += await ListTasksAsync(statusFilter,
cancellationToken)
.CountAsync(cancellationToken)
.ConfigureAwait(false);
}

return count;
}
public Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
=> Task.FromResult(taskId2TaskData_.Count(pair => pair.Value.Status == status));

/// <inheritdoc />
public Task<IEnumerable<TaskIdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
Expand Down
31 changes: 0 additions & 31 deletions Adaptors/MongoDB/src/Table/MongoQueryableExt.cs

This file was deleted.

42 changes: 0 additions & 42 deletions Adaptors/MongoDB/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,11 @@
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Adapters.MongoDB.Common;
using ArmoniK.Core.Adapters.MongoDB.Options;
using ArmoniK.Core.Adapters.MongoDB.Table;
using ArmoniK.Core.Adapters.MongoDB.Table.DataModel;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
Expand Down Expand Up @@ -678,43 +675,4 @@ public async Task<IEnumerable<TaskIdStatus>> GetTaskStatus(IEnumerable<string> t
.ToListAsync(cancellationToken)
.ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IEnumerable<TaskStatusCount>> CountTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(CountTasksAsync)}");

var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();


var res = await taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.GroupBy(model => model.Status)
.Select(models => new TaskStatusCount(models.Key,
models.Count()))
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

return res;
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(ListTasksAsync)}");
var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

await foreach (var taskId in taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.Select(model => model.TaskId)
.ToAsyncEnumerable(cancellationToken)
.ConfigureAwait(false))
{
yield return taskId;
}
}
}
22 changes: 0 additions & 22 deletions Common/src/Storage/ITaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base;
using ArmoniK.Core.Common.Exceptions;

Expand Down Expand Up @@ -105,16 +104,6 @@ Task<bool> IsTaskCancelledAsync(string taskId,
Task StartTask(TaskData taskData,
CancellationToken cancellationToken = default);

/// <summary>
/// Count tasks matching a given filter
/// </summary>
/// <param name="filter">Task Filter describing the tasks to be counted</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// The number of tasks that matched the filter
/// </returns>
Task<IEnumerable<TaskStatusCount>> CountTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default);

/// <summary>
/// Count tasks matching a given filter
Expand Down Expand Up @@ -158,17 +147,6 @@ Task<int> CountAllTasksAsync(TaskStatus status,
Task DeleteTaskAsync(string id,
CancellationToken cancellationToken = default);

/// <summary>
/// List all tasks matching a given filter
/// </summary>
/// <param name="filter">Task Filter describing the tasks to be counted</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// List of tasks that matched the filter
/// </returns>
IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default);

/// <summary>
/// List all tasks matching the given filter and ordering
/// </summary>
Expand Down
45 changes: 0 additions & 45 deletions Common/src/Storage/TaskTableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC;
using ArmoniK.Core.Common.gRPC.Convertors;

namespace ArmoniK.Core.Common.Storage;

public static class TaskTableExtensions
Expand All @@ -40,14 +35,6 @@ public static class TaskTableExtensions
TaskStatus.Timeout,
};

public static async Task<int> CancelTasks(this ITaskTable taskTable,
TaskFilter filter,
CancellationToken cancellationToken = default)
=> (int)await taskTable.UpdateAllTaskStatusAsync(filter,
TaskStatus.Cancelling,
cancellationToken)
.ConfigureAwait(false);

/// <summary>
/// Change the status of the task to canceled
/// </summary>
Expand Down Expand Up @@ -192,38 +179,6 @@ public static async Task<bool> SetTaskRetryAsync(this ITaskTable taskTable,
return task.Status != TaskStatus.Retried;
}

/// <summary>
/// Update the statuses of all tasks matching a given filter
/// </summary>
/// <param name="taskTable">Interface to manage tasks lifecycle</param>
/// <param name="filter">Task Filter describing the tasks whose status should be updated</param>
/// <param name="status">The new task status</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// The number of updated tasks
/// </returns>
public static async Task<long> UpdateAllTaskStatusAsync(this ITaskTable taskTable,
TaskFilter filter,
TaskStatus status,
CancellationToken cancellationToken = default)
{
if (filter.Included != null && (filter.Included.Statuses.Contains(TaskStatus.Completed.ToGrpcStatus()) ||
filter.Included.Statuses.Contains(TaskStatus.Cancelled.ToGrpcStatus()) ||
filter.Included.Statuses.Contains(TaskStatus.Error.ToGrpcStatus()) ||
filter.Included.Statuses.Contains(TaskStatus.Retried.ToGrpcStatus())))
{
throw new ArmoniKException("The given TaskFilter contains a terminal state, update forbidden");
}

return await taskTable.UpdateManyTasks(filter.ToFilterExpression(),
new List<(Expression<Func<TaskData, object?>> selector, object? newValue)>
{
(tdm => tdm.Status, status),
},
cancellationToken)
.ConfigureAwait(false);
}

/// <summary>
/// Cancels all tasks in a given session
/// </summary>
Expand Down
Loading

0 comments on commit d4a3dec

Please sign in to comment.