Skip to content

Commit

Permalink
Add LocalForward API. (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Dec 6, 2024
1 parent dc78b64 commit 58d8af9
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 16 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SshClient : IDisposable

Task<SshDataStream> OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default);
Task<SshDataStream> OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default);
Task<LocalForward> StartForwardTcpAsync(EndPoint bindEndpoint, string remoteHost, int remotePort, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = null, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -112,8 +113,15 @@ class RemoteProcess : IDisposable
}
class SshDataStream : Stream
{
void WriteEof();
CancellationToken StreamAborted { get; }
}
class LocalForward : IDisposable
{
EndPoint? EndPoint { get; }
CancellationToken ForwardStopped { get; }
void ThrowIfStopped();
}
class SftpClient : IDisposable
{
// Note: umask is applied on the server.
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/ISshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ interface ISshChannel
CancellationToken cancellationToken = default);

ValueTask WriteAsync(ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default);
void WriteEof();

Exception CreateCloseException();
}
273 changes: 273 additions & 0 deletions src/Tmds.Ssh/LocalForward.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System.Net;
using System.Net.Sockets;
using System.Diagnostics;
using Microsoft.Extensions.Logging;

namespace Tmds.Ssh;

public sealed class LocalForward : IDisposable
{
// Sentinel stop reasons.
private static readonly Exception ConnectionClosed = new();
private static readonly Exception Disposed = new();

private readonly SshSession _session;
private readonly ILogger<LocalForward> _logger;
private readonly CancellationTokenSource _cancel;

private Func<CancellationToken, Task<SshDataStream>>? _connectToRemote;
private Socket? _serverSocket;
private EndPoint? _localEndPoint;
private CancellationTokenRegistration _ctr;
private Exception? _stopReason;
private string _remoteEndPoint;

private bool IsDisposed => ReferenceEquals(_stopReason, Disposed);

internal LocalForward(SshSession session, ILogger<LocalForward> logger)
{
_logger = logger;
_session = session;
_cancel = new();

_remoteEndPoint = "";
}

internal void StartTcpForward(EndPoint bindEndpoint, string remoteHost, int remotePort)
{
ArgumentNullException.ThrowIfNull(bindEndpoint);
ArgumentException.ThrowIfNullOrEmpty(remoteHost);
if (remotePort < 0 || remotePort > 0xffff)
{
throw new ArgumentException(nameof(remotePort));
}
if (bindEndpoint is not IPEndPoint)
{
throw new ArgumentException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
}

_remoteEndPoint = $"{remoteHost}:{remotePort}";
_connectToRemote = async ct => await _session.OpenTcpConnectionChannelAsync(remoteHost, remotePort, ct).ConfigureAwait(false);

Start(bindEndpoint);
}

private void Start(EndPoint bindEndpoint)
{
// Assign to bindEndPoint in case we fail to bind/listen so we have an address for logging.
_localEndPoint = bindEndpoint;

try
{
if (bindEndpoint is IPEndPoint ipEndPoint)
{
_serverSocket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
}
else
{
// Type must be validated before calling this method.
throw new InvalidOperationException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
}

_serverSocket.Bind(bindEndpoint);
_serverSocket.Listen();

EndPoint localEndPoint = _serverSocket.LocalEndPoint!;
_localEndPoint = localEndPoint;

_ctr = _session.ConnectionClosed.UnsafeRegister(o => ((LocalForward)o!).Stop(ConnectionClosed), this);

_ = AcceptLoop(localEndPoint);
}
catch (Exception ex)
{
Stop(ex);

throw;
}
}

private async Task AcceptLoop(EndPoint localEndPoint)
{
try
{
_logger.ForwardStart(localEndPoint, _remoteEndPoint);
while (true)
{
Socket acceptedSocket = await _serverSocket!.AcceptAsync().ConfigureAwait(false);
_ = Accept(acceptedSocket, localEndPoint);
}
}
catch (Exception ex)
{
Stop(ex);
}
}

private async Task Accept(Socket acceptedSocket, EndPoint localEndpoint)
{
Debug.Assert(_connectToRemote is not null);
SshDataStream? forwardStream = null;
EndPoint? peerEndPoint = null;
try
{
peerEndPoint = acceptedSocket.RemoteEndPoint!;
_logger.AcceptConnection(localEndpoint, peerEndPoint, _remoteEndPoint);
acceptedSocket.NoDelay = true;

// We may want to add a timeout option, and the ability to stop the lister on some conditions like nr of successive fails to connect to the remote.
forwardStream = await _connectToRemote(_cancel!.Token).ConfigureAwait(false);
_ = ForwardConnectionAsync(new NetworkStream(acceptedSocket, ownsSocket: true), forwardStream, peerEndPoint);
}
catch (Exception ex)
{
_logger.ForwardConnectionFailed(peerEndPoint, _remoteEndPoint, ex);

acceptedSocket.Dispose();
forwardStream?.Dispose();
}
}

private async Task ForwardConnectionAsync(NetworkStream socketStream, SshDataStream sshStream, EndPoint peerEndPoint)
{
Exception? exception = null;
try
{
_logger.ForwardConnection(peerEndPoint, _remoteEndPoint);
Task first, second;
try
{
Task copy1 = CopyTillEofAsync(socketStream, sshStream, sshStream.WriteMaxPacketDataLength);
Task copy2 = CopyTillEofAsync(sshStream, socketStream, sshStream.ReadMaxPacketDataLength);

first = await Task.WhenAny(copy1, copy2).ConfigureAwait(false);
second = first == copy1 ? copy2 : copy1;
}
finally
{
// When the copy stops in one direction, stop it in the other direction too.
// Though TCP allows data still to be received when the writing is shutdown
// application protocols (usually) follow the pattern of only closing
// when they will no longer receive.
socketStream.Dispose();
sshStream.Dispose();
}
// The dispose will cause the second copy to stop.
await second.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);

await first.ConfigureAwait(false); // Throws if faulted.
}
catch (Exception ex)
{
exception = ex;
}
finally
{
if (exception is null)
{
_logger.ForwardConnectionClosed(peerEndPoint, _remoteEndPoint);
}
else
{
_logger.ForwardConnectionAborted(peerEndPoint, _remoteEndPoint, exception);
}
}

static async Task CopyTillEofAsync(Stream from, Stream to, int bufferSize)
{
await from.CopyToAsync(to, bufferSize).ConfigureAwait(false);
if (to is NetworkStream ns)
{
ns.Socket.Shutdown(SocketShutdown.Send);
}
else if (to is SshDataStream ds)
{
ds.WriteEof();
}
}
}

public EndPoint? EndPoint
{
get
{
ObjectDisposedException.ThrowIf(IsDisposed, this);
return _localEndPoint;
}
}

public CancellationToken ForwardStopped
{
get
{
ObjectDisposedException.ThrowIf(IsDisposed, this);
return _cancel.Token;
}
}

public void ThrowIfStopped()
{
Exception? stopReason = _stopReason;
if (ReferenceEquals(stopReason, Disposed))
{
throw new ObjectDisposedException(typeof(LocalForward).FullName);
}
else if (ReferenceEquals(stopReason, ConnectionClosed))
{
throw _session.CreateCloseException();
}
else if (stopReason is not null)
{
throw new SshException($"{nameof(LocalForward)} stopped due to an unexpected error.", stopReason);
}
}

private void Stop(Exception stopReason)
{
bool disposing = ReferenceEquals(stopReason, Disposed);
if (disposing)
{
if (Interlocked.Exchange(ref _stopReason, Disposed) != null)
{
return;
}
}
else
{
if (Interlocked.CompareExchange(ref _stopReason, stopReason, null) != null)
{
return;
}
}

if (_localEndPoint is not null)
{
if (IsDisposed)
{
_logger.ForwardStopped(_localEndPoint, _remoteEndPoint);
}
else if (ReferenceEquals(stopReason, ConnectionClosed))
{
if (_logger.IsEnabled(LogLevel.Error))
{
_logger.ForwardAborted(_localEndPoint, _remoteEndPoint, _session.CreateCloseException());
}
}
else
{
_logger.ForwardAborted(_localEndPoint, _remoteEndPoint, stopReason);
}
_ctr.Dispose();
_localEndPoint = null;
_serverSocket?.Dispose();
}

_cancel.Cancel();
}

public void Dispose()
=> Stop(Disposed);
}
46 changes: 42 additions & 4 deletions src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public CancellationToken ChannelAborted
private int _disposed;
private int _sendWindow;
private Exception? _abortReason = null;
private bool _eofSent;
private object _gate => _receiveQueue;

public async ValueTask<(ChannelReadType ReadType, int BytesRead)> ReadAsync
Expand Down Expand Up @@ -173,16 +174,50 @@ string data
}
}

public void WriteEof()
{
ThrowIfDisposed();
ThrowIfAborted();
ThrowIfEofSent();

_eofSent = true;
TrySendEofMessage();
}

private void ThrowIfEofSent()
{
if (_eofSent)
{
ThrowEofSent();
}

static void ThrowEofSent()
{
throw new InvalidOperationException("EOF already sent.");
}
}

private void ThrowIfAborted()
{
if (_abortState >= (int)AbortState.Closed)
{
ThrowCloseException();
}
}

private void ThrowCloseException()
{
throw CreateCloseException();
}

public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfEofSent();

while (memory.Length > 0)
{
if (_abortState >= (int)AbortState.Closed)
{
throw CreateCloseException();
}
ThrowIfAborted();

int sendWindow = Volatile.Read(ref _sendWindow);
if (sendWindow > 0)
Expand Down Expand Up @@ -595,6 +630,9 @@ private Exception CreateObjectDisposedException()
private void TrySendChannelFailureMessage()
=> TrySendPacket(_sequencePool.CreateChannelFailureMessage(RemoteChannel));

private void TrySendEofMessage()
=> TrySendPacket(_sequencePool.CreateChannelEofMessage(RemoteChannel));

private void TrySendChannelDataMessage(ReadOnlyMemory<byte> memory)
=> TrySendPacket(_sequencePool.CreateChannelDataMessage(RemoteChannel, memory));
}
Loading

0 comments on commit 58d8af9

Please sign in to comment.