Skip to content

Commit

Permalink
Remove unnecessary DNS lookup in Connect (#1412)
Browse files Browse the repository at this point in the history
The library currently performs a DNS lookup of the desired host, takes the first
returned IP address and connects to that. Instead, we can just pass the hostname
down to System.Net.Sockets which will do the right thing, potentially trying
multiple addresses if needed.

Co-authored-by: Wojciech Nagórski <[email protected]>
  • Loading branch information
Rob-Hague and WojciechNagorski authored Jun 18, 2024
1 parent 27fad71 commit 919af75
Show file tree
Hide file tree
Showing 43 changed files with 147 additions and 154 deletions.
6 changes: 3 additions & 3 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
return socket;
}

public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
public static void Connect(Socket socket, EndPoint remoteEndpoint, TimeSpan connectTimeout)
{
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
}

public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
public static async Task ConnectAsync(Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
{
await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
}

private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
private static void ConnectCore(Socket socket, EndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
var connectCompleted = new ManualResetEvent(initialState: false);
var args = new SocketAsyncEventArgs
Expand Down
2 changes: 1 addition & 1 deletion src/Renci.SshNet/Abstractions/SocketExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public void GetResult()
}
}

public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
public static async Task ConnectAsync(this Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

Expand Down
37 changes: 12 additions & 25 deletions src/Renci.SshNet/Connection/ConnectorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,21 @@ protected ConnectorBase(ISocketFactory socketFactory)
public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);

/// <summary>
/// Establishes a socket connection to the specified host and port.
/// Establishes a socket connection to the specified endpoint.
/// </summary>
/// <param name="host">The host name of the server to connect to.</param>
/// <param name="port">The port to connect to.</param>
/// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
/// <param name="timeout">The maximum time to wait for the connection to be established.</param>
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
protected Socket SocketConnect(string host, int port, TimeSpan timeout)
protected Socket SocketConnect(EndPoint endPoint, TimeSpan timeout)
{
var ipAddress = Dns.GetHostAddresses(host)[0];
var ep = new IPEndPoint(ipAddress, port);
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));

DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));

var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);

try
{
SocketAbstraction.Connect(socket, ep, timeout);
SocketAbstraction.Connect(socket, endPoint, timeout);

const int socketBufferSize = 10 * Session.MaximumSshPacketSize;
socket.SendBufferSize = socketBufferSize;
Expand All @@ -62,31 +58,22 @@ protected Socket SocketConnect(string host, int port, TimeSpan timeout)
}

/// <summary>
/// Establishes a socket connection to the specified host and port.
/// Establishes a socket connection to the specified endpoint.
/// </summary>
/// <param name="host">The host name of the server to connect to.</param>
/// <param name="port">The port to connect to.</param>
/// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
/// <param name="cancellationToken">The cancellation token to observe.</param>
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
protected async Task<Socket> SocketConnectAsync(string host, int port, CancellationToken cancellationToken)
protected async Task<Socket> SocketConnectAsync(EndPoint endPoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

#if NET6_0_OR_GREATER
var ipAddress = (await Dns.GetHostAddressesAsync(host, cancellationToken).ConfigureAwait(false))[0];
#else
var ipAddress = (await Dns.GetHostAddressesAsync(host).ConfigureAwait(false))[0];
#endif

var ep = new IPEndPoint(ipAddress, port);

DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));

var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
try
{
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);
await SocketAbstraction.ConnectAsync(socket, endPoint, cancellationToken).ConfigureAwait(false);

const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
socket.SendBufferSize = socketBufferSize;
Expand Down
7 changes: 4 additions & 3 deletions src/Renci.SshNet/Connection/DirectConnector.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Net.Sockets;
using System.Net;
using System.Net.Sockets;
using System.Threading;

namespace Renci.SshNet.Connection
Expand All @@ -12,12 +13,12 @@ public DirectConnector(ISocketFactory socketFactory)

public override Socket Connect(IConnectionInfo connectionInfo)
{
return SocketConnect(connectionInfo.Host, connectionInfo.Port, connectionInfo.Timeout);
return SocketConnect(new DnsEndPoint(connectionInfo.Host, connectionInfo.Port), connectionInfo.Timeout);
}

public override System.Threading.Tasks.Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
{
return SocketConnectAsync(connectionInfo.Host, connectionInfo.Port, cancellationToken);
return SocketConnectAsync(new DnsEndPoint(connectionInfo.Host, connectionInfo.Port), cancellationToken);
}
}
}
8 changes: 3 additions & 5 deletions src/Renci.SshNet/Connection/ISocketFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ namespace Renci.SshNet.Connection
internal interface ISocketFactory
{
/// <summary>
/// Creates a <see cref="Socket"/> with the specified <see cref="AddressFamily"/>,
/// <see cref="SocketType"/> and <see cref="ProtocolType"/> that does not use the
/// <c>Nagle</c> algorithm.
/// Creates a <see cref="Socket"/> with the specified <see cref="SocketType"/>
/// and <see cref="ProtocolType"/> that does not use the <c>Nagle</c> algorithm.
/// </summary>
/// <param name="addressFamily">The <see cref="AddressFamily"/>.</param>
/// <param name="socketType">The <see cref="SocketType"/>.</param>
/// <param name="protocolType">The <see cref="ProtocolType"/>.</param>
/// <returns>
/// The <see cref="Socket"/>.
/// </returns>
Socket Create(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType);
Socket Create(SocketType socketType, ProtocolType protocolType);
}
}
5 changes: 3 additions & 2 deletions src/Renci.SshNet/Connection/ProxyConnector.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -52,7 +53,7 @@ Task HandleProxyConnectAsync(IConnectionInfo connectionInfo, Socket socket, Canc
/// </returns>
public override Socket Connect(IConnectionInfo connectionInfo)
{
var socket = SocketConnect(connectionInfo.ProxyHost, connectionInfo.ProxyPort, connectionInfo.Timeout);
var socket = SocketConnect(new DnsEndPoint(connectionInfo.ProxyHost, connectionInfo.ProxyPort), connectionInfo.Timeout);

try
{
Expand All @@ -78,7 +79,7 @@ public override Socket Connect(IConnectionInfo connectionInfo)
/// </returns>
public override async Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
{
var socket = await SocketConnectAsync(connectionInfo.ProxyHost, connectionInfo.ProxyPort, cancellationToken).ConfigureAwait(false);
var socket = await SocketConnectAsync(new DnsEndPoint(connectionInfo.ProxyHost, connectionInfo.ProxyPort), cancellationToken).ConfigureAwait(false);

try
{
Expand Down
16 changes: 3 additions & 13 deletions src/Renci.SshNet/Connection/SocketFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,10 @@ namespace Renci.SshNet.Connection
/// </summary>
internal sealed class SocketFactory : ISocketFactory
{
/// <summary>
/// Creates a <see cref="Socket"/> with the specified <see cref="AddressFamily"/>,
/// <see cref="SocketType"/> and <see cref="ProtocolType"/> that does not use the
/// <c>Nagle</c> algorithm.
/// </summary>
/// <param name="addressFamily">The <see cref="AddressFamily"/>.</param>
/// <param name="socketType">The <see cref="SocketType"/>.</param>
/// <param name="protocolType">The <see cref="ProtocolType"/>.</param>
/// <returns>
/// The <see cref="Socket"/>.
/// </returns>
public Socket Create(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
/// <inheritdoc/>
public Socket Create(SocketType socketType, ProtocolType protocolType)
{
return new Socket(addressFamily, socketType, protocolType) { NoDelay = true };
return new Socket(socketType, protocolType) { NoDelay = true };
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ protected override void SetupData()
_stopWatch = new Stopwatch();
_actualException = null;

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -95,7 +95,7 @@ public void ClientSocketShouldHaveBeenDisposed()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected override void SetupData()
_stopWatch = new Stopwatch();
_disconnected = false;

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);

_server = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.Port));
_server.Disconnected += (socket) => _disconnected = true;
Expand All @@ -42,7 +42,7 @@ protected override void SetupData()

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -106,7 +106,7 @@ public void NoBytesShouldHaveBeenReadFromSocket()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Renci.SshNet.Tests.Classes.Connection
public class DirectConnectorTest_Connect_HostNameInvalid : DirectConnectorTestBase
{
private ConnectionInfo _connectionInfo;
private Socket _clientSocket;
private SocketException _actualException;

protected override void SetupData()
Expand All @@ -16,6 +17,7 @@ protected override void SetupData()

_connectionInfo = CreateConnectionInfo("invalid.");
_actualException = null;
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
}

protected override void Act()
Expand All @@ -31,6 +33,12 @@ protected override void Act()
}
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

[TestMethod]
public void ConnectShouldHaveThrownSocketException()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ protected override void SetupData()
_stopWatch = new Stopwatch();
_actualException = null;

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -116,7 +116,7 @@ public void ClientSocketShouldHaveBeenDisposed()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ protected override void SetupData()
_stopWatch = new Stopwatch();
_actualException = null;

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -105,7 +105,7 @@ public void ClientSocketShouldHaveBeenDisposed()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected override void SetupData()
};
_actualException = null;

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);

_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
_proxyServer.Disconnected += socket => _disconnected = true;
Expand All @@ -52,7 +52,7 @@ protected override void SetupData()

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -110,7 +110,7 @@ public void ClientSocketShouldHaveBeenDisposed()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class HttpConnectorTest_Connect_ProxyHostInvalid : HttpConnectorTestBase
{
private ConnectionInfo _connectionInfo;
private SocketException _actualException;
private Socket _clientSocket;

protected override void SetupData()
{
Expand All @@ -24,6 +25,13 @@ protected override void SetupData()
"proxyPwd",
new KeyboardInteractiveAuthenticationMethod("user"));
_actualException = null;
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

protected override void Act()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected override void SetupData()
"\r\n");
_bytesReceivedByProxy = new List<byte>();
_disconnected = false;
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);

_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
_proxyServer.Disconnected += (socket) => _disconnected = true;
Expand All @@ -72,7 +72,7 @@ protected override void SetupData()

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

Expand Down Expand Up @@ -131,7 +131,7 @@ public void ClientSocketShouldBeConnected()
[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
Expand Down
Loading

0 comments on commit 919af75

Please sign in to comment.