Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: recreate AMQP connection when dropped #605

Merged
merged 10 commits into from
Jan 31, 2024
100 changes: 60 additions & 40 deletions Adaptors/Amqp/src/ConnectionAmqp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,50 +37,74 @@ namespace ArmoniK.Core.Adapters.Amqp;
[UsedImplicitly]
public class ConnectionAmqp : IConnectionAmqp
{
private readonly AsyncLazy connectionTask_;
private readonly ILogger<ConnectionAmqp> logger_;
private readonly QueueCommon.Amqp options_;
private bool isInitialized_;
private readonly ExecutionSingleizer<Connection> connectionSingleizer_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private readonly ExecutionSingleizer<Connection> connectionSingleizer_;
private readonly ExecutionSingleizer<Connection> connectionSingleizer_ = new();

private readonly ILogger<ConnectionAmqp> logger_;
private readonly QueueCommon.Amqp options_;
private Connection? connection_;

public ConnectionAmqp(QueueCommon.Amqp options,
ILogger<ConnectionAmqp> logger)
{
options_ = options;
logger_ = logger;
connectionTask_ = new AsyncLazy(() => InitTask(this));
options_ = options;
logger_ = logger;
connectionSingleizer_ = new ExecutionSingleizer<Connection>();
}

public Connection? Connection { get; private set; }

public Task<HealthCheckResult> Check(HealthCheckTag tag)
=> tag switch
{
HealthCheckTag.Startup or HealthCheckTag.Readiness => Task.FromResult(isInitialized_
HealthCheckTag.Startup or HealthCheckTag.Readiness => Task.FromResult(connection_ is not null
? HealthCheckResult.Healthy()
: HealthCheckResult.Unhealthy($"{nameof(ConnectionAmqp)} is not yet initialized.")),
HealthCheckTag.Liveness => Task.FromResult(isInitialized_ && Connection is not null && Connection.ConnectionState == ConnectionState.Opened
HealthCheckTag.Liveness => Task.FromResult(connection_ is not null && connection_.ConnectionState == ConnectionState.Opened
? HealthCheckResult.Healthy()
: HealthCheckResult.Unhealthy($"{nameof(ConnectionAmqp)} not initialized or connection dropped.")),
_ => throw new ArgumentOutOfRangeException(nameof(tag),
tag,
null),
};

public async Task Init(CancellationToken cancellationToken = default)
=> await connectionTask_;
public Task Init(CancellationToken cancellationToken = default)
=> GetConnectionAsync(cancellationToken);

private static async Task InitTask(ConnectionAmqp conn,
CancellationToken cancellationToken = default)
public async Task<Connection> GetConnectionAsync(CancellationToken cancellationToken = default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove "async" if you return the Task from the ExecutionSingleizer directly.

{
conn.logger_.LogInformation("Get address for session");
var address = new Address(conn.options_.Host,
conn.options_.Port,
conn.options_.User,
conn.options_.Password,
scheme: conn.options_.Scheme);
if (connection_ is not null && !connection_.IsClosed)
{
return connection_;
}

return await connectionSingleizer_.Call(async token =>
{
// this is needed to resolve TOCTOU problem
if (connection_ is not null && !connection_.IsClosed)
{
return connection_;
}

var conn = await CreateConnection(options_,
logger_,
token)
.ConfigureAwait(false);
connection_ = conn;
return conn;
},
cancellationToken)
.ConfigureAwait(false);
}

private static async Task<Connection> CreateConnection(QueueCommon.Amqp options,
ILogger logger,
CancellationToken cancellationToken = default)
{
var address = new Address(options.Host,
options.Port,
options.User,
options.Password,
scheme: options.Scheme);

var connectionFactory = new ConnectionFactory();
if (conn.options_.Scheme.Equals("AMQPS"))
if (options.Scheme.Equals("AMQPS"))
{
connectionFactory.SSL.RemoteCertificateValidationCallback = delegate(object _,
X509Certificate? _,
Expand All @@ -89,45 +113,41 @@ private static async Task InitTask(ConnectionAmqp conn,
{
switch (errors)
{
case SslPolicyErrors.RemoteCertificateNameMismatch when conn.options_.AllowHostMismatch:
case SslPolicyErrors.RemoteCertificateNameMismatch when options.AllowHostMismatch:
case SslPolicyErrors.None:
return true;
default:
conn.logger_.LogError("SSL error : {error}",
errors);
logger.LogError("SSL error : {error}",
errors);
return false;
}
};
}

var retry = 0;
for (; retry < conn.options_.MaxRetries; retry++)
for (; retry < options.MaxRetries; retry++)
{
try
{
conn.Connection = await connectionFactory.CreateAsync(address)
.ConfigureAwait(false);
conn.Connection.AddClosedCallback((_,
e) => OnCloseConnection(e,
conn.logger_));
break;
var connection = await connectionFactory.CreateAsync(address)
.ConfigureAwait(false);
connection.AddClosedCallback((_,
e) => OnCloseConnection(e,
logger));

return connection;
}
catch (Exception ex)
{
conn.logger_.LogInformation(ex,
"Retrying to create connection");
logger.LogInformation(ex,
"Retrying to create connection");
await Task.Delay(1000 * retry,
cancellationToken)
.ConfigureAwait(false);
}
}

if (retry == conn.options_.MaxRetries)
{
throw new TimeoutException($"{nameof(conn.options_.MaxRetries)} reached");
}

conn.isInitialized_ = true;
throw new TimeoutException($"{nameof(options.MaxRetries)} reached");
}

private static void OnCloseConnection(Error? error,
Expand Down
5 changes: 4 additions & 1 deletion Adaptors/Amqp/src/IConnectionAmqp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System.Threading;
using System.Threading.Tasks;

using Amqp;

using ArmoniK.Core.Base;
Expand All @@ -23,5 +26,5 @@ namespace ArmoniK.Core.Adapters.Amqp;

public interface IConnectionAmqp : IInitializable
{
public Connection? Connection { get; }
Task<Connection> GetConnectionAsync(CancellationToken cancellationToken = default);
}
12 changes: 7 additions & 5 deletions Adaptors/Amqp/src/PullQueueStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ await ConnectionAmqp.Init(cancellationToken)

senders_ = Enumerable.Range(0,
NbLinks)
.Select(i => new AsyncLazy<ISenderLink>(() => new SenderLink(new Session(ConnectionAmqp.Connection),
$"{Options.PartitionId}###SenderLink{i}",
$"{Options.PartitionId}###q{i}")))
.Select(i => new AsyncLazy<ISenderLink>(async () => new SenderLink(new Session(await ConnectionAmqp.GetConnectionAsync()
.ConfigureAwait(false)),
$"{Options.PartitionId}###SenderLink{i}",
$"{Options.PartitionId}###q{i}")))
.ToArray();

var senders = senders_.Select(lazy => lazy.Value)
Expand Down Expand Up @@ -174,9 +175,10 @@ await Task.Delay(retry * retry * baseDelay_,

private AsyncLazy<IReceiverLink> CreateReceiver(IConnectionAmqp connection,
int link)
=> new(() =>
=> new(async () =>
{
var rl = new ReceiverLink(new Session(connection.Connection),
var rl = new ReceiverLink(new Session(await connection.GetConnectionAsync()
.ConfigureAwait(false)),
$"{Options.PartitionId}###ReceiverLink{link}",
$"{Options.PartitionId}###q{link}");

Expand Down
6 changes: 4 additions & 2 deletions Adaptors/Amqp/src/PushQueueStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ public PushQueueStorage(QueueCommon.Amqp options,
{
logger_ = logger;
sessionPool_ = new ObjectPool<Session>(200,
() => new Session(connectionAmqp.Connection),
session => !session.IsClosed);
async token => new Session(await connectionAmqp.GetConnectionAsync(token)
.ConfigureAwait(false)),
(session,
_) => new ValueTask<bool>(!session.IsClosed));
}

/// <inheritdoc />
Expand Down
12 changes: 0 additions & 12 deletions Adaptors/Amqp/src/QueueMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
using System.Threading.Tasks;

using Amqp;
using Amqp.Framing;

using ArmoniK.Core.Base;

Expand Down Expand Up @@ -68,17 +67,6 @@ public async ValueTask DisposeAsync()
switch (Status)
{
case QueueMessageStatus.Postponed:
await sender_.SendAsync(new Message(message_.Body)
{
Header = new Header
{
Priority = message_.Header.Priority,
},
Properties = new Properties(),
})
.ConfigureAwait(false);
receiver_.Accept(message_);
break;
case QueueMessageStatus.Failed:
case QueueMessageStatus.Running:
case QueueMessageStatus.Waiting:
Expand Down
19 changes: 11 additions & 8 deletions Common/tests/Helpers/SimpleAmqpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public class SimpleAmqpClient : IConnectionAmqp, IAsyncDisposable
private readonly Address address_;
private readonly ConnectionFactory connectionFactory_;
private readonly ILoggerFactory loggerFactory_;
private bool isInitialized_;

private Connection? connection_;
private bool isInitialized_;

public SimpleAmqpClient()
{
Expand All @@ -48,25 +50,26 @@ public SimpleAmqpClient()

public async ValueTask DisposeAsync()
{
if (Connection is not null && Connection.ConnectionState == ConnectionState.Opened)
if (connection_ is not null && connection_.ConnectionState == ConnectionState.Opened)
{
await Connection.CloseAsync()
.ConfigureAwait(false);
await connection_.CloseAsync()
.ConfigureAwait(false);
}

loggerFactory_.Dispose();
GC.SuppressFinalize(this);
}

public Connection? Connection { get; private set; }

public async Task Init(CancellationToken cancellation)
{
Connection = await connectionFactory_.CreateAsync(address_)
.ConfigureAwait(false);
connection_ = await connectionFactory_.CreateAsync(address_)
.ConfigureAwait(false);
isInitialized_ = true;
}

public Task<Connection> GetConnectionAsync(CancellationToken cancellationToken = default)
=> Task.FromResult(connection_!);

public Task<HealthCheckResult> Check(HealthCheckTag tag)
=> Task.FromResult(isInitialized_
? HealthCheckResult.Healthy()
Expand Down
Loading