Skip to content

Commit

Permalink
refactor: Remove usage of gRPC defined TaskStatus in internal interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem committed Sep 26, 2023
1 parent d2e0e0a commit 5bd5da9
Show file tree
Hide file tree
Showing 43 changed files with 419 additions and 230 deletions.
94 changes: 46 additions & 48 deletions Adaptors/Memory/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Convertors;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Core.Utils;

using Microsoft.Extensions.Diagnostics.HealthChecks;
using Microsoft.Extensions.Logging;

using TaskStatus = ArmoniK.Api.gRPC.V1.TaskStatus;
using TaskStatus = ArmoniK.Core.Common.Storage.TaskStatus;

namespace ArmoniK.Core.Adapters.Memory;

Expand Down Expand Up @@ -168,41 +169,6 @@ public Task<IEnumerable<PartitionTaskStatusCount>> CountPartitionTasksAsync(Canc
return Task.FromResult(res as IEnumerable<PartitionTaskStatusCount>);
}

/// <inheritdoc />
public async Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
{
var count = 0;

foreach (var session in session2TaskIds_.Keys)
{
var statusFilter = new TaskFilter
{
Included = new TaskFilter.Types.StatusesRequest
{
Statuses =
{
status,
},
},
Session = new TaskFilter.Types.IdsRequest
{
Ids =
{
session,
},
},
};

count += await ListTasksAsync(statusFilter,
cancellationToken)
.CountAsync(cancellationToken)
.ConfigureAwait(false);
}

return count;
}

/// <inheritdoc />
public Task DeleteTaskAsync(string id,
CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -235,9 +201,9 @@ public IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
{
TaskFilter.StatusesOneofCase.None => true,
TaskFilter.StatusesOneofCase.Included => filter.Included.Statuses.Contains(taskId2TaskData_[taskId]
.Status),
.Status.ToGrpcStatus()),
TaskFilter.StatusesOneofCase.Excluded => !filter.Excluded.Statuses.Contains(taskId2TaskData_[taskId]
.Status),
.Status.ToGrpcStatus()),
_ => throw new ArgumentException("Filter is set to an unknown StatusesCase."),
})
.ToAsyncEnumerable();
Expand Down Expand Up @@ -416,16 +382,6 @@ public Task<TaskData> ReleaseTask(TaskData taskData,
};
}));

/// <inheritdoc />
public Task<IEnumerable<GetTaskStatusReply.Types.IdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
=> Task.FromResult(taskId2TaskData_.Where(tdm => taskIds.Contains(tdm.Key))
.Select(model => new GetTaskStatusReply.Types.IdStatus
{
Status = model.Value.Status,
TaskId = model.Value.TaskId,
}));

public IAsyncEnumerable<(string taskId, IEnumerable<string> expectedOutputKeys)> GetTasksExpectedOutputKeys(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
=> taskId2TaskData_.Where(pair => taskIds.Contains(pair.Key))
Expand Down Expand Up @@ -508,4 +464,46 @@ public Task<HealthCheckResult> Check(HealthCheckTag tag)
=> Task.FromResult(isInitialized_
? HealthCheckResult.Healthy()
: HealthCheckResult.Unhealthy());

/// <inheritdoc />
public async Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
{
var count = 0;

foreach (var session in session2TaskIds_.Keys)
{
var statusFilter = new TaskFilter
{
Included = new TaskFilter.Types.StatusesRequest
{
Statuses =
{
status.ToGrpcStatus(),
},
},
Session = new TaskFilter.Types.IdsRequest
{
Ids =
{
session,
},
},
};

count += await ListTasksAsync(statusFilter,
cancellationToken)
.CountAsync(cancellationToken)
.ConfigureAwait(false);
}

return count;
}

/// <inheritdoc />
public Task<IEnumerable<TaskIdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
=> Task.FromResult(taskId2TaskData_.Where(tdm => taskIds.Contains(tdm.Key))
.Select(model => new TaskIdStatus(model.Value.TaskId,
model.Value.Status)));
}
145 changes: 71 additions & 74 deletions Adaptors/MongoDB/src/TaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
using MongoDB.Driver;
using MongoDB.Driver.Linq;

using TaskStatus = ArmoniK.Api.gRPC.V1.TaskStatus;
using TaskStatus = ArmoniK.Core.Common.Storage.TaskStatus;

namespace ArmoniK.Core.Adapters.MongoDB;

Expand Down Expand Up @@ -174,27 +174,6 @@ public async Task StartTask(TaskData taskData,
}
}

/// <inheritdoc />
public async Task<IEnumerable<TaskStatusCount>> CountTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(CountTasksAsync)}");

var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();


var res = await taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.GroupBy(model => model.Status)
.Select(models => new TaskStatusCount(models.Key,
models.Count()))
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

return res;
}

/// <inheritdoc />
public async Task<IEnumerable<TaskStatusCount>> CountTasksAsync(Expression<Func<TaskData, bool>> filter,
CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -239,21 +218,6 @@ public async Task<IEnumerable<PartitionTaskStatusCount>> CountPartitionTasksAsyn
return res;
}

/// <inheritdoc />
public Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(CountAllTasksAsync)}");

var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

var res = taskCollection.AsQueryable(sessionHandle)
.Count(model => model.Status == status);

return Task.FromResult(res);
}

/// <inheritdoc />
public async Task DeleteTaskAsync(string id,
CancellationToken cancellationToken = default)
Expand All @@ -278,24 +242,6 @@ public async Task DeleteTaskAsync(string id,
}
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(ListTasksAsync)}");
var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

await foreach (var taskId in taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.Select(model => model.TaskId)
.ToAsyncEnumerable(cancellationToken)
.ConfigureAwait(false))
{
yield return taskId;
}
}

/// <inheritdoc />
public async Task<(IEnumerable<T> tasks, long totalCount)> ListTasksAsync<T>(Expression<Func<TaskData, bool>> filter,
Expression<Func<TaskData, object?>> orderField,
Expand Down Expand Up @@ -571,25 +517,6 @@ await ReadTaskAsync(taskData.TaskId,
.ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IEnumerable<GetTaskStatusReply.Types.IdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(GetTaskStatus)}");
var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

return await taskCollection.AsQueryable(sessionHandle)
.Where(tdm => taskIds.Contains(tdm.TaskId))
.Select(model => new GetTaskStatusReply.Types.IdStatus
{
Status = model.Status,
TaskId = model.TaskId,
})
.ToListAsync(cancellationToken)
.ConfigureAwait(false);
}

public IAsyncEnumerable<(string taskId, IEnumerable<string> expectedOutputKeys)> GetTasksExpectedOutputKeys(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -720,4 +647,74 @@ await taskCollectionProvider_.Init(cancellationToken)
isInitialized_ = true;
}
}

/// <inheritdoc />
public Task<int> CountAllTasksAsync(TaskStatus status,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(CountAllTasksAsync)}");

var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

var res = taskCollection.AsQueryable(sessionHandle)
.Count(model => model.Status == status);

return Task.FromResult(res);
}

/// <inheritdoc />
public async Task<IEnumerable<TaskIdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(GetTaskStatus)}");
var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

return await taskCollection.AsQueryable(sessionHandle)
.Where(tdm => taskIds.Contains(tdm.TaskId))
.Select(model => new TaskIdStatus(model.TaskId,
model.Status))
.ToListAsync(cancellationToken)
.ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IEnumerable<TaskStatusCount>> CountTasksAsync(TaskFilter filter,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(CountTasksAsync)}");

var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();


var res = await taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.GroupBy(model => model.Status)
.Select(models => new TaskStatusCount(models.Key,
models.Count()))
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

return res;
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListTasksAsync(TaskFilter filter,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(ListTasksAsync)}");
var sessionHandle = sessionProvider_.Get();
var taskCollection = taskCollectionProvider_.Get();

await foreach (var taskId in taskCollection.AsQueryable(sessionHandle)
.FilterQuery(filter)
.Select(model => model.TaskId)
.ToAsyncEnumerable(cancellationToken)
.ConfigureAwait(false))
{
yield return taskId;
}
}
}
1 change: 1 addition & 0 deletions Adaptors/MongoDB/tests/BsonSerializerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

using Output = ArmoniK.Core.Common.Storage.Output;
using TaskOptions = ArmoniK.Core.Base.DataStructures.TaskOptions;
using TaskStatus = ArmoniK.Core.Common.Storage.TaskStatus;

namespace ArmoniK.Core.Adapters.MongoDB.Tests;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
using System;
using System.Collections.Generic;

using ArmoniK.Api.gRPC.V1;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Core.Common.Utils;

using NUnit.Framework;

using Output = ArmoniK.Core.Common.Storage.Output;
using TaskOptions = ArmoniK.Core.Base.DataStructures.TaskOptions;

namespace ArmoniK.Core.Adapters.MongoDB.Tests;

[TestFixture(TestOf = typeof(ExpressionsBuilders))]
Expand Down
2 changes: 1 addition & 1 deletion Common/src/Pollster/TaskHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
using Microsoft.Extensions.Logging;

using Output = ArmoniK.Api.gRPC.V1.Output;
using TaskStatus = ArmoniK.Api.gRPC.V1.TaskStatus;
using TaskStatus = ArmoniK.Core.Common.Storage.TaskStatus;

namespace ArmoniK.Core.Common.Pollster;

Expand Down
2 changes: 0 additions & 2 deletions Common/src/Storage/Events/NewTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
using System.Collections.Generic;
using System.Linq;

using ArmoniK.Api.gRPC.V1;

namespace ArmoniK.Core.Common.Storage.Events;

/// <summary>
Expand Down
2 changes: 0 additions & 2 deletions Common/src/Storage/Events/TaskStatusUpdate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using ArmoniK.Api.gRPC.V1;

namespace ArmoniK.Core.Common.Storage.Events;

/// <summary>
Expand Down
6 changes: 2 additions & 4 deletions Common/src/Storage/ITaskTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

using Microsoft.Extensions.Logging;

using TaskStatus = ArmoniK.Api.gRPC.V1.TaskStatus;

namespace ArmoniK.Core.Common.Storage;

/// <summary>
Expand Down Expand Up @@ -324,8 +322,8 @@ Task<TaskData> ReleaseTask(TaskData taskData,
/// <returns>
/// Reply status metadata
/// </returns>
Task<IEnumerable<GetTaskStatusReply.Types.IdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default);
Task<IEnumerable<TaskIdStatus>> GetTaskStatus(IEnumerable<string> taskIds,
CancellationToken cancellationToken = default);

/// <summary>
/// Get expected output keys of tasks given their ids
Expand Down
2 changes: 0 additions & 2 deletions Common/src/Storage/PartitionTaskStatusCount.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using ArmoniK.Api.gRPC.V1;

namespace ArmoniK.Core.Common.Storage;

/// <summary>
Expand Down
Loading

0 comments on commit 5bd5da9

Please sign in to comment.