Skip to content

Commit

Permalink
Refactor keep alive message handling in order to fix issues with sudd…
Browse files Browse the repository at this point in the history
…en disconnects.
  • Loading branch information
chkr1011 committed Feb 23, 2021
1 parent 2d4ea8e commit b766c1b
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 96 deletions.
1 change: 1 addition & 0 deletions MQTTnet.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002ECSharpPlaceAttributeOnSameLineMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateBlankLinesAroundFieldToBlankLinesAroundProperty/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateThisQualifierSettings/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=PINGREQ/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unsub/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
60 changes: 31 additions & 29 deletions Source/MQTTnet/Client/MqttClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
using MQTTnet.Packets;
using MQTTnet.Protocol;
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Implementations;
Expand All @@ -25,8 +24,6 @@ public class MqttClient : Disposable, IMqttClient
{
readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
readonly Stopwatch _sendTracker = new Stopwatch();
readonly Stopwatch _receiveTracker = new Stopwatch();
readonly object _disconnectLock = new object();

readonly IMqttClientAdapterFactory _adapterFactory;
Expand All @@ -44,6 +41,8 @@ public class MqttClient : Disposable, IMqttClient
long _isDisconnectPending;
bool _isConnected;
MqttClientDisconnectReason _disconnectReason;

DateTime _lastPacketSentTimestamp;

public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logger)
{
Expand Down Expand Up @@ -79,7 +78,7 @@ public async Task<MqttClientAuthenticateResult> ConnectAsync(IMqttClientOptions
Options = options;

_packetIdentifierProvider.Reset();
_packetDispatcher.Cancel();
_packetDispatcher.CancelAll();

_backgroundCancellationTokenSource = new CancellationTokenSource();
var backgroundCancellationToken = _backgroundCancellationTokenSource.Token;
Expand All @@ -102,8 +101,7 @@ public async Task<MqttClientAuthenticateResult> ConnectAsync(IMqttClientOptions
authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false);
}

_sendTracker.Restart();
_receiveTracker.Restart();
_lastPacketSentTimestamp = DateTime.UtcNow;

if (Options.KeepAlivePeriod != TimeSpan.Zero)
{
Expand Down Expand Up @@ -391,27 +389,26 @@ Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

_sendTracker.Restart();

_lastPacketSentTimestamp = DateTime.UtcNow;
return _adapter.SendPacketAsync(packet, cancellationToken);
}

async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket
{
cancellationToken.ThrowIfCancellationRequested();

ushort identifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier && packetWithIdentifier.PacketIdentifier > 0)
ushort packetIdentifier = 0;
if (requestPacket is IMqttPacketWithIdentifier packetWithIdentifier)
{
identifier = packetWithIdentifier.PacketIdentifier;
packetIdentifier = packetWithIdentifier.PacketIdentifier;
}

using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(identifier))
using (var packetAwaiter = _packetDispatcher.AddAwaiter<TResponsePacket>(packetIdentifier))
{
try
{
_sendTracker.Restart();
await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false);
await SendAsync(requestPacket, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
Expand Down Expand Up @@ -446,15 +443,15 @@ async Task TrySendKeepAliveMessagesAsync(CancellationToken cancellationToken)
while (!cancellationToken.IsCancellationRequested)
{
// Values described here: [MQTT-3.1.2-24].
var waitTime = keepAlivePeriod - _sendTracker.Elapsed;
var timeWithoutPacketSent = DateTime.UtcNow - _lastPacketSentTimestamp;

if (waitTime <= TimeSpan.Zero)
if (timeWithoutPacketSent > keepAlivePeriod)
{
await SendAndReceiveAsync<MqttPingRespPacket>(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false);
await PingAsync(cancellationToken).ConfigureAwait(false);
}

// Wait a fixed time in all cases. Calculation of the remaining time is complicated
// due to some edge cases and was buggy in the past. Now we wait half a second because the
// due to some edge cases and was buggy in the past. Now we wait several ms because the
// min keep alive value is one second so that the server will wait 1.5 seconds for a PING
// packet.
await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -538,7 +535,7 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken)
_logger.Error(exception, "Error while receiving packets.");
}

_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
Expand All @@ -555,8 +552,6 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke
{
try
{
_receiveTracker.Restart();

if (packet is MqttPublishPacket publishPacket)
{
EnqueueReceivedPublishPacket(publishPacket);
Expand All @@ -569,10 +564,6 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke
{
await ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttPingReqPacket)
{
await SendAsync(MqttPingRespPacket.Instance, cancellationToken).ConfigureAwait(false);
}
else if (packet is MqttDisconnectPacket disconnectPacket)
{
await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false);
Expand All @@ -581,9 +572,20 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke
{
await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false);
}
else if (packet is MqttPingRespPacket)
{
_packetDispatcher.TryDispatch(packet);
}
else if (packet is MqttPingReqPacket)
{
throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a Client to the Server only.");
}
else
{
_packetDispatcher.Dispatch(packet);
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
}
}
}
catch (Exception exception)
Expand All @@ -605,7 +607,7 @@ async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToke
_logger.Error(exception, "Error while receiving packets.");
}

_packetDispatcher.Dispatch(exception);
_packetDispatcher.FailAll(exception);

if (!DisconnectIsPending())
{
Expand Down Expand Up @@ -703,8 +705,8 @@ Task ProcessReceivedDisconnectPacket(MqttDisconnectPacket disconnectPacket)
_disconnectReason = (MqttClientDisconnectReason)(disconnectPacket.ReasonCode ?? MqttDisconnectReasonCode.NormalDisconnection);

// Also dispatch disconnect to waiting threads to generate a proper exception.
_packetDispatcher.Dispatch(disconnectPacket);

_packetDispatcher.FailAll(new MqttUnexpectedDisconnectReceivedException(disconnectPacket));
if (!DisconnectIsPending())
{
return DisconnectInternalAsync(_packetReceiverTask, null, null);
Expand Down
2 changes: 2 additions & 0 deletions Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ namespace MQTTnet.PacketDispatcher
{
public interface IMqttPacketAwaiter : IDisposable
{
MqttPacketAwaiterPacketFilter PacketFilter { get; }

void Complete(MqttBasePacket packet);

void Fail(Exception exception);
Expand Down
16 changes: 11 additions & 5 deletions Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@ namespace MQTTnet.PacketDispatcher
public sealed class MqttPacketAwaiter<TPacket> : IMqttPacketAwaiter where TPacket : MqttBasePacket
{
readonly TaskCompletionSource<MqttBasePacket> _taskCompletionSource;
readonly ushort? _packetIdentifier;
readonly MqttPacketDispatcher _owningPacketDispatcher;

public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
public MqttPacketAwaiter(ushort packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
{
_packetIdentifier = packetIdentifier;
PacketFilter = new MqttPacketAwaiterPacketFilter
{
Type = typeof(TPacket),
Identifier = packetIdentifier
};

_owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher));
#if NET452
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>();
#else
_taskCompletionSource = new TaskCompletionSource<MqttBasePacket>(TaskCreationOptions.RunContinuationsAsynchronously);
#endif
}


public MqttPacketAwaiterPacketFilter PacketFilter { get; }

public async Task<TPacket> WaitOneAsync(TimeSpan timeout)
{
using (var timeoutToken = new CancellationTokenSource(timeout))
Expand Down Expand Up @@ -82,7 +88,7 @@ public void Cancel()

public void Dispose()
{
_owningPacketDispatcher.RemoveAwaiter<TPacket>(_packetIdentifier);
_owningPacketDispatcher.RemoveAwaiter(this);
}
}
}
11 changes: 11 additions & 0 deletions Source/MQTTnet/PacketDispatcher/MqttPacketAwaiterPacketFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;

namespace MQTTnet.PacketDispatcher
{
public sealed class MqttPacketAwaiterPacketFilter
{
public Type Type { get; set; }

public ushort Identifier { get; set; }
}
}
Loading

0 comments on commit b766c1b

Please sign in to comment.