Skip to content

Commit

Permalink
Implement server keep alive. (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Nov 19, 2024
1 parent 618dbde commit 93a01aa
Show file tree
Hide file tree
Showing 14 changed files with 359 additions and 8 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ class SshClientSettings
bool AutoReconnect { get; set; } = false;

bool TcpKeepAlive { get; set; } = true;
TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero;
int KeepAliveCountMax = 3;

List<string> GlobalKnownHostsFilePaths { get; set; } = DefaultGlobalKnownHostsFilePaths;
List<string> UserKnownHostsFilePaths { get; set; } = DefaultUserKnownHostsFilePaths;
Expand Down Expand Up @@ -258,7 +260,7 @@ class SshConfigSettings

HostAuthentication? HostAuthentication { get; set; } // Called for Unknown when StrictHostKeyChecking is 'ask' (default)
}
public enum SshConfigOption
enum SshConfigOption
{
Hostname,
User,
Expand All @@ -281,7 +283,9 @@ public enum SshConfigOption
KexAlgorithms,
MACs,
PubkeyAcceptedAlgorithms,
TCPKeepAlive
TCPKeepAlive,
ServerAliveCountMax,
ServerAliveInterval
}
struct SshConfigOptionValue
{
Expand Down
70 changes: 69 additions & 1 deletion src/Tmds.Ssh/SocketSshConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// See file LICENSE for full license details.

using System.Buffers;
using System.Diagnostics;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging;
Expand All @@ -10,7 +11,6 @@ namespace Tmds.Ssh;

sealed class SocketSshConnection : SshConnection
{

private static ReadOnlySpan<byte> NewLine => new byte[] { (byte)'\r', (byte)'\n' };
private static readonly UTF8Encoding s_utf8Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

Expand All @@ -22,6 +22,63 @@ sealed class SocketSshConnection : SshConnection
private IPacketEncryptor _encryptor;
private uint _sendSequenceNumber;
private uint _receiveSequenceNumber;
private int _keepAlivePeriod;
private Action? _keepAliveCallback;
private Timer? _keepAliveTimer;
private int _lastReceivedTime;

public override void EnableKeepAlive(int period, Action callback)
{
if (period > 0)
{
if (_keepAliveTimer is not null)
{
throw new InvalidOperationException();
}

_keepAlivePeriod = period;
_keepAliveCallback = callback;
_lastReceivedTime = GetTime();
_keepAliveTimer = new Timer(o => ((SocketSshConnection)o!).OnKeepAliveTimerCallback(), this, -1, -1);
// Start timer after assigning the variable to ensure it is set when the callback is invoked.
_keepAliveTimer.Change(_keepAlivePeriod, _keepAlivePeriod);
}
}

private static int GetTime()
=> Environment.TickCount;

private static int GetElapsed(int previous)
=> Math.Max(GetTime() - previous, 0);

private void OnKeepAliveTimerCallback()
{
Debug.Assert(_keepAliveTimer is not null);
Debug.Assert(_keepAliveCallback is not null);

int elapsedTime = GetElapsed(_lastReceivedTime);
lock (_keepAliveTimer)
{
// Synchronize with dispose.
if (_keepAlivePeriod < 0)
{
return;
}

if (elapsedTime < _keepAlivePeriod)
{
// Wait for the period to expire.
_keepAliveTimer.Change(_keepAlivePeriod - elapsedTime, _keepAlivePeriod);
return;
}
else
{
_keepAliveTimer.Change(_keepAlivePeriod, _keepAlivePeriod);
}
}

_keepAliveCallback();
}

public SocketSshConnection(ILogger<SshClient> logger, SequencePool sequencePool, Socket socket) :
base(sequencePool)
Expand Down Expand Up @@ -102,6 +159,8 @@ public async override ValueTask<Packet> ReceivePacketAsync(CancellationToken ct,
{
if (_decryptor.TryDecrypt(_receiveBuffer, _receiveSequenceNumber, maxLength, out Packet packet))
{
_lastReceivedTime = GetTime();

_receiveSequenceNumber++;

using Packet p = packet;
Expand Down Expand Up @@ -163,6 +222,15 @@ public override void SetEncryptorDecryptor(IPacketEncryptor packetEncoder, IPack

public override void Dispose()
{
if (_keepAliveTimer is not null)
{
lock (_keepAliveTimer)
{
_keepAlivePeriod = -1;
_keepAliveTimer.Dispose();
}
}
_keepAliveTimer?.Dispose();
_receiveBuffer.Dispose();
_sendBuffer.Dispose();
_encryptor.Dispose();
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/SshClientSettings.Defaults.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ partial class SshClientSettings

private static bool DefaultTcpKeepAlive => true;

private static int DefaultKeepAliveCountMax => 3;

// Algorithms are in **order of preference**.
private readonly static List<Name> EmptyList = [];
internal readonly static List<Name> SupportedKeyExchangeAlgorithms = [ AlgorithmNames.EcdhSha2Nistp256, AlgorithmNames.EcdhSha2Nistp384, AlgorithmNames.EcdhSha2Nistp521 ];
Expand Down
4 changes: 3 additions & 1 deletion src/Tmds.Ssh/SshClientSettings.SshConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ internal static async ValueTask<SshClientSettings> LoadFromConfigAsync(string? u
MinimumRSAKeySize = sshConfig.RequiredRSASize ?? DefaultMinimumRSAKeySize,
Credentials = DetermineCredentials(sshConfig),
HashKnownHosts = sshConfig.HashKnownHosts ?? DefaultHashKnownHosts,
TcpKeepAlive = sshConfig.TcpKeepAlive ?? DefaultTcpKeepAlive
TcpKeepAlive = sshConfig.TcpKeepAlive ?? DefaultTcpKeepAlive,
KeepAliveCountMax = sshConfig.ServerAliveCountMax ?? DefaultKeepAliveCountMax,
KeepAliveInterval = sshConfig.ServerAliveInterval > 0 ? TimeSpan.FromSeconds(sshConfig.ServerAliveInterval.Value) : TimeSpan.Zero,
};
if (sshConfig.UserKnownHostsFiles is not null)
{
Expand Down
22 changes: 22 additions & 0 deletions src/Tmds.Ssh/SshClientSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public sealed partial class SshClientSettings
private string _userName = "";
private List<Credential>? _credentials;
private TimeSpan _connectTimeout = DefaultConnectTimeout;
private TimeSpan _keepAliveInterval = TimeSpan.Zero;
private int _keepAliveCountMax = 3;
private List<string>? _userKnownHostsFilePaths;
private List<string>? _globalKnownHostsFilePaths;
private Dictionary<string, string>? _environmentVariables;
Expand Down Expand Up @@ -107,6 +109,26 @@ public TimeSpan ConnectTimeout
}
}

public int KeepAliveCountMax
{
get => _keepAliveCountMax;
set
{
ArgumentOutOfRangeException.ThrowIfLessThan(value, 0);
_keepAliveCountMax = value;
}
}

public TimeSpan KeepAliveInterval
{
get => _keepAliveInterval;
set
{
ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero);
_keepAliveInterval = value;
}
}

internal void Validate()
{
if (_credentials is not null)
Expand Down
12 changes: 10 additions & 2 deletions src/Tmds.Ssh/SshConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public struct AlgorithmList
public bool? HashKnownHosts { get; set; }
public List<string>? SendEnv { get; set; }
public bool? TcpKeepAlive { get; set; }
public int? ServerAliveCountMax { get; set; }
public int? ServerAliveInterval { get; set; }

internal static ValueTask<SshConfig> DetermineConfigForHost(string? userName, string host, int? port, IReadOnlyDictionary<SshConfigOption, SshConfigOptionValue>? options, IReadOnlyList<string> configFiles, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -447,6 +449,14 @@ private static void HandleMatchedKeyword(SshConfig config, ReadOnlySpan<char> ke
config.TcpKeepAlive ??= ParseYesNoKeywordValue(keyword, ref remainder);
break;

case "serveralivecountmax":
config.ServerAliveCountMax ??= NextTokenAsInt(keyword, ref remainder);
break;

case "serveraliveinterval":
config.ServerAliveInterval ??= NextTokenAsInt(keyword, ref remainder);
break;

/* The following options are unsupported,
we have some basic handling that checks the option value indicates the feature is disabled */
case "permitlocalcommand":
Expand Down Expand Up @@ -539,8 +549,6 @@ we have some basic handling that checks the option value indicates the feature i
// case "ipqos":
// case "streamlocalbindmask":
// case "streamlocalbindunlink":
// case "serveralivecountmax":
// case "serveraliveinterval":
// case "setenv":
// case "tag":
// case "proxycommand":
Expand Down
4 changes: 3 additions & 1 deletion src/Tmds.Ssh/SshConfigOption.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@ public enum SshConfigOption
KexAlgorithms,
MACs,
PubkeyAcceptedAlgorithms,
TCPKeepAlive
TCPKeepAlive,
ServerAliveCountMax,
ServerAliveInterval
}
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/SshConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ protected SshConnection(SequencePool sequencePool)

public SequencePool SequencePool { get; }

public abstract void EnableKeepAlive(int period, Action callback);

public abstract ValueTask<string> ReceiveLineAsync(int maxLength, CancellationToken ct);
public abstract ValueTask WriteLineAsync(string line, CancellationToken ct);

Expand Down
1 change: 1 addition & 0 deletions src/Tmds.Ssh/SshConnectionClosedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Tmds.Ssh;
public class SshConnectionClosedException : SshConnectionException
{
internal const string ConnectionClosedByPeer = "Connection closed by peer.";
internal const string ConnectionClosedByKeepAliveTimeout = "Connection closed due to keep alive timeout.";
internal const string ConnectionClosedByAbort = "Connection closed due to an unexpected error.";
internal const string ConnectionClosedByDispose = "Connection closed by dispose.";

Expand Down
12 changes: 12 additions & 0 deletions src/Tmds.Ssh/SshSequencePoolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,16 @@ uint32 bytes to add
writer.WriteUInt32(bytesToAdd);
return packet.Move();
}

public static Packet CreateKeepAliveMessage(this SequencePool sequencePool)
{
using var packet = sequencePool.RentPacket();
var writer = packet.GetWriter();
writer.WriteMessageId(MessageId.SSH_MSG_GLOBAL_REQUEST);
// The request name can be any unknown name (to trigger an SSH_MSG_REQUEST_FAILURE response).
// We use the same name as the OpenSSH client.
writer.WriteString("[email protected]");
writer.WriteBoolean(true); // want reply
return packet.Move();
}
}
Loading

0 comments on commit 93a01aa

Please sign in to comment.