From b766c1bebb28086d0263dbf98451b5d216fd43cd Mon Sep 17 00:00:00 2001 From: Christian Date: Tue, 23 Feb 2021 11:45:22 +0100 Subject: [PATCH] Refactor keep alive message handling in order to fix issues with sudden disconnects. --- MQTTnet.sln.DotSettings | 1 + Source/MQTTnet/Client/MqttClient.cs | 60 ++++----- .../PacketDispatcher/IMqttPacketAwaiter.cs | 2 + .../PacketDispatcher/MqttPacketAwaiter.cs | 16 ++- .../MqttPacketAwaiterPacketFilter.cs | 11 ++ .../PacketDispatcher/MqttPacketDispatcher.cs | 118 +++++++++--------- Source/MQTTnet/Server/MqttClientConnection.cs | 11 +- 7 files changed, 123 insertions(+), 96 deletions(-) create mode 100644 Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs diff --git a/MQTTnet.sln.DotSettings b/MQTTnet.sln.DotSettings index 0b9e2dc5f..e81ffe218 100644 --- a/MQTTnet.sln.DotSettings +++ b/MQTTnet.sln.DotSettings @@ -12,4 +12,5 @@ True True True + True True \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index b16cac7e8..0107ffb4d 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -14,7 +14,6 @@ using MQTTnet.Packets; using MQTTnet.Protocol; using System; -using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using MQTTnet.Implementations; @@ -25,8 +24,6 @@ public class MqttClient : Disposable, IMqttClient { readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); - readonly Stopwatch _sendTracker = new Stopwatch(); - readonly Stopwatch _receiveTracker = new Stopwatch(); readonly object _disconnectLock = new object(); readonly IMqttClientAdapterFactory _adapterFactory; @@ -44,6 +41,8 @@ public class MqttClient : Disposable, IMqttClient long _isDisconnectPending; bool _isConnected; MqttClientDisconnectReason _disconnectReason; + + DateTime _lastPacketSentTimestamp; public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger) { @@ -79,7 +78,7 @@ public async Task ConnectAsync(IMqttClientOptions Options = options; _packetIdentifierProvider.Reset(); - _packetDispatcher.Cancel(); + _packetDispatcher.CancelAll(); _backgroundCancellationTokenSource = new CancellationTokenSource(); var backgroundCancellationToken = _backgroundCancellationTokenSource.Token; @@ -102,8 +101,7 @@ public async Task ConnectAsync(IMqttClientOptions authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false); } - _sendTracker.Restart(); - _receiveTracker.Restart(); + _lastPacketSentTimestamp = DateTime.UtcNow; if (Options.KeepAlivePeriod != TimeSpan.Zero) { @@ -391,8 +389,8 @@ Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - _sendTracker.Restart(); - + _lastPacketSentTimestamp = DateTime.UtcNow; + return _adapter.SendPacketAsync(packet, cancellationToken); } @@ -400,18 +398,17 @@ async Task SendAndReceiveAsync(MqttBasePacket { cancellationToken.ThrowIfCancellationRequested(); - ushort identifier = 0; - if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0) + ushort packetIdentifier = 0; + if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier) { - identifier = packetWithIdentifier.PacketIdentifier; + packetIdentifier = packetWithIdentifier.PacketIdentifier; } - using (var packetAwaiter = _packetDispatcher.AddAwaiter(identifier)) + using (var packetAwaiter = _packetDispatcher.AddAwaiter(packetIdentifier)) { try { - _sendTracker.Restart(); - await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); + await SendAsync(requestPacket, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { @@ -446,15 +443,15 @@ async Task TrySendKeepAliveMessagesAsync(CancellationToken cancellationToken) while (!cancellationToken.IsCancellationRequested) { // Values described here: [MQTT-3.1.2-24]. - var waitTime = keepAlivePeriod - _sendTracker.Elapsed; + var timeWithoutPacketSent = DateTime.UtcNow - _lastPacketSentTimestamp; - if (waitTime <= TimeSpan.Zero) + if (timeWithoutPacketSent > keepAlivePeriod) { - await SendAndReceiveAsync(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false); + await PingAsync(cancellationToken).ConfigureAwait(false); } // Wait a fixed time in all cases. Calculation of the remaining time is complicated - // due to some edge cases and was buggy in the past. Now we wait half a second because the + // due to some edge cases and was buggy in the past. Now we wait several ms because the // min keep alive value is one second so that the server will wait 1.5 seconds for a PING // packet. await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken).ConfigureAwait(false); @@ -538,7 +535,7 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken) _logger.Error(exception, "Error while receiving packets."); } - _packetDispatcher.Dispatch(exception); + _packetDispatcher.FailAll(exception); if (!DisconnectIsPending()) { @@ -555,8 +552,6 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke { try { - _receiveTracker.Restart(); - if (packet is MqttPublishPacket publishPacket) { EnqueueReceivedPublishPacket(publishPacket); @@ -569,10 +564,6 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke { await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false); } - else if (packet is MqttPingReqPacket) - { - await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false); - } else if (packet is MqttDisconnectPacket disconnectPacket) { await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false); @@ -581,9 +572,20 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke { await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false); } + else if (packet is MqttPingRespPacket) + { + _packetDispatcher.TryDispatch(packet); + } + else if (packet is MqttPingReqPacket) + { + throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a Client to the Server only."); + } else { - _packetDispatcher.Dispatch(packet); + if (!_packetDispatcher.TryDispatch(packet)) + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } } } catch (Exception exception) @@ -605,7 +607,7 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke _logger.Error(exception, "Error while receiving packets."); } - _packetDispatcher.Dispatch(exception); + _packetDispatcher.FailAll(exception); if (!DisconnectIsPending()) { @@ -703,8 +705,8 @@ Task ProcessReceivedDisconnectPacket(MqttDisconnectPacket disconnectPacket) _disconnectReason = (MqttClientDisconnectReason)(disconnectPacket.ReasonCode ?? MqttDisconnectReasonCode.NormalDisconnection); // Also dispatch disconnect to waiting threads to generate a proper exception. - _packetDispatcher.Dispatch(disconnectPacket); - + _packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket)); + if (!DisconnectIsPending()) { return DisconnectInternalAsync(_packetReceiverTask, null, null); diff --git a/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs index c1c2d4eb9..8324d4ca0 100644 --- a/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs @@ -5,6 +5,8 @@ namespace MQTTnet.PacketDispatcher { public interface IMqttPacketAwaiter : IDisposable { + MqttPacketAwaiterPacketFilter PacketFilter { get; } + void Complete(MqttBasePacket packet); void Fail(Exception exception); diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs index 6935e746a..9ed6e71ce 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs @@ -9,12 +9,16 @@ namespace MQTTnet.PacketDispatcher public sealed class MqttPacketAwaiter : IMqttPacketAwaiter where TPacket : MqttBasePacket { readonly TaskCompletionSource _taskCompletionSource; - readonly ushort? _packetIdentifier; readonly MqttPacketDispatcher _owningPacketDispatcher; - public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPacketDispatcher) + public MqttPacketAwaiter(ushort packetIdentifier, MqttPacketDispatcher owningPacketDispatcher) { - _packetIdentifier = packetIdentifier; + PacketFilter = new MqttPacketAwaiterPacketFilter + { + Type = typeof(TPacket), + Identifier = packetIdentifier + }; + _owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher)); #if NET452 _taskCompletionSource = new TaskCompletionSource(); @@ -22,7 +26,9 @@ public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPa _taskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); #endif } - + + public MqttPacketAwaiterPacketFilter PacketFilter { get; } + public async Task WaitOneAsync(TimeSpan timeout) { using (var timeoutToken = new CancellationTokenSource(timeout)) @@ -82,7 +88,7 @@ public void Cancel() public void Dispose() { - _owningPacketDispatcher.RemoveAwaiter(_packetIdentifier); + _owningPacketDispatcher.RemoveAwaiter(this); } } } \ No newline at end of file diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs new file mode 100644 index 000000000..a36acc5a7 --- /dev/null +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs @@ -0,0 +1,11 @@ +using System; + +namespace MQTTnet.PacketDispatcher +{ + public sealed class MqttPacketAwaiterPacketFilter + { + public Type Type { get; set; } + + public ushort Identifier { get; set; } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs index 306ca45c7..aafabf234 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs @@ -1,103 +1,101 @@ -using MQTTnet.Exceptions; using MQTTnet.Packets; using System; -using System.Collections.Concurrent; +using System.Collections.Generic; namespace MQTTnet.PacketDispatcher { public sealed class MqttPacketDispatcher { - readonly ConcurrentDictionary, IMqttPacketAwaiter> _awaiters = new ConcurrentDictionary, IMqttPacketAwaiter>(); + readonly List _awaiters = new List(); - public void Dispatch(Exception exception) + public void FailAll(Exception exception) { - foreach (var awaiter in _awaiters) + if (exception == null) throw new ArgumentNullException(nameof(exception)); + + lock (_awaiters) { - awaiter.Value.Fail(exception); - } + foreach (var awaiter in _awaiters) + { + awaiter.Fail(exception); + } - _awaiters.Clear(); + _awaiters.Clear(); + } } - public bool TryDispatch(MqttBasePacket packet) + public void CancelAll() { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - - if (packet is MqttDisconnectPacket disconnectPacket) + lock (_awaiters) { - foreach (var packetAwaiter in _awaiters) + foreach (var entry in _awaiters) { - packetAwaiter.Value.Fail(new MqttUnexpectedDisconnectReceivedException(disconnectPacket)); + entry.Cancel(); } - return true; + _awaiters.Clear(); } - + } + + public bool TryDispatch(MqttBasePacket packet) + { + if (packet == null) throw new ArgumentNullException(nameof(packet)); + ushort identifier = 0; - if (packet is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0) + if (packet is IMqttPacketWithIdentifier packetWithIdentifier) { identifier = packetWithIdentifier.PacketIdentifier; } - var type = packet.GetType(); - var key = new Tuple(identifier, type); - - if (_awaiters.TryRemove(key, out var awaiter)) + var packetType = packet.GetType(); + var matchingAwaiters = new List(); + + lock (_awaiters) { - awaiter.Complete(packet); - return true; - } - - return false; - } - - public void Dispatch(MqttBasePacket packet) - { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - - if (!TryDispatch(packet)) - { - throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + for (var i = _awaiters.Count - 1; i >= 0; i--) + { + var entry = _awaiters[i]; + + // Note: The PingRespPacket will also arrive here and has NO identifier but there + // is code which waits for it. So the code must be able to deal with filters which + // are referring to the type only (identifier is 0)! + if (entry.PacketFilter.Type != packetType || entry.PacketFilter.Identifier != identifier) + { + continue; + } + + matchingAwaiters.Add(entry); + _awaiters.RemoveAt(i); + } } - } - - public void Cancel() - { - foreach (var awaiter in _awaiters) + + foreach (var matchingEntry in matchingAwaiters) { - awaiter.Value.Cancel(); + matchingEntry.Complete(packet); } - _awaiters.Clear(); + return matchingAwaiters.Count > 0; } - - public MqttPacketAwaiter AddAwaiter(ushort? identifier) where TResponsePacket : MqttBasePacket + + public MqttPacketAwaiter AddAwaiter(ushort packetIdentifier) where TResponsePacket : MqttBasePacket { - if (!identifier.HasValue) - { - identifier = 0; - } - - var awaiter = new MqttPacketAwaiter(identifier, this); + var awaiter = new MqttPacketAwaiter(packetIdentifier, this); - var key = new Tuple(identifier.Value, typeof(TResponsePacket)); - if (!_awaiters.TryAdd(key, awaiter)) + lock (_awaiters) { - throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{key.Item2.Name}' with identifier {key.Item1}."); + _awaiters.Add(awaiter); } - + return awaiter; } - public void RemoveAwaiter(ushort? identifier) where TResponsePacket : MqttBasePacket + public void RemoveAwaiter(IMqttPacketAwaiter awaiter) { - if (!identifier.HasValue) + if (awaiter == null) throw new ArgumentNullException(nameof(awaiter)); + + lock (_awaiters) { - identifier = 0; + _awaiters.Remove(awaiter); } - - var key = new Tuple(identifier.Value, typeof(TResponsePacket)); - _awaiters.TryRemove(key, out _); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index 75083a545..8f33bb8ba 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -212,6 +212,10 @@ async Task RunInternalAsync(CancellationToken cancellationToken) { await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false); } + else if (packet is MqttPingRespPacket) + { + throw new MqttProtocolViolationException("A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ Packet only."); + } else if (packet is MqttDisconnectPacket) { Session.WillMessage = null; @@ -222,7 +226,10 @@ async Task RunInternalAsync(CancellationToken cancellationToken) } else { - _packetDispatcher.Dispatch(packet); + if (!_packetDispatcher.TryDispatch(packet)) + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } } } } @@ -255,7 +262,7 @@ async Task RunInternalAsync(CancellationToken cancellationToken) Session.WillMessage = null; } - _packetDispatcher.Cancel(); + _packetDispatcher.CancelAll(); _logger.Info("Client '{0}': Connection stopped.", ClientId);