From 58d8af9ad6bf92c78fa4a23da3edb5fe7dba4c4c Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Fri, 6 Dec 2024 09:15:13 +0100 Subject: [PATCH] Add LocalForward API. (#258) --- README.md | 8 + src/Tmds.Ssh/ISshChannel.cs | 2 + src/Tmds.Ssh/LocalForward.cs | 273 ++++++++++++++++++++++ src/Tmds.Ssh/SshChannel.cs | 46 +++- src/Tmds.Ssh/SshClient.cs | 16 +- src/Tmds.Ssh/SshDataStream.cs | 10 + src/Tmds.Ssh/SshLoggers.cs | 6 +- src/Tmds.Ssh/SshPortForwardLogger.cs | 57 +++++ src/Tmds.Ssh/SshSequencePoolExtensions.cs | 13 ++ src/Tmds.Ssh/SshSession.cs | 12 +- test/Tmds.Ssh.Tests/LocalForwardTests.cs | 100 ++++++++ 11 files changed, 527 insertions(+), 16 deletions(-) create mode 100644 src/Tmds.Ssh/LocalForward.cs create mode 100644 src/Tmds.Ssh/SshPortForwardLogger.cs create mode 100644 test/Tmds.Ssh.Tests/LocalForwardTests.cs diff --git a/README.md b/README.md index b242a22..1947313 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ class SshClient : IDisposable Task OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default); Task OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default); + Task StartForwardTcpAsync(EndPoint bindEndpoint, string remoteHost, int remotePort, CancellationToken cancellationToken = default); Task OpenSftpClientAsync(CancellationToken cancellationToken); Task OpenSftpClientAsync(SftpClientOptions? options = null, CancellationToken cancellationToken = default) @@ -112,8 +113,15 @@ class RemoteProcess : IDisposable } class SshDataStream : Stream { + void WriteEof(); CancellationToken StreamAborted { get; } } +class LocalForward : IDisposable +{ + EndPoint? EndPoint { get; } + CancellationToken ForwardStopped { get; } + void ThrowIfStopped(); +} class SftpClient : IDisposable { // Note: umask is applied on the server. diff --git a/src/Tmds.Ssh/ISshChannel.cs b/src/Tmds.Ssh/ISshChannel.cs index 6c407ab..7bc36ae 100644 --- a/src/Tmds.Ssh/ISshChannel.cs +++ b/src/Tmds.Ssh/ISshChannel.cs @@ -17,5 +17,7 @@ interface ISshChannel CancellationToken cancellationToken = default); ValueTask WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken = default); + void WriteEof(); + Exception CreateCloseException(); } \ No newline at end of file diff --git a/src/Tmds.Ssh/LocalForward.cs b/src/Tmds.Ssh/LocalForward.cs new file mode 100644 index 0000000..42ae7b7 --- /dev/null +++ b/src/Tmds.Ssh/LocalForward.cs @@ -0,0 +1,273 @@ +// This file is part of Tmds.Ssh which is released under MIT. +// See file LICENSE for full license details. + +using System.Net; +using System.Net.Sockets; +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +namespace Tmds.Ssh; + +public sealed class LocalForward : IDisposable +{ + // Sentinel stop reasons. + private static readonly Exception ConnectionClosed = new(); + private static readonly Exception Disposed = new(); + + private readonly SshSession _session; + private readonly ILogger _logger; + private readonly CancellationTokenSource _cancel; + + private Func>? _connectToRemote; + private Socket? _serverSocket; + private EndPoint? _localEndPoint; + private CancellationTokenRegistration _ctr; + private Exception? _stopReason; + private string _remoteEndPoint; + + private bool IsDisposed => ReferenceEquals(_stopReason, Disposed); + + internal LocalForward(SshSession session, ILogger logger) + { + _logger = logger; + _session = session; + _cancel = new(); + + _remoteEndPoint = ""; + } + + internal void StartTcpForward(EndPoint bindEndpoint, string remoteHost, int remotePort) + { + ArgumentNullException.ThrowIfNull(bindEndpoint); + ArgumentException.ThrowIfNullOrEmpty(remoteHost); + if (remotePort < 0 || remotePort > 0xffff) + { + throw new ArgumentException(nameof(remotePort)); + } + if (bindEndpoint is not IPEndPoint) + { + throw new ArgumentException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}."); + } + + _remoteEndPoint = $"{remoteHost}:{remotePort}"; + _connectToRemote = async ct => await _session.OpenTcpConnectionChannelAsync(remoteHost, remotePort, ct).ConfigureAwait(false); + + Start(bindEndpoint); + } + + private void Start(EndPoint bindEndpoint) + { + // Assign to bindEndPoint in case we fail to bind/listen so we have an address for logging. + _localEndPoint = bindEndpoint; + + try + { + if (bindEndpoint is IPEndPoint ipEndPoint) + { + _serverSocket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + } + else + { + // Type must be validated before calling this method. + throw new InvalidOperationException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}."); + } + + _serverSocket.Bind(bindEndpoint); + _serverSocket.Listen(); + + EndPoint localEndPoint = _serverSocket.LocalEndPoint!; + _localEndPoint = localEndPoint; + + _ctr = _session.ConnectionClosed.UnsafeRegister(o => ((LocalForward)o!).Stop(ConnectionClosed), this); + + _ = AcceptLoop(localEndPoint); + } + catch (Exception ex) + { + Stop(ex); + + throw; + } + } + + private async Task AcceptLoop(EndPoint localEndPoint) + { + try + { + _logger.ForwardStart(localEndPoint, _remoteEndPoint); + while (true) + { + Socket acceptedSocket = await _serverSocket!.AcceptAsync().ConfigureAwait(false); + _ = Accept(acceptedSocket, localEndPoint); + } + } + catch (Exception ex) + { + Stop(ex); + } + } + + private async Task Accept(Socket acceptedSocket, EndPoint localEndpoint) + { + Debug.Assert(_connectToRemote is not null); + SshDataStream? forwardStream = null; + EndPoint? peerEndPoint = null; + try + { + peerEndPoint = acceptedSocket.RemoteEndPoint!; + _logger.AcceptConnection(localEndpoint, peerEndPoint, _remoteEndPoint); + acceptedSocket.NoDelay = true; + + // We may want to add a timeout option, and the ability to stop the lister on some conditions like nr of successive fails to connect to the remote. + forwardStream = await _connectToRemote(_cancel!.Token).ConfigureAwait(false); + _ = ForwardConnectionAsync(new NetworkStream(acceptedSocket, ownsSocket: true), forwardStream, peerEndPoint); + } + catch (Exception ex) + { + _logger.ForwardConnectionFailed(peerEndPoint, _remoteEndPoint, ex); + + acceptedSocket.Dispose(); + forwardStream?.Dispose(); + } + } + + private async Task ForwardConnectionAsync(NetworkStream socketStream, SshDataStream sshStream, EndPoint peerEndPoint) + { + Exception? exception = null; + try + { + _logger.ForwardConnection(peerEndPoint, _remoteEndPoint); + Task first, second; + try + { + Task copy1 = CopyTillEofAsync(socketStream, sshStream, sshStream.WriteMaxPacketDataLength); + Task copy2 = CopyTillEofAsync(sshStream, socketStream, sshStream.ReadMaxPacketDataLength); + + first = await Task.WhenAny(copy1, copy2).ConfigureAwait(false); + second = first == copy1 ? copy2 : copy1; + } + finally + { + // When the copy stops in one direction, stop it in the other direction too. + // Though TCP allows data still to be received when the writing is shutdown + // application protocols (usually) follow the pattern of only closing + // when they will no longer receive. + socketStream.Dispose(); + sshStream.Dispose(); + } + // The dispose will cause the second copy to stop. + await second.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + + await first.ConfigureAwait(false); // Throws if faulted. + } + catch (Exception ex) + { + exception = ex; + } + finally + { + if (exception is null) + { + _logger.ForwardConnectionClosed(peerEndPoint, _remoteEndPoint); + } + else + { + _logger.ForwardConnectionAborted(peerEndPoint, _remoteEndPoint, exception); + } + } + + static async Task CopyTillEofAsync(Stream from, Stream to, int bufferSize) + { + await from.CopyToAsync(to, bufferSize).ConfigureAwait(false); + if (to is NetworkStream ns) + { + ns.Socket.Shutdown(SocketShutdown.Send); + } + else if (to is SshDataStream ds) + { + ds.WriteEof(); + } + } + } + + public EndPoint? EndPoint + { + get + { + ObjectDisposedException.ThrowIf(IsDisposed, this); + return _localEndPoint; + } + } + + public CancellationToken ForwardStopped + { + get + { + ObjectDisposedException.ThrowIf(IsDisposed, this); + return _cancel.Token; + } + } + + public void ThrowIfStopped() + { + Exception? stopReason = _stopReason; + if (ReferenceEquals(stopReason, Disposed)) + { + throw new ObjectDisposedException(typeof(LocalForward).FullName); + } + else if (ReferenceEquals(stopReason, ConnectionClosed)) + { + throw _session.CreateCloseException(); + } + else if (stopReason is not null) + { + throw new SshException($"{nameof(LocalForward)} stopped due to an unexpected error.", stopReason); + } + } + + private void Stop(Exception stopReason) + { + bool disposing = ReferenceEquals(stopReason, Disposed); + if (disposing) + { + if (Interlocked.Exchange(ref _stopReason, Disposed) != null) + { + return; + } + } + else + { + if (Interlocked.CompareExchange(ref _stopReason, stopReason, null) != null) + { + return; + } + } + + if (_localEndPoint is not null) + { + if (IsDisposed) + { + _logger.ForwardStopped(_localEndPoint, _remoteEndPoint); + } + else if (ReferenceEquals(stopReason, ConnectionClosed)) + { + if (_logger.IsEnabled(LogLevel.Error)) + { + _logger.ForwardAborted(_localEndPoint, _remoteEndPoint, _session.CreateCloseException()); + } + } + else + { + _logger.ForwardAborted(_localEndPoint, _remoteEndPoint, stopReason); + } + _ctr.Dispose(); + _localEndPoint = null; + _serverSocket?.Dispose(); + } + + _cancel.Cancel(); + } + + public void Dispose() + => Stop(Disposed); +} \ No newline at end of file diff --git a/src/Tmds.Ssh/SshChannel.cs b/src/Tmds.Ssh/SshChannel.cs index 35a39ff..ab006ef 100644 --- a/src/Tmds.Ssh/SshChannel.cs +++ b/src/Tmds.Ssh/SshChannel.cs @@ -71,6 +71,7 @@ public CancellationToken ChannelAborted private int _disposed; private int _sendWindow; private Exception? _abortReason = null; + private bool _eofSent; private object _gate => _receiveQueue; public async ValueTask<(ChannelReadType ReadType, int BytesRead)> ReadAsync @@ -173,16 +174,50 @@ string data } } + public void WriteEof() + { + ThrowIfDisposed(); + ThrowIfAborted(); + ThrowIfEofSent(); + + _eofSent = true; + TrySendEofMessage(); + } + + private void ThrowIfEofSent() + { + if (_eofSent) + { + ThrowEofSent(); + } + + static void ThrowEofSent() + { + throw new InvalidOperationException("EOF already sent."); + } + } + + private void ThrowIfAborted() + { + if (_abortState >= (int)AbortState.Closed) + { + ThrowCloseException(); + } + } + + private void ThrowCloseException() + { + throw CreateCloseException(); + } + public async ValueTask WriteAsync(ReadOnlyMemory memory, CancellationToken cancellationToken) { ThrowIfDisposed(); + ThrowIfEofSent(); while (memory.Length > 0) { - if (_abortState >= (int)AbortState.Closed) - { - throw CreateCloseException(); - } + ThrowIfAborted(); int sendWindow = Volatile.Read(ref _sendWindow); if (sendWindow > 0) @@ -595,6 +630,9 @@ private Exception CreateObjectDisposedException() private void TrySendChannelFailureMessage() => TrySendPacket(_sequencePool.CreateChannelFailureMessage(RemoteChannel)); + private void TrySendEofMessage() + => TrySendPacket(_sequencePool.CreateChannelEofMessage(RemoteChannel)); + private void TrySendChannelDataMessage(ReadOnlyMemory memory) => TrySendPacket(_sequencePool.CreateChannelDataMessage(RemoteChannel, memory)); } diff --git a/src/Tmds.Ssh/SshClient.cs b/src/Tmds.Ssh/SshClient.cs index 2fddacc..25b0fab 100644 --- a/src/Tmds.Ssh/SshClient.cs +++ b/src/Tmds.Ssh/SshClient.cs @@ -2,6 +2,7 @@ // See file LICENSE for full license details. using System.Diagnostics; +using System.Net; using System.Runtime.CompilerServices; using System.Text; using Microsoft.Extensions.Logging; @@ -271,17 +272,22 @@ public async Task ExecuteSubsystemAsync(string subsystem, Execute public async Task OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default) { SshSession session = await GetSessionAsync(cancellationToken).ConfigureAwait(false); - var channel = await session.OpenTcpConnectionChannelAsync(typeof(SshDataStream), host, port, cancellationToken).ConfigureAwait(false); - - return new SshDataStream(channel); + return await session.OpenTcpConnectionChannelAsync(host, port, cancellationToken).ConfigureAwait(false); } public async Task OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default) { SshSession session = await GetSessionAsync(cancellationToken).ConfigureAwait(false); - var channel = await session.OpenUnixConnectionChannelAsync(typeof(SshDataStream), path, cancellationToken).ConfigureAwait(false); + return await session.OpenUnixConnectionChannelAsync(path, cancellationToken).ConfigureAwait(false); + } + + public async Task StartForwardTcpAsync(EndPoint bindEndpoint, string remoteHost, int remotePort, CancellationToken cancellationToken = default) + { + SshSession session = await GetSessionAsync(cancellationToken).ConfigureAwait(false); - return new SshDataStream(channel); + var forward = new LocalForward(session, _loggers.GetLocalPortForwardLogger()); + forward.StartTcpForward(bindEndpoint, remoteHost, remotePort); + return forward; } public Task OpenSftpClientAsync(CancellationToken cancellationToken) diff --git a/src/Tmds.Ssh/SshDataStream.cs b/src/Tmds.Ssh/SshDataStream.cs index 2a0c6ec..0f3443d 100644 --- a/src/Tmds.Ssh/SshDataStream.cs +++ b/src/Tmds.Ssh/SshDataStream.cs @@ -12,6 +12,11 @@ internal SshDataStream(ISshChannel channel) _channel = channel; } + // OpenSSH uses the maximum packet sizes as how much data may fit into an SSH_MSG_CHANNEL_DATA packet. + // We're following that behavior and don't subtract bytes for the header. + internal int ReadMaxPacketDataLength => _channel.ReceiveMaxPacket; + internal int WriteMaxPacketDataLength => _channel.SendMaxPacket; + public CancellationToken StreamAborted => _channel.ChannelAborted; @@ -81,6 +86,11 @@ protected override void Dispose(bool disposing) } } + public void WriteEof() + { + _channel.WriteEof(); + } + public override async ValueTask WriteAsync(System.ReadOnlyMemory buffer, CancellationToken cancellationToken = default(CancellationToken)) { try diff --git a/src/Tmds.Ssh/SshLoggers.cs b/src/Tmds.Ssh/SshLoggers.cs index 97133d1..d1a3e1b 100644 --- a/src/Tmds.Ssh/SshLoggers.cs +++ b/src/Tmds.Ssh/SshLoggers.cs @@ -5,11 +5,15 @@ namespace Tmds.Ssh; sealed class SshLoggers { + private readonly ILoggerFactory _loggerFactory; + public ILogger SshClientLogger { get; } + public ILogger GetLocalPortForwardLogger() => _loggerFactory.CreateLogger(); + public SshLoggers(ILoggerFactory? loggerFactory = null) { - loggerFactory ??= new NullLoggerFactory(); + _loggerFactory = loggerFactory ??= new NullLoggerFactory(); SshClientLogger = loggerFactory.CreateLogger(); } diff --git a/src/Tmds.Ssh/SshPortForwardLogger.cs b/src/Tmds.Ssh/SshPortForwardLogger.cs new file mode 100644 index 0000000..9330f26 --- /dev/null +++ b/src/Tmds.Ssh/SshPortForwardLogger.cs @@ -0,0 +1,57 @@ +using System.Buffers; +using System.Net; +using System.Net.Security; +using Microsoft.Extensions.Logging; + +namespace Tmds.Ssh; + +static partial class SshPortForwardLogger +{ + [LoggerMessage( + EventId = 0, + Level = LogLevel.Information, + Message = "Forwarding connections from '{BindEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardStart(this ILogger logger, EndPoint bindEndPoint, string remoteEndPoint); + + [LoggerMessage( + EventId = 1, + Level = LogLevel.Information, + Message = "Accepted connection at '{BindEndPoint}' from '{PeerEndPoint}' to forward to '{RemoteEndPoint}'")] + public static partial void AcceptConnection(this ILogger logger, EndPoint bindEndPoint, EndPoint peerEndPoint, string remoteEndPoint); + + [LoggerMessage( + EventId = 2, + Level = LogLevel.Error, + Message = "Failed to forward connection from '{PeerEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardConnectionFailed(this ILogger logger, EndPoint? peerEndPoint, string remoteEndPoint, Exception exception); + + [LoggerMessage( + EventId = 3, + Level = LogLevel.Information, + Message = "Forwarding connection from '{PeerEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardConnection(this ILogger logger, EndPoint peerEndPoint, string remoteEndPoint); + + [LoggerMessage( + EventId = 4, + Level = LogLevel.Information, + Message = "Closed forwarded connection from '{PeerEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardConnectionClosed(this ILogger logger, EndPoint peerEndPoint, string remoteEndPoint); + + [LoggerMessage( + EventId = 5, + Level = LogLevel.Error, + Message = "Aborted forwarded connection from '{PeerEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardConnectionAborted(this ILogger logger, EndPoint peerEndPoint, string remoteEndPoint, Exception exception); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "Stopped forwarding connections from '{BindEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardStopped(this ILogger logger, EndPoint bindEndPoint, string remoteEndPoint); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Error, + Message = "Aborted forwarding connections from '{BindEndPoint}' to '{RemoteEndPoint}'")] + public static partial void ForwardAborted(this ILogger logger, EndPoint bindEndPoint, string remoteEndPoint, Exception exception); +} \ No newline at end of file diff --git a/src/Tmds.Ssh/SshSequencePoolExtensions.cs b/src/Tmds.Ssh/SshSequencePoolExtensions.cs index 1bb8a3c..e1ade0d 100644 --- a/src/Tmds.Ssh/SshSequencePoolExtensions.cs +++ b/src/Tmds.Ssh/SshSequencePoolExtensions.cs @@ -47,6 +47,19 @@ uint32 recipient channel return packet.Move(); } + public static Packet CreateChannelEofMessage(this SequencePool sequencePool, uint remoteChannel) + { + /* + byte SSH_MSG_CHANNEL_EOF + uint32 recipient channel + */ + using var packet = sequencePool.RentPacket(); + var writer = packet.GetWriter(); + writer.WriteMessageId(MessageId.SSH_MSG_CHANNEL_EOF); + writer.WriteUInt32(remoteChannel); + return packet.Move(); + } + public static Packet CreateChannelDataMessage(this SequencePool sequencePool, uint remoteChannel, ReadOnlyMemory memory) { /* diff --git a/src/Tmds.Ssh/SshSession.cs b/src/Tmds.Ssh/SshSession.cs index 0799ac7..5d35fff 100644 --- a/src/Tmds.Ssh/SshSession.cs +++ b/src/Tmds.Ssh/SshSession.cs @@ -808,9 +808,9 @@ public async Task OpenRemoteProcessChannelAsync(Type channelType, s public async Task OpenRemoteSubsystemChannelAsync(Type channelType, string subsystem, CancellationToken cancellationToken) => await OpenSubsystemChannelAsync(channelType, null, subsystem, cancellationToken).ConfigureAwait(false); - public async Task OpenTcpConnectionChannelAsync(Type channelType, string host, int port, CancellationToken cancellationToken) + public async Task OpenTcpConnectionChannelAsync(string host, int port, CancellationToken cancellationToken) { - SshChannel channel = CreateChannel(channelType); + SshChannel channel = CreateChannel(typeof(SshDataStream)); try { IPAddress originatorIP = IPAddress.Any; @@ -818,7 +818,7 @@ public async Task OpenTcpConnectionChannelAsync(Type channelType, s channel.TrySendChannelOpenDirectTcpIpMessage(host, (uint)port, originatorIP, (uint)originatorPort); await channel.ReceiveChannelOpenConfirmationAsync(cancellationToken).ConfigureAwait(false); - return channel; + return new SshDataStream(channel); } catch { @@ -827,15 +827,15 @@ public async Task OpenTcpConnectionChannelAsync(Type channelType, s } } - public async Task OpenUnixConnectionChannelAsync(Type channelType, string path, CancellationToken cancellationToken) + public async Task OpenUnixConnectionChannelAsync(string path, CancellationToken cancellationToken) { - SshChannel channel = CreateChannel(channelType); + SshChannel channel = CreateChannel(typeof(SshDataStream)); try { channel.TrySendChannelOpenDirectStreamLocalMessage(path); await channel.ReceiveChannelOpenConfirmationAsync(cancellationToken).ConfigureAwait(false); - return channel; + return new SshDataStream(channel); } catch { diff --git a/test/Tmds.Ssh.Tests/LocalForwardTests.cs b/test/Tmds.Ssh.Tests/LocalForwardTests.cs new file mode 100644 index 0000000..77d27c6 --- /dev/null +++ b/test/Tmds.Ssh.Tests/LocalForwardTests.cs @@ -0,0 +1,100 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Text; +using Xunit; + +namespace Tmds.Ssh.Tests; + +[Collection(nameof(SshServerCollection))] +public class LocalForwardTests +{ + const int SocatStartDelay = 100; + private readonly SshServer _sshServer; + + public LocalForwardTests(SshServer sshServer) + { + _sshServer = sshServer; + } + + [Fact] + public async Task Forwards() + { + using var client = await _sshServer.CreateClientAsync(); + + // start a an echo server using socat. + const int socatPort = 1234; + using var soCatProcess = await client.ExecuteAsync($"socat -v tcp-l:{socatPort},fork exec:'/bin/cat'"); + await Task.Delay(SocatStartDelay); // wait a little for socat to start. + + using var localForward = await client.StartForwardTcpAsync(new IPEndPoint(IPAddress.Loopback, 0), "localhost", socatPort); + + byte[] helloWorldBytes = Encoding.UTF8.GetBytes("hello world"); + byte[] receiveBuffer = new byte[128]; + for (int i = 0; i < 2; i++) + { + EndPoint endPoint = localForward.EndPoint!; + using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(endPoint); + + for (int j = 0; j < 2; j++) + { + await socket.SendAsync(helloWorldBytes); + + int bytesRead = await socket.ReceiveAsync(receiveBuffer); + Assert.Equal(helloWorldBytes.Length, bytesRead); + Assert.Equal(helloWorldBytes, receiveBuffer.AsSpan(0, bytesRead).ToArray()); + } + + socket.Shutdown(SocketShutdown.Send); + int received = await socket.ReceiveAsync(receiveBuffer); + Assert.Equal(0, received); + } + } + + [Fact] + public async Task StopsWhenDisposed() + { + using var client = await _sshServer.CreateClientAsync(); + + using var localForward = await client.StartForwardTcpAsync(new IPEndPoint(IPAddress.Loopback, 0), "localhost", 5000); + CancellationToken ct = localForward.ForwardStopped; + EndPoint? endPoint = localForward.EndPoint; + + Assert.False(ct.IsCancellationRequested); + Assert.NotNull(endPoint); + + localForward.Dispose(); + + Assert.True(ct.IsCancellationRequested); + Assert.Throws(() => localForward.EndPoint); + Assert.Throws(() => localForward.ForwardStopped); + Assert.Throws(() => localForward.ThrowIfStopped()); + + using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await Assert.ThrowsAnyAsync(async () => await socket.ConnectAsync(endPoint)); + } + + [Fact] + public async Task StopsWhenClientDisconnects() + { + using var client = await _sshServer.CreateClientAsync(); + + using var localForward = await client.StartForwardTcpAsync(new IPEndPoint(IPAddress.Loopback, 0), "localhost", 5000); + CancellationToken ct = localForward.ForwardStopped; + EndPoint? endPoint = localForward.EndPoint; + + Assert.False(ct.IsCancellationRequested); + Assert.NotNull(endPoint); + + client.Dispose(); + + Assert.True(ct.IsCancellationRequested); + Assert.Null(localForward.EndPoint); + Assert.True(localForward.ForwardStopped.IsCancellationRequested); + Assert.Throws(() => localForward.ThrowIfStopped()); + + using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await Assert.ThrowsAnyAsync(async () => await socket.ConnectAsync(endPoint)); + } +}