Skip to content

Commit

Permalink
Use Utils ObjectPool
Browse files Browse the repository at this point in the history
  • Loading branch information
lemaitre-aneo committed Jun 5, 2024
1 parent 0e22148 commit 8c11db1
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 246 deletions.
75 changes: 39 additions & 36 deletions Client/src/Common/Submitter/BaseClientSubmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
using Google.Protobuf;

using Grpc.Core;
using Grpc.Net.Client;

using JetBrains.Annotations;

Expand Down Expand Up @@ -69,7 +70,7 @@ public abstract class BaseClientSubmitter<T>
/// <summary>
/// The channel pool to use for creating clients
/// </summary>
private ChannelPool? channelPool_;
private ObjectPool<GrpcChannel>? channelPool_;

/// <summary>
/// Base Object for all Client submitter
Expand All @@ -95,8 +96,8 @@ protected BaseClientSubmitter(Properties properties,
TaskOptions.PartitionId,
});

configuration_ = ChannelPool.WithChannel(channel => new Results.ResultsClient(channel).GetServiceConfiguration(new Empty())
.DataChunkMaxSize);
configuration_ = ChannelPool.WithInstance(channel => new Results.ResultsClient(channel).GetServiceConfiguration(new Empty())
.DataChunkMaxSize);
}

private ILoggerFactory LoggerFactory { get; }
Expand All @@ -115,7 +116,7 @@ protected BaseClientSubmitter(Properties properties,
/// <summary>
/// The channel pool to use for creating clients
/// </summary>
public ChannelPool ChannelPool
public ObjectPool<GrpcChannel> ChannelPool
=> channelPool_ ??= ClientServiceConnector.ControlPlaneConnectionPool(properties_,
LoggerFactory);

Expand All @@ -129,7 +130,7 @@ private Session CreateSession(IEnumerable<string> partitionIds)
{
using var _ = Logger.LogFunction();
Logger.LogDebug("Creating Session... ");
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var sessionsClient = new Sessions.SessionsClient(channel);
var createSessionReply = sessionsClient.CreateSession(new CreateSessionRequest
{
Expand Down Expand Up @@ -167,7 +168,7 @@ public TaskStatus GetTaskStatus(string taskId)
/// <returns></returns>
public IEnumerable<Tuple<string, TaskStatus>> GetTaskStatues(params string[] taskIds)
{
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var tasksClient = new Tasks.TasksClient(channel);
return tasksClient.ListTasks(new Filters
{
Expand Down Expand Up @@ -200,10 +201,10 @@ public IEnumerable<Tuple<string, TaskStatus>> GetTaskStatues(params string[] tas
// TODO: This function should not have Output as a return type because it is a gRPC type
public Output GetTaskOutputInfo(string taskId)
{
var getTaskResponse = ChannelPool.WithChannel(channel => new Tasks.TasksClient(channel).GetTask(new GetTaskRequest
{
TaskId = taskId,
}));
var getTaskResponse = ChannelPool.WithInstance(channel => new Tasks.TasksClient(channel).GetTask(new GetTaskRequest
{
TaskId = taskId,
}));
return new Output
{
Error = new Output.Types.Error
Expand Down Expand Up @@ -296,7 +297,7 @@ private IEnumerable<string> ChunkSubmitTasksWithDependencies(IEnumerable<Tuple<s
{
for (var nbRetry = 0; nbRetry < maxRetries; nbRetry++)
{
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var resultsClient = new Results.ResultsClient(channel);

try
Expand Down Expand Up @@ -400,7 +401,7 @@ private IEnumerable<string> ChunkSubmitTasksWithDependencies(IEnumerable<Tuple<s

for (var nbRetry = 0; nbRetry < maxRetries; nbRetry++)
{
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var tasksClient = new Tasks.TasksClient(channel);

try
Expand Down Expand Up @@ -500,7 +501,7 @@ public void WaitForTasksCompletion(IEnumerable<string> taskIds,
delayMs,
retry =>
{
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);

if (retry > 1)
Expand Down Expand Up @@ -565,7 +566,7 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,
// TODO: use ListResult
var idStatusPair = result2TaskDic.Keys.ParallelSelect(async resultId =>
{
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var resultsClient = new Results.ResultsClient(channel);
var result = await resultsClient.GetResultAsync(new GetResultRequest
{
Expand Down Expand Up @@ -640,14 +641,14 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,
nameof(GetResultIds));
}

return ChannelPool.WithChannel(channel => new Tasks.TasksClient(channel).GetResultIds(new GetResultIdsRequest
{
TaskId =
{
taskIds,
},
})
.TaskResults);
return ChannelPool.WithInstance(channel => new Tasks.TasksClient(channel).GetResultIds(new GetResultIdsRequest
{
TaskId =
{
taskIds,
},
})
.TaskResults);
},
true,
Logger,
Expand Down Expand Up @@ -680,7 +681,7 @@ public byte[] GetResult(string taskId,
ResultId = resultId,
Session = SessionId.Id,
};
using var channel = ChannelPool.GetChannel();
using var channel = ChannelPool.Get();
var eventsClient = new Events.EventsClient(channel);
eventsClient.WaitForResultsAsync(SessionId.Id,
new List<string>
Expand Down Expand Up @@ -740,8 +741,9 @@ public IEnumerable<Tuple<string, byte[]>> GetResults(IEnumerable<string> taskIds
public async Task<byte[]?> TryGetResultAsync(ResultRequest resultRequest,
CancellationToken cancellationToken = default)
{
using var channel = ChannelPool.GetChannel();
var resultsClient = new Results.ResultsClient(channel);
await using var channel = await ChannelPool.GetAsync(cancellationToken)
.ConfigureAwait(false);
var resultsClient = new Results.ResultsClient(channel);
var getResultResponse = await resultsClient.GetResultAsync(new GetResultRequest
{
ResultId = resultRequest.ResultId,
Expand Down Expand Up @@ -951,17 +953,18 @@ public IList<Tuple<string, byte[]>> TryGetResults(IList<string> resultIds)
/// <returns>Dictionary where each result name is associated with its result id</returns>
[PublicAPI]
public Dictionary<string, string> CreateResultsMetadata(IEnumerable<string> resultNames)
=> ChannelPool.WithChannel(c => new Results.ResultsClient(c).CreateResultsMetaData(new CreateResultsMetaDataRequest
{
SessionId = SessionId.Id,
Results =
{
resultNames.Select(name => new CreateResultsMetaDataRequest.Types.ResultCreate
{
Name = name,
}),
},
}))
=> ChannelPool.WithInstance(c => new Results.ResultsClient(c).CreateResultsMetaData(new CreateResultsMetaDataRequest
{
SessionId = SessionId.Id,
Results =
{
resultNames.Select(name
=> new CreateResultsMetaDataRequest.Types.ResultCreate
{
Name = name,
}),
},
}))
.Results.ToDictionary(r => r.Name,
r => r.ResultId);
}
188 changes: 0 additions & 188 deletions Client/src/Common/Submitter/ChannelPool.cs

This file was deleted.

Loading

0 comments on commit 8c11db1

Please sign in to comment.