diff --git a/Client/src/Common/Submitter/BaseClientSubmitter.cs b/Client/src/Common/Submitter/BaseClientSubmitter.cs index f28fff6ee..354e9dab0 100644 --- a/Client/src/Common/Submitter/BaseClientSubmitter.cs +++ b/Client/src/Common/Submitter/BaseClientSubmitter.cs @@ -38,6 +38,7 @@ using Google.Protobuf; using Grpc.Core; +using Grpc.Net.Client; using JetBrains.Annotations; @@ -69,7 +70,7 @@ public abstract class BaseClientSubmitter /// /// The channel pool to use for creating clients /// - private ChannelPool? channelPool_; + private ObjectPool? channelPool_; /// /// Base Object for all Client submitter @@ -95,8 +96,8 @@ protected BaseClientSubmitter(Properties properties, TaskOptions.PartitionId, }); - configuration_ = ChannelPool.WithChannel(channel => new Results.ResultsClient(channel).GetServiceConfiguration(new Empty()) - .DataChunkMaxSize); + configuration_ = ChannelPool.WithInstance(channel => new Results.ResultsClient(channel).GetServiceConfiguration(new Empty()) + .DataChunkMaxSize); } private ILoggerFactory LoggerFactory { get; } @@ -115,7 +116,7 @@ protected BaseClientSubmitter(Properties properties, /// /// The channel pool to use for creating clients /// - public ChannelPool ChannelPool + public ObjectPool ChannelPool => channelPool_ ??= ClientServiceConnector.ControlPlaneConnectionPool(properties_, LoggerFactory); @@ -129,7 +130,7 @@ private Session CreateSession(IEnumerable partitionIds) { using var _ = Logger.LogFunction(); Logger.LogDebug("Creating Session... "); - using var channel = ChannelPool.GetChannel(); + using var channel = ChannelPool.Get(); var sessionsClient = new Sessions.SessionsClient(channel); var createSessionReply = sessionsClient.CreateSession(new CreateSessionRequest { @@ -167,7 +168,7 @@ public TaskStatus GetTaskStatus(string taskId) /// public IEnumerable> GetTaskStatues(params string[] taskIds) { - using var channel = ChannelPool.GetChannel(); + using var channel = ChannelPool.Get(); var tasksClient = new Tasks.TasksClient(channel); return tasksClient.ListTasks(new Filters { @@ -200,10 +201,10 @@ public IEnumerable> GetTaskStatues(params string[] tas // TODO: This function should not have Output as a return type because it is a gRPC type public Output GetTaskOutputInfo(string taskId) { - var getTaskResponse = ChannelPool.WithChannel(channel => new Tasks.TasksClient(channel).GetTask(new GetTaskRequest - { - TaskId = taskId, - })); + var getTaskResponse = ChannelPool.WithInstance(channel => new Tasks.TasksClient(channel).GetTask(new GetTaskRequest + { + TaskId = taskId, + })); return new Output { Error = new Output.Types.Error @@ -296,7 +297,7 @@ private IEnumerable ChunkSubmitTasksWithDependencies(IEnumerable ChunkSubmitTasksWithDependencies(IEnumerable taskIds, delayMs, retry => { - using var channel = ChannelPool.GetChannel(); + using var channel = ChannelPool.Get(); var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel); if (retry > 1) @@ -565,7 +566,7 @@ public ResultStatusCollection GetResultStatus(IEnumerable taskIds, // TODO: use ListResult var idStatusPair = result2TaskDic.Keys.ParallelSelect(async resultId => { - using var channel = ChannelPool.GetChannel(); + using var channel = ChannelPool.Get(); var resultsClient = new Results.ResultsClient(channel); var result = await resultsClient.GetResultAsync(new GetResultRequest { @@ -640,14 +641,14 @@ public ResultStatusCollection GetResultStatus(IEnumerable taskIds, nameof(GetResultIds)); } - return ChannelPool.WithChannel(channel => new Tasks.TasksClient(channel).GetResultIds(new GetResultIdsRequest - { - TaskId = - { - taskIds, - }, - }) - .TaskResults); + return ChannelPool.WithInstance(channel => new Tasks.TasksClient(channel).GetResultIds(new GetResultIdsRequest + { + TaskId = + { + taskIds, + }, + }) + .TaskResults); }, true, Logger, @@ -680,7 +681,7 @@ public byte[] GetResult(string taskId, ResultId = resultId, Session = SessionId.Id, }; - using var channel = ChannelPool.GetChannel(); + using var channel = ChannelPool.Get(); var eventsClient = new Events.EventsClient(channel); eventsClient.WaitForResultsAsync(SessionId.Id, new List @@ -740,8 +741,9 @@ public IEnumerable> GetResults(IEnumerable taskIds public async Task TryGetResultAsync(ResultRequest resultRequest, CancellationToken cancellationToken = default) { - using var channel = ChannelPool.GetChannel(); - var resultsClient = new Results.ResultsClient(channel); + await using var channel = await ChannelPool.GetAsync(cancellationToken) + .ConfigureAwait(false); + var resultsClient = new Results.ResultsClient(channel); var getResultResponse = await resultsClient.GetResultAsync(new GetResultRequest { ResultId = resultRequest.ResultId, @@ -951,17 +953,18 @@ public IList> TryGetResults(IList resultIds) /// Dictionary where each result name is associated with its result id [PublicAPI] public Dictionary CreateResultsMetadata(IEnumerable resultNames) - => ChannelPool.WithChannel(c => new Results.ResultsClient(c).CreateResultsMetaData(new CreateResultsMetaDataRequest - { - SessionId = SessionId.Id, - Results = - { - resultNames.Select(name => new CreateResultsMetaDataRequest.Types.ResultCreate - { - Name = name, - }), - }, - })) + => ChannelPool.WithInstance(c => new Results.ResultsClient(c).CreateResultsMetaData(new CreateResultsMetaDataRequest + { + SessionId = SessionId.Id, + Results = + { + resultNames.Select(name + => new CreateResultsMetaDataRequest.Types.ResultCreate + { + Name = name, + }), + }, + })) .Results.ToDictionary(r => r.Name, r => r.ResultId); } diff --git a/Client/src/Common/Submitter/ChannelPool.cs b/Client/src/Common/Submitter/ChannelPool.cs deleted file mode 100644 index 6450a5c23..000000000 --- a/Client/src/Common/Submitter/ChannelPool.cs +++ /dev/null @@ -1,188 +0,0 @@ -// This file is part of the ArmoniK project -// -// Copyright (C) ANEO, 2021-2024. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License") -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using System; -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; - -using Grpc.Net.Client; - -using Microsoft.Extensions.Logging; - -namespace ArmoniK.DevelopmentKit.Client.Common.Submitter; - -/// -/// Helper to have a connection pool for gRPC services -/// -public sealed class ChannelPool -{ - private readonly Func channelFactory_; - - private readonly ILogger? logger_; - - private readonly ConcurrentBag pool_; - - /// - /// Constructs a new channelPool - /// - /// Function used to create new channels - /// loggerFactory used to instantiate a logger for the pool - public ChannelPool(Func channelFactory, - ILoggerFactory? loggerFactory = null) - { - channelFactory_ = channelFactory; - pool_ = new ConcurrentBag(); - logger_ = loggerFactory?.CreateLogger(); - } - - /// - /// Get a channel from the pool. If the pool is empty, create a new channel - /// - /// A GrpcChannel used by nobody else - private GrpcChannel AcquireChannel() - { - if (pool_.TryTake(out var channel)) - { - if (ShutdownOnFailure(channel)) - { - logger_?.LogDebug("Got an invalid channel {channel} from pool", - channel); - } - else - { - logger_?.LogDebug("Acquired already existing channel {channel} from pool", - channel); - return channel; - } - } - - channel = channelFactory_(); - logger_?.LogInformation("Created and acquired new channel {channel} from pool", - channel); - return channel; - } - - /// - /// Release a GrpcChannel to the pool that could be reused later by someone else - /// - /// Channel to release - private void ReleaseChannel(GrpcChannel channel) - { - if (ShutdownOnFailure(channel)) - { - logger_?.LogDebug("Shutdown unhealthy channel {channel}", - channel); - } - else - { - logger_?.LogDebug("Released channel {channel} to pool", - channel); - pool_.Add(channel); - } - } - - /// - /// Check the state of a channel and shutdown it in case of failure - /// - /// Channel to check the state - /// True if the channel has been shut down - private static bool ShutdownOnFailure(GrpcChannel channel) - { - try - { -#if NET5_0_OR_GREATER - switch (channel.State) - { - case ConnectivityState.TransientFailure: - channel.ShutdownAsync() - .Wait(); - channel.Dispose(); - return true; - case ConnectivityState.Shutdown: - return true; - case ConnectivityState.Idle: - case ConnectivityState.Connecting: - case ConnectivityState.Ready: - default: - return false; - } -#else - _ = channel; - return false; -#endif - } - catch (InvalidOperationException) - { - return false; - } - } - - /// - /// Get a channel that will be automatically released when disposed - /// - /// - public ChannelGuard GetChannel() - => new(this); - - /// - /// Call f with an acquired channel - /// - /// Function to be called - /// Type of the return type of f - /// Value returned by f - public T WithChannel(Func f) - { - using var channel = GetChannel(); - return f(channel); - } - - /// - /// Helper class that acquires a channel from a pool when constructed, and releases it when disposed - /// - public sealed class ChannelGuard : IDisposable - { - /// - /// Channel that is used by nobody else - /// - [SuppressMessage("Usage", - "CA2213:Disposable fields should be disposed")] - private readonly GrpcChannel channel_; - - private readonly ChannelPool pool_; - - /// - /// Acquire a channel that will be released when disposed - /// - /// - public ChannelGuard(ChannelPool channelPool) - { - pool_ = channelPool; - channel_ = channelPool.AcquireChannel(); - } - - /// - public void Dispose() - => pool_.ReleaseChannel(channel_); - - /// - /// Implicit convert a ChannelGuard into a ChannelBase - /// - /// ChannelGuard - /// GrpcChannel - public static implicit operator GrpcChannel(ChannelGuard guard) - => guard.channel_; - } -} diff --git a/Client/src/Common/Submitter/ClientServiceConnector.cs b/Client/src/Common/Submitter/ClientServiceConnector.cs index 641bbcc00..26e71c3fe 100644 --- a/Client/src/Common/Submitter/ClientServiceConnector.cs +++ b/Client/src/Common/Submitter/ClientServiceConnector.cs @@ -14,8 +14,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System.Threading.Tasks; + using ArmoniK.Api.Client.Options; using ArmoniK.Api.Client.Submitter; +using ArmoniK.Utils; + +using Grpc.Net.Client; using Microsoft.Extensions.Logging; @@ -33,8 +38,8 @@ public class ClientServiceConnector /// Configuration Properties /// Optional logger factory /// The connection pool - public static ChannelPool ControlPlaneConnectionPool(Properties properties, - ILoggerFactory? loggerFactory = null) + public static ObjectPool ControlPlaneConnectionPool(Properties properties, + ILoggerFactory? loggerFactory = null) { var options = new GrpcClient { @@ -52,7 +57,32 @@ public static ChannelPool ControlPlaneConnectionPool(Properties properties, ProxyPassword = properties.ProxyPassword, }; - return new ChannelPool(() => GrpcChannelFactory.CreateChannel(options, - loggerFactory?.CreateLogger(typeof(ClientServiceConnector)))); + return new ObjectPool(ct => new ValueTask(GrpcChannelFactory.CreateChannel(options, + loggerFactory?.CreateLogger(typeof(ClientServiceConnector)))), + + +#if NET5_0_OR_GREATER + async (channel, ct) => + { +switch (channel.State) + { + case ConnectivityState.TransientFailure: + await channel.ShutdownAsync() + .ConfigureAwait(false); + return false; + case ConnectivityState.Shutdown: + return false; + case ConnectivityState.Idle: + case ConnectivityState.Connecting: + case ConnectivityState.Ready: + default: + return true; + } + } +#else + (_, + _) => new ValueTask(true) +#endif + ); } } diff --git a/Client/src/Symphony/ArmonikSymphonyClient.cs b/Client/src/Symphony/ArmonikSymphonyClient.cs index 38bdeeefd..e7e163ea8 100644 --- a/Client/src/Symphony/ArmonikSymphonyClient.cs +++ b/Client/src/Symphony/ArmonikSymphonyClient.cs @@ -18,6 +18,9 @@ using ArmoniK.DevelopmentKit.Client.Common; using ArmoniK.DevelopmentKit.Client.Common.Submitter; using ArmoniK.DevelopmentKit.Common; +using ArmoniK.Utils; + +using Grpc.Net.Client; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; @@ -63,7 +66,7 @@ public ArmonikSymphonyClient(IConfiguration configuration, /// public string SectionGrpc { get; set; } = "Grpc"; - private ChannelPool GrpcPool { get; set; } + private ObjectPool GrpcPool { get; set; } private IConfiguration Configuration { get; } diff --git a/Client/src/Unified/Factory/SessionServiceFactory.cs b/Client/src/Unified/Factory/SessionServiceFactory.cs index ab8a28446..8f15d2751 100644 --- a/Client/src/Unified/Factory/SessionServiceFactory.cs +++ b/Client/src/Unified/Factory/SessionServiceFactory.cs @@ -20,9 +20,12 @@ using ArmoniK.DevelopmentKit.Client.Unified.Services; using ArmoniK.DevelopmentKit.Client.Unified.Services.Admin; using ArmoniK.DevelopmentKit.Common; +using ArmoniK.Utils; using Google.Protobuf.WellKnownTypes; +using Grpc.Net.Client; + using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -51,7 +54,7 @@ public SessionServiceFactory(ILoggerFactory? loggerFactory = null) private ILogger Logger { get; } - private ChannelPool? GrpcPool { get; set; } + private ObjectPool? GrpcPool { get; set; } private ILoggerFactory LoggerFactory { get; } diff --git a/Client/src/Unified/Services/Admin/AdminMonitoringService.cs b/Client/src/Unified/Services/Admin/AdminMonitoringService.cs index 3447329f7..b9577179f 100644 --- a/Client/src/Unified/Services/Admin/AdminMonitoringService.cs +++ b/Client/src/Unified/Services/Admin/AdminMonitoringService.cs @@ -24,6 +24,9 @@ using ArmoniK.Api.gRPC.V1.SortDirection; using ArmoniK.Api.gRPC.V1.Tasks; using ArmoniK.DevelopmentKit.Client.Common.Submitter; +using ArmoniK.Utils; + +using Grpc.Net.Client; using Microsoft.Extensions.Logging; @@ -39,15 +42,15 @@ namespace ArmoniK.DevelopmentKit.Client.Unified.Services.Admin; /// public class AdminMonitoringService { - private readonly ChannelPool channelPool_; + private readonly ObjectPool channelPool_; /// /// The constructor to instantiate this service /// /// The entry point to the control plane /// The factory logger to create logger - public AdminMonitoringService(ChannelPool channelPool, - ILoggerFactory? loggerFactory = null) + public AdminMonitoringService(ObjectPool channelPool, + ILoggerFactory? loggerFactory = null) { Logger = loggerFactory?.CreateLogger(); channelPool_ = channelPool; @@ -61,7 +64,7 @@ public AdminMonitoringService(ChannelPool channelPool, /// public void GetServiceConfiguration() { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var resultsClient = new Results.ResultsClient(channel); var configuration = resultsClient.GetServiceConfiguration(new Empty()); Logger?.LogInformation($"This configuration will be update in the nex version [ {configuration} ]"); @@ -74,7 +77,7 @@ public void GetServiceConfiguration() /// the sessionId of the session to cancel public void CancelSession(string sessionId) { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var sessionsClient = new Sessions.SessionsClient(channel); sessionsClient.CancelSession(new CancelSessionRequest { @@ -98,7 +101,7 @@ public IEnumerable ListAllTasksBySession(string sessionId) public IEnumerable ListTasksBySession(string sessionId, params TaskStatus[] taskStatus) { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var tasksClient = new Tasks.TasksClient(channel); return tasksClient.ListTasks(new Filters @@ -159,7 +162,7 @@ public IEnumerable ListCancelledTasks(string sessionId) /// The list of filtered session public IEnumerable ListAllSessions() { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var sessionsClient = new Sessions.SessionsClient(channel); return sessionsClient.ListSessions(new ListSessionsRequest()) .Sessions.Select(session => session.SessionId); @@ -172,7 +175,7 @@ public IEnumerable ListAllSessions() /// returns a list of session filtered public IEnumerable ListRunningSessions() { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var sessionsClient = new Sessions.SessionsClient(channel); return sessionsClient.ListSessions(new ListSessionsRequest { @@ -213,7 +216,7 @@ public IEnumerable ListRunningSessions() /// returns a list of session filtered public IEnumerable ListCancelledSessions() { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var sessionsClient = new Sessions.SessionsClient(channel); return sessionsClient.ListSessions(new ListSessionsRequest { @@ -281,7 +284,7 @@ public int CountErrorTasksBySession(string sessionId) public int CountTaskBySession(string sessionId, params TaskStatus[] taskStatus) { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var tasksClient = new Tasks.TasksClient(channel); return tasksClient.CountTasksByStatus(new CountTasksByStatusRequest { @@ -322,7 +325,7 @@ public int CountCompletedTasksBySession(string sessionId) /// the taskIds list to cancel public void CancelTasksBySession(IEnumerable taskIds) { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var tasksClient = new Tasks.TasksClient(channel); tasksClient.CancelTasks(new CancelTasksRequest { @@ -340,7 +343,7 @@ public void CancelTasksBySession(IEnumerable taskIds) /// returns a list of pair TaskId/TaskStatus public IEnumerable> GetTaskStatus(IEnumerable taskIds) { - using var channel = channelPool_.GetChannel(); + using var channel = channelPool_.Get(); var tasksClient = new Tasks.TasksClient(channel); return tasksClient.ListTasks(new Filters { diff --git a/Client/src/Unified/Services/Submitter/Service.cs b/Client/src/Unified/Services/Submitter/Service.cs index 39794bbd0..4689d3244 100644 --- a/Client/src/Unified/Services/Submitter/Service.cs +++ b/Client/src/Unified/Services/Submitter/Service.cs @@ -35,6 +35,7 @@ using ArmoniK.Utils; using Grpc.Core; +using Grpc.Net.Client; using JetBrains.Annotations; @@ -796,8 +797,8 @@ private void ResultTask() /// gRPC channel // TODO: Refactor test to remove this // ReSharper disable once UnusedMember.Global - public ChannelBase GetChannel() - => SessionService.ChannelPool.GetChannel(); + public ObjectPool GetChannelPool() + => SessionService.ChannelPool; /// /// Class to return TaskId and the result diff --git a/Common/src/Common/ArmoniK.DevelopmentKit.Common.csproj b/Common/src/Common/ArmoniK.DevelopmentKit.Common.csproj index d4b9dc587..4af40e4fb 100644 --- a/Common/src/Common/ArmoniK.DevelopmentKit.Common.csproj +++ b/Common/src/Common/ArmoniK.DevelopmentKit.Common.csproj @@ -8,7 +8,7 @@ - + diff --git a/Tests/ArmoniK.EndToEndTests/ArmoniK.EndToEndTests.Client/Tests/AggregationPriority/AggregationPriorityTest.cs b/Tests/ArmoniK.EndToEndTests/ArmoniK.EndToEndTests.Client/Tests/AggregationPriority/AggregationPriorityTest.cs index 1d0b7e4a8..2924c6be9 100644 --- a/Tests/ArmoniK.EndToEndTests/ArmoniK.EndToEndTests.Client/Tests/AggregationPriority/AggregationPriorityTest.cs +++ b/Tests/ArmoniK.EndToEndTests/ArmoniK.EndToEndTests.Client/Tests/AggregationPriority/AggregationPriorityTest.cs @@ -156,7 +156,11 @@ private async Task> GetDistribution(int nRows) var taskRawData = new List(); - await foreach (var taskRaw in RetrieveAllTasksStats(service.GetChannel(), + await using var channel = await service!.GetChannelPool() + .GetAsync(CancellationToken.None) + .ConfigureAwait(false); + + await foreach (var taskRaw in RetrieveAllTasksStats(channel, new Filters { Or =