Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LocalForward API. #258

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SshClient : IDisposable

Task<SshDataStream> OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default);
Task<SshDataStream> OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default);
Task<LocalForward> StartForwardTcpAsync(EndPoint bindEndpoint, string remoteHost, int remotePort, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = null, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -112,8 +113,15 @@ class RemoteProcess : IDisposable
}
class SshDataStream : Stream
{
void WriteEof();
CancellationToken StreamAborted { get; }
}
class LocalForward : IDisposable
{
EndPoint? EndPoint { get; }
CancellationToken ForwardStopped { get; }
void ThrowIfStopped();
}
class SftpClient : IDisposable
{
// Note: umask is applied on the server.
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/ISshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ interface ISshChannel
CancellationToken cancellationToken = default);

ValueTask WriteAsync(ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default);
void WriteEof();

Exception CreateCloseException();
}
273 changes: 273 additions & 0 deletions src/Tmds.Ssh/LocalForward.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System.Net;
using System.Net.Sockets;
using System.Diagnostics;
using Microsoft.Extensions.Logging;

namespace Tmds.Ssh;

public sealed class LocalForward : IDisposable
{
// Sentinel stop reasons.
private static readonly Exception ConnectionClosed = new();
private static readonly Exception Disposed = new();

private readonly SshSession _session;
private readonly ILogger<LocalForward> _logger;
private readonly CancellationTokenSource _cancel;

private Func<CancellationToken, Task<SshDataStream>>? _connectToRemote;
private Socket? _serverSocket;
private EndPoint? _localEndPoint;
private CancellationTokenRegistration _ctr;
private Exception? _stopReason;
private string _remoteEndPoint;

private bool IsDisposed => ReferenceEquals(_stopReason, Disposed);

internal LocalForward(SshSession session, ILogger<LocalForward> logger)
{
_logger = logger;
_session = session;
_cancel = new();

_remoteEndPoint = "";
}

internal void StartTcpForward(EndPoint bindEndpoint, string remoteHost, int remotePort)
{
ArgumentNullException.ThrowIfNull(bindEndpoint);
ArgumentException.ThrowIfNullOrEmpty(remoteHost);
if (remotePort < 0 || remotePort > 0xffff)
{
throw new ArgumentException(nameof(remotePort));
}
if (bindEndpoint is not IPEndPoint)
{
throw new ArgumentException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
}

_remoteEndPoint = $"{remoteHost}:{remotePort}";
_connectToRemote = async ct => await _session.OpenTcpConnectionChannelAsync(remoteHost, remotePort, ct).ConfigureAwait(false);

Start(bindEndpoint);
}

private void Start(EndPoint bindEndpoint)
{
// Assign to bindEndPoint in case we fail to bind/listen so we have an address for logging.
_localEndPoint = bindEndpoint;

try
{
if (bindEndpoint is IPEndPoint ipEndPoint)
{
_serverSocket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
}
else
{
// Type must be validated before calling this method.
throw new InvalidOperationException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
}

_serverSocket.Bind(bindEndpoint);
_serverSocket.Listen();

EndPoint localEndPoint = _serverSocket.LocalEndPoint!;
_localEndPoint = localEndPoint;

_ctr = _session.ConnectionClosed.UnsafeRegister(o => ((LocalForward)o!).Stop(ConnectionClosed), this);

_ = AcceptLoop(localEndPoint);
}
catch (Exception ex)
{
Stop(ex);

throw;
}
}

private async Task AcceptLoop(EndPoint localEndPoint)
{
try
{
_logger.ForwardStart(localEndPoint, _remoteEndPoint);
while (true)
{
Socket acceptedSocket = await _serverSocket!.AcceptAsync().ConfigureAwait(false);
_ = Accept(acceptedSocket, localEndPoint);
}
}
catch (Exception ex)
{
Stop(ex);
}
}

private async Task Accept(Socket acceptedSocket, EndPoint localEndpoint)
{
Debug.Assert(_connectToRemote is not null);
SshDataStream? forwardStream = null;
EndPoint? peerEndPoint = null;
try
{
peerEndPoint = acceptedSocket.RemoteEndPoint!;
_logger.AcceptConnection(localEndpoint, peerEndPoint, _remoteEndPoint);
acceptedSocket.NoDelay = true;

// We may want to add a timeout option, and the ability to stop the lister on some conditions like nr of successive fails to connect to the remote.
forwardStream = await _connectToRemote(_cancel!.Token).ConfigureAwait(false);
_ = ForwardConnectionAsync(new NetworkStream(acceptedSocket, ownsSocket: true), forwardStream, peerEndPoint);
}
catch (Exception ex)
{
_logger.ForwardConnectionFailed(peerEndPoint, _remoteEndPoint, ex);

acceptedSocket.Dispose();
forwardStream?.Dispose();
}
}

private async Task ForwardConnectionAsync(NetworkStream socketStream, SshDataStream sshStream, EndPoint peerEndPoint)
{
Exception? exception = null;
try
{
_logger.ForwardConnection(peerEndPoint, _remoteEndPoint);
Task first, second;
try
{
Task copy1 = CopyTillEofAsync(socketStream, sshStream, sshStream.WriteMaxPacketDataLength);
Task copy2 = CopyTillEofAsync(sshStream, socketStream, sshStream.ReadMaxPacketDataLength);

first = await Task.WhenAny(copy1, copy2).ConfigureAwait(false);
second = first == copy1 ? copy2 : copy1;
}
finally
{
// When the copy stops in one direction, stop it in the other direction too.
// Though TCP allows data still to be received when the writing is shutdown
// application protocols (usually) follow the pattern of only closing
// when they will no longer receive.
socketStream.Dispose();
sshStream.Dispose();
}
// The dispose will cause the second copy to stop.
await second.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);

await first.ConfigureAwait(false); // Throws if faulted.
}
catch (Exception ex)
{
exception = ex;
}
finally
{
if (exception is null)
{
_logger.ForwardConnectionClosed(peerEndPoint, _remoteEndPoint);
}
else
{
_logger.ForwardConnectionAborted(peerEndPoint, _remoteEndPoint, exception);
}
}

static async Task CopyTillEofAsync(Stream from, Stream to, int bufferSize)
{
await from.CopyToAsync(to, bufferSize).ConfigureAwait(false);
if (to is NetworkStream ns)
{
ns.Socket.Shutdown(SocketShutdown.Send);
}
else if (to is SshDataStream ds)
{
ds.WriteEof();
}
}
}

public EndPoint? EndPoint
{
get
{
ObjectDisposedException.ThrowIf(IsDisposed, this);
return _localEndPoint;
}
}

public CancellationToken ForwardStopped
{
get
{
ObjectDisposedException.ThrowIf(IsDisposed, this);
return _cancel.Token;
}
}

public void ThrowIfStopped()
{
Exception? stopReason = _stopReason;
if (ReferenceEquals(stopReason, Disposed))
{
throw new ObjectDisposedException(typeof(LocalForward).FullName);
}
else if (ReferenceEquals(stopReason, ConnectionClosed))
{
throw _session.CreateCloseException();
}
else if (stopReason is not null)
{
throw new SshException($"{nameof(LocalForward)} stopped due to an unexpected error.", stopReason);
}
}

private void Stop(Exception stopReason)
{
bool disposing = ReferenceEquals(stopReason, Disposed);
if (disposing)
{
if (Interlocked.Exchange(ref _stopReason, Disposed) != null)
{
return;
}
}
else
{
if (Interlocked.CompareExchange(ref _stopReason, stopReason, null) != null)
{
return;
}
}

if (_localEndPoint is not null)
{
if (IsDisposed)
{
_logger.ForwardStopped(_localEndPoint, _remoteEndPoint);
}
else if (ReferenceEquals(stopReason, ConnectionClosed))
{
if (_logger.IsEnabled(LogLevel.Error))
{
_logger.ForwardAborted(_localEndPoint, _remoteEndPoint, _session.CreateCloseException());
}
}
else
{
_logger.ForwardAborted(_localEndPoint, _remoteEndPoint, stopReason);
}
_ctr.Dispose();
_localEndPoint = null;
_serverSocket?.Dispose();
}

_cancel.Cancel();
}

public void Dispose()
=> Stop(Disposed);
}
46 changes: 42 additions & 4 deletions src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public CancellationToken ChannelAborted
private int _disposed;
private int _sendWindow;
private Exception? _abortReason = null;
private bool _eofSent;
private object _gate => _receiveQueue;

public async ValueTask<(ChannelReadType ReadType, int BytesRead)> ReadAsync
Expand Down Expand Up @@ -173,16 +174,50 @@ string data
}
}

public void WriteEof()
{
ThrowIfDisposed();
ThrowIfAborted();
ThrowIfEofSent();

_eofSent = true;
TrySendEofMessage();
}

private void ThrowIfEofSent()
{
if (_eofSent)
{
ThrowEofSent();
}

static void ThrowEofSent()
{
throw new InvalidOperationException("EOF already sent.");
}
}

private void ThrowIfAborted()
{
if (_abortState >= (int)AbortState.Closed)
{
ThrowCloseException();
}
}

private void ThrowCloseException()
{
throw CreateCloseException();
}

public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfEofSent();

while (memory.Length > 0)
{
if (_abortState >= (int)AbortState.Closed)
{
throw CreateCloseException();
}
ThrowIfAborted();

int sendWindow = Volatile.Read(ref _sendWindow);
if (sendWindow > 0)
Expand Down Expand Up @@ -595,6 +630,9 @@ private Exception CreateObjectDisposedException()
private void TrySendChannelFailureMessage()
=> TrySendPacket(_sequencePool.CreateChannelFailureMessage(RemoteChannel));

private void TrySendEofMessage()
=> TrySendPacket(_sequencePool.CreateChannelEofMessage(RemoteChannel));

private void TrySendChannelDataMessage(ReadOnlyMemory<byte> memory)
=> TrySendPacket(_sequencePool.CreateChannelDataMessage(RemoteChannel, memory));
}
Loading