Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: reduce number of methods in ISessionTable and implement them as extensions methods #514

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 7 additions & 56 deletions Adaptors/Memory/src/SessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Convertors;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Utils;

Expand Down Expand Up @@ -78,36 +76,13 @@ public Task<string> SetSessionDataAsync(IEnumerable<string> partitionIds,
}

/// <inheritdoc />
public Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
if (!storage_.ContainsKey(sessionId))
{
throw new SessionNotFoundException($"Key '{sessionId}' not found");
}

return Task.FromResult(storage_[sessionId]);
}

/// <inheritdoc />
public Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default)
=> Task.FromResult(GetSessionAsync(sessionId,
cancellationToken)
.Result.Status == SessionStatus.Cancelled);

/// <inheritdoc />
public Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
if (!storage_.ContainsKey(sessionId))
{
throw new SessionNotFoundException($"Key '{sessionId}' not found");
}

return Task.FromResult(storage_[sessionId]
.Options);
}
public IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default)
=> storage_.Select(pair => pair.Value)
.Where(filter.Compile())
.Select(selector.Compile())
.ToAsyncEnumerable();

/// <inheritdoc />
public Task<SessionData> CancelSessionAsync(string sessionId,
Expand Down Expand Up @@ -144,30 +119,6 @@ public Task DeleteSessionAsync(string sessionId,
return Task.CompletedTask;
}


/// <inheritdoc />
public IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
CancellationToken cancellationToken = default)
{
var rawList = storage_.Keys.ToAsyncEnumerable();

if (sessionFilter.Sessions.Any())
{
rawList = storage_.Keys.Intersect(sessionFilter.Sessions)
.ToAsyncEnumerable();
}

return rawList.Where(sessionId => sessionFilter.StatusesCase switch
{
SessionFilter.StatusesOneofCase.None => true,
SessionFilter.StatusesOneofCase.Included => sessionFilter.Included.Statuses.Contains(storage_[sessionId]
.Status.ToGrpcStatus()),
SessionFilter.StatusesOneofCase.Excluded => !sessionFilter.Excluded.Statuses.Contains(storage_[sessionId]
.Status.ToGrpcStatus()),
_ => throw new ArgumentException("Filter is set to an unknown StatusesCase."),
});
}

/// <inheritdoc />
public Task<(IEnumerable<SessionData> sessions, long totalCount)> ListSessionsAsync(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, object?>> orderField,
Expand Down
93 changes: 10 additions & 83 deletions Adaptors/MongoDB/src/SessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.Common.Utils;
using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Adapters.MongoDB.Common;
using ArmoniK.Core.Adapters.MongoDB.Table;
using ArmoniK.Core.Adapters.MongoDB.Table.DataModel;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
Expand Down Expand Up @@ -88,67 +84,18 @@ await sessionCollection.InsertOneAsync(data,
}

/// <inheritdoc />
public async Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default)
public IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction(sessionId);
using var activity = activitySource_.StartActivity($"{nameof(GetSessionAsync)}");
activity?.SetTag($"{nameof(GetSessionAsync)}_sessionId",
sessionId);
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();


try
{
return await sessionCollection.Find(session => session.SessionId == sessionId)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Key '{sessionId}' not found",
e);
}
}


/// <inheritdoc />
public async Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction(sessionId);
using var activity = activitySource_.StartActivity($"{nameof(IsSessionCancelledAsync)}");
activity?.SetTag($"{nameof(IsSessionCancelledAsync)}_sessionId",
sessionId);

return (await GetSessionAsync(sessionId,
cancellationToken)
.ConfigureAwait(false)).Status == SessionStatus.Cancelled;
}

/// <inheritdoc />
public async Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(GetDefaultTaskOptionAsync)}");
activity?.SetTag($"{nameof(GetDefaultTaskOptionAsync)}_sessionId",
sessionId);
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();
using var activity = activitySource_.StartActivity($"{nameof(FindSessionsAsync)}");
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();

try
{
return await sessionCollection.Find(sdm => sdm.SessionId == sessionId)
.Project(sdm => sdm.Options)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Key '{sessionId}' not found",
e);
}
return sessionCollection.Find(sessionHandle,
filter)
.Project(selector)
.ToAsyncEnumerable(cancellationToken);
}

/// <inheritdoc />
Expand Down Expand Up @@ -206,26 +153,6 @@ public async Task DeleteSessionAsync(string sessionId,
}
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction();
using var activity = activitySource_.StartActivity($"{nameof(ListSessionsAsync)}");
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();

await foreach (var sessionId in sessionCollection.AsQueryable(sessionHandle)
.FilterQuery(sessionFilter)
.Select(model => model.SessionId)
.ToAsyncEnumerable()
.WithCancellation(cancellationToken)
.ConfigureAwait(false))
{
yield return sessionId;
}
}

public async Task<(IEnumerable<SessionData> sessions, long totalCount)> ListSessionsAsync(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, object?>> orderField,
bool ascOrder,
Expand Down
45 changes: 7 additions & 38 deletions Common/src/Storage/ISessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base;
using ArmoniK.Core.Base.DataStructures;

Expand Down Expand Up @@ -53,37 +52,17 @@ Task<string> SetSessionDataAsync(IEnumerable<string> partitionIds,
CancellationToken cancellationToken = default);

/// <summary>
/// Get SessionData from sessionId
/// Find all sessions matching the given filter and ordering
/// </summary>
/// <param name="sessionId">Id of the session to get</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Data of the session
/// </returns>
Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// Query a session status to check if it is canceled
/// </summary>
/// <param name="sessionId">Id of the session to check</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Boolean representing the cancelation status of the session
/// </returns>
Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// Get default task metadata for a session given its id
/// </summary>
/// <param name="sessionId">Id of the target session</param>
/// <param name="filter">Filter to select sessions</param>
/// <param name="selector">Expression to select part of the returned session data</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Default task metadata of this session
/// Session metadata matching the request
/// </returns>
Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default);
IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default);

/// <summary>
/// Cancel a session
Expand All @@ -107,16 +86,6 @@ Task<SessionData> CancelSessionAsync(string sessionId,
Task DeleteSessionAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// List all sessions matching a given filter
/// </summary>
/// <param name="sessionFilter">Session filter describing the sessions to be listed </param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Collection of sessions that matched the filter
/// </returns>
IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
CancellationToken cancellationToken = default);

/// <summary>
/// List all sessions matching the given request
Expand Down
113 changes: 113 additions & 0 deletions Common/src/Storage/SessionTableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2023. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;

namespace ArmoniK.Core.Common.Storage;

public static class SessionTableExtensions
{
/// <summary>
/// Get SessionData from sessionId
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the session to get</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Data of the session
/// </returns>
public static async Task<SessionData> GetSessionAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}

/// <summary>
/// Query a session status to check if it is canceled
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the session to check</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Boolean representing the cancellation status of the session
/// </returns>
public static async Task<bool> IsSessionCancelledAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data.Status == SessionStatus.Cancelled,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}

/// <summary>
/// Get default task metadata for a session given its id
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the target session</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Default task metadata of this session
/// </returns>
public static async Task<TaskOptions> GetDefaultTaskOptionAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data.Options,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}
}
Loading