Skip to content

Commit

Permalink
refactor: reduce number of methods in ISessionTable and implement the…
Browse files Browse the repository at this point in the history
…m as extensions methods
  • Loading branch information
aneojgurhem committed Sep 29, 2023
1 parent f806170 commit e9d5924
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 224 deletions.
63 changes: 7 additions & 56 deletions Adaptors/Memory/src/SessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
using ArmoniK.Core.Common.gRPC.Convertors;
using ArmoniK.Core.Common.Storage;
using ArmoniK.Utils;

Expand Down Expand Up @@ -78,36 +76,13 @@ public Task<string> SetSessionDataAsync(IEnumerable<string> partitionIds,
}

/// <inheritdoc />
public Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
if (!storage_.ContainsKey(sessionId))
{
throw new SessionNotFoundException($"Key '{sessionId}' not found");
}

return Task.FromResult(storage_[sessionId]);
}

/// <inheritdoc />
public Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default)
=> Task.FromResult(GetSessionAsync(sessionId,
cancellationToken)
.Result.Status == SessionStatus.Cancelled);

/// <inheritdoc />
public Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
if (!storage_.ContainsKey(sessionId))
{
throw new SessionNotFoundException($"Key '{sessionId}' not found");
}

return Task.FromResult(storage_[sessionId]
.Options);
}
public IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default)
=> storage_.Select(pair => pair.Value)
.Where(filter.Compile())
.Select(selector.Compile())
.ToAsyncEnumerable();

/// <inheritdoc />
public Task<SessionData> CancelSessionAsync(string sessionId,
Expand Down Expand Up @@ -144,30 +119,6 @@ public Task DeleteSessionAsync(string sessionId,
return Task.CompletedTask;
}


/// <inheritdoc />
public IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
CancellationToken cancellationToken = default)
{
var rawList = storage_.Keys.ToAsyncEnumerable();

if (sessionFilter.Sessions.Any())
{
rawList = storage_.Keys.Intersect(sessionFilter.Sessions)
.ToAsyncEnumerable();
}

return rawList.Where(sessionId => sessionFilter.StatusesCase switch
{
SessionFilter.StatusesOneofCase.None => true,
SessionFilter.StatusesOneofCase.Included => sessionFilter.Included.Statuses.Contains(storage_[sessionId]
.Status.ToGrpcStatus()),
SessionFilter.StatusesOneofCase.Excluded => !sessionFilter.Excluded.Statuses.Contains(storage_[sessionId]
.Status.ToGrpcStatus()),
_ => throw new ArgumentException("Filter is set to an unknown StatusesCase."),
});
}

/// <inheritdoc />
public Task<(IEnumerable<SessionData> sessions, long totalCount)> ListSessionsAsync(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, object?>> orderField,
Expand Down
93 changes: 10 additions & 83 deletions Adaptors/MongoDB/src/SessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.Common.Utils;
using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Adapters.MongoDB.Common;
using ArmoniK.Core.Adapters.MongoDB.Table;
using ArmoniK.Core.Adapters.MongoDB.Table.DataModel;
using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;
Expand Down Expand Up @@ -88,67 +84,18 @@ await sessionCollection.InsertOneAsync(data,
}

/// <inheritdoc />
public async Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default)
public IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction(sessionId);
using var activity = activitySource_.StartActivity($"{nameof(GetSessionAsync)}");
activity?.SetTag($"{nameof(GetSessionAsync)}_sessionId",
sessionId);
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();


try
{
return await sessionCollection.Find(session => session.SessionId == sessionId)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Key '{sessionId}' not found",
e);
}
}


/// <inheritdoc />
public async Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction(sessionId);
using var activity = activitySource_.StartActivity($"{nameof(IsSessionCancelledAsync)}");
activity?.SetTag($"{nameof(IsSessionCancelledAsync)}_sessionId",
sessionId);

return (await GetSessionAsync(sessionId,
cancellationToken)
.ConfigureAwait(false)).Status == SessionStatus.Cancelled;
}

/// <inheritdoc />
public async Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default)
{
using var activity = activitySource_.StartActivity($"{nameof(GetDefaultTaskOptionAsync)}");
activity?.SetTag($"{nameof(GetDefaultTaskOptionAsync)}_sessionId",
sessionId);
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();
using var activity = activitySource_.StartActivity($"{nameof(FindSessionsAsync)}");
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();

try
{
return await sessionCollection.Find(sdm => sdm.SessionId == sessionId)
.Project(sdm => sdm.Options)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Key '{sessionId}' not found",
e);
}
return sessionCollection.Find(sessionHandle,
filter)
.Project(selector)
.ToAsyncEnumerable(cancellationToken);
}

/// <inheritdoc />
Expand Down Expand Up @@ -206,26 +153,6 @@ public async Task DeleteSessionAsync(string sessionId,
}
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var _ = Logger.LogFunction();
using var activity = activitySource_.StartActivity($"{nameof(ListSessionsAsync)}");
var sessionHandle = sessionProvider_.Get();
var sessionCollection = sessionCollectionProvider_.Get();

await foreach (var sessionId in sessionCollection.AsQueryable(sessionHandle)
.FilterQuery(sessionFilter)
.Select(model => model.SessionId)
.ToAsyncEnumerable()
.WithCancellation(cancellationToken)
.ConfigureAwait(false))
{
yield return sessionId;
}
}

public async Task<(IEnumerable<SessionData> sessions, long totalCount)> ListSessionsAsync(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, object?>> orderField,
bool ascOrder,
Expand Down
45 changes: 7 additions & 38 deletions Common/src/Storage/ISessionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Core.Base;
using ArmoniK.Core.Base.DataStructures;

Expand Down Expand Up @@ -53,37 +52,17 @@ Task<string> SetSessionDataAsync(IEnumerable<string> partitionIds,
CancellationToken cancellationToken = default);

/// <summary>
/// Get SessionData from sessionId
/// Find all sessions matching the given filter and ordering
/// </summary>
/// <param name="sessionId">Id of the session to get</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Data of the session
/// </returns>
Task<SessionData> GetSessionAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// Query a session status to check if it is canceled
/// </summary>
/// <param name="sessionId">Id of the session to check</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Boolean representing the cancelation status of the session
/// </returns>
Task<bool> IsSessionCancelledAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// Get default task metadata for a session given its id
/// </summary>
/// <param name="sessionId">Id of the target session</param>
/// <param name="filter">Filter to select sessions</param>
/// <param name="selector">Expression to select part of the returned session data</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Default task metadata of this session
/// Session metadata matching the request
/// </returns>
Task<TaskOptions> GetDefaultTaskOptionAsync(string sessionId,
CancellationToken cancellationToken = default);
IAsyncEnumerable<T> FindSessionsAsync<T>(Expression<Func<SessionData, bool>> filter,
Expression<Func<SessionData, T>> selector,
CancellationToken cancellationToken = default);

/// <summary>
/// Cancel a session
Expand All @@ -107,16 +86,6 @@ Task<SessionData> CancelSessionAsync(string sessionId,
Task DeleteSessionAsync(string sessionId,
CancellationToken cancellationToken = default);

/// <summary>
/// List all sessions matching a given filter
/// </summary>
/// <param name="sessionFilter">Session filter describing the sessions to be listed </param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Collection of sessions that matched the filter
/// </returns>
IAsyncEnumerable<string> ListSessionsAsync(SessionFilter sessionFilter,
CancellationToken cancellationToken = default);

/// <summary>
/// List all sessions matching the given request
Expand Down
113 changes: 113 additions & 0 deletions Common/src/Storage/SessionTableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2023. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Core.Base.DataStructures;
using ArmoniK.Core.Common.Exceptions;

namespace ArmoniK.Core.Common.Storage;

public static class SessionTableExtensions
{
/// <summary>
/// Get SessionData from sessionId
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the session to get</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Data of the session
/// </returns>
public static async Task<SessionData> GetSessionAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}

/// <summary>
/// Query a session status to check if it is canceled
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the session to check</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Boolean representing the cancellation status of the session
/// </returns>
public static async Task<bool> IsSessionCancelledAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data.Status == SessionStatus.Cancelled,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}

/// <summary>
/// Get default task metadata for a session given its id
/// </summary>
/// <param name="sessionTable">Interface to manage sessions lifecycle</param>
/// <param name="sessionId">Id of the target session</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Default task metadata of this session
/// </returns>
public static async Task<TaskOptions> GetDefaultTaskOptionAsync(this ISessionTable sessionTable,
string sessionId,
CancellationToken cancellationToken = default)
{
try
{
return await sessionTable.FindSessionsAsync(data => data.SessionId == sessionId,
data => data.Options,
cancellationToken)
.SingleAsync(cancellationToken)
.ConfigureAwait(false);
}
catch (InvalidOperationException e)
{
throw new SessionNotFoundException($"Session {sessionId} not found.",
e);
}
}
}
Loading

0 comments on commit e9d5924

Please sign in to comment.