Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Commit of Transport Creation Handler #2569

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/Microsoft.Data.SqlClient.sln
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{4F3CD363-B1E6-4D6D-9466-97D78A56BE45}"
ProjectSection(SolutionItems) = preProject
Directory.Build.props = Directory.Build.props
NuGet.config = NuGet.config
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlServer.Server", "Microsoft.SqlServer.Server\Microsoft.SqlServer.Server.csproj", "{A314812A-7820-4565-A2A8-ABBE391C11E4}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequest.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequestType.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\IHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\IpAddressVersionComparer.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\TransportCreationHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\IO\TdsWriteStream.cs" />
<Compile Include="Microsoft\Data\SqlClientX\SqlConnectionX.cs" />
<EmbeddedResource Include="$(CommonSourceRoot)Resources\Strings.resx">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ namespace Microsoft.Data.SqlClientX.Handlers
internal class ConnectionHandlerContext : HandlerRequest
{
/// <summary>
/// Class that contains data required to handle a connection request.
/// Stream used by readers.
/// </summary>
public SqlConnectionString ConnectionString { get; set; }
public Stream ConnectionStream { get; set; }

/// <summary>
/// Stream used by readers.
/// Class that contains data required to handle a connection request.
/// </summary>
public Stream ConnectionStream { get; set; }
public SqlConnectionString ConnectionString { get; set; }

/// <summary>
/// Class required by DataSourceParser and Transport layer.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
/// <summary>
/// Comparer that sorts IP addresses based on the version of the internet protocol it is using.
/// This class cannot be instantiated, so to use it, use the singleton instances (doubleton?)
/// <see cref="InstanceV4"/> and <see cref="InstanceV6"/>.
/// </summary>
internal sealed class IpAddressVersionSorter : IComparer<IPAddress>
{
private readonly AddressFamily _preferredAddressFamily;

private IpAddressVersionSorter(AddressFamily preferredAddressFamily)
{
_preferredAddressFamily = preferredAddressFamily;
}

/// <summary>
/// Gets a singleton instance that ranks IPv4 addresses higher than IPv6 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV4 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetwork);

/// <summary>
/// Gets a singleton instance that ranks IPv6 addresses higher than IPv4 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV6 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetworkV6);

/// <inheritdoc />
public int Compare(IPAddress x, IPAddress y)
{
if (x is null) { throw new ArgumentNullException(nameof(x)); }
if (y is null) { throw new ArgumentNullException(nameof(y)); }

if (x.AddressFamily == y.AddressFamily)
{
// Versions are the same, it's a tie.
return 0;
}

return x.AddressFamily == _preferredAddressFamily ? 1 : -1;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
internal sealed class TransportCreationHandler : IHandler<ConnectionHandlerContext>
{
private const int KeepAliveIntervalSeconds = 1;
private const int KeepAliveTimeSeconds = 30;

#if NET8_0_OR_GREATER
private static readonly TimeSpan DefaultPollTimeout = TimeSpan.FromSeconds(30);
#else
private const int DefaultPollTimeout = 30 * 100000; // 30 seconds as microseconds
#endif

/// <inheritdoc />
public IHandler<ConnectionHandlerContext> NextHandler { get; set; }

/// <inheritdoc />
public async ValueTask Handle(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
Debug.Assert(context.DataSource is not null, "context.DataSource is null");

try
{
// @TODO: Build CoR for handling the different protocols in order
if (context.DataSource.ResolvedProtocol is DataSource.Protocol.TCP)
{
context.ConnectionStream = await HandleTcpRequest(context, isAsync, ct).ConfigureAwait(false);
}
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
else
{
throw new NotImplementedException();
}
}
catch (Exception e)
{
context.Error = e;
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
return;
}

if (NextHandler is not null)
{
await NextHandler.Handle(context, isAsync, ct).ConfigureAwait(false);
}
}

private ValueTask<Stream> HandleNamedPipeRequest()
{
throw new NotImplementedException();
}

private ValueTask<Stream> HandleSharedMemoryRequest()
{
throw new NotImplementedException();
}

private async ValueTask<Stream> HandleTcpRequest(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
benrr101 marked this conversation as resolved.
Show resolved Hide resolved

// DNS lookup
IPAddress[] ipAddresses = isAsync
? await Dns.GetHostAddressesAsync(context.DataSource.ServerName, ct).ConfigureAwait(false)
: Dns.GetHostAddresses(context.DataSource.ServerName);
if (ipAddresses is null || ipAddresses.Length == 0)
{
throw new SocketException((int)SocketError.HostNotFound);
}

// If there is an IP version preference, apply it
switch (context.ConnectionString.IPAddressPreference)
{
case SqlConnectionIPAddressPreference.IPv4First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV4);
break;

case SqlConnectionIPAddressPreference.IPv6First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV6);
break;

case SqlConnectionIPAddressPreference.UsePlatformDefault:
default:
// Not sorting necessary
break;
}

// Attempt to connect to one of the matching IP addresses
// @TODO: Handle opening in parallel
Socket socket = null;
var socketOpenExceptions = new List<Exception>();

int portToUse = context.DataSource.ResolvedPort < 0
? context.DataSource.Port
: context.DataSource.ResolvedPort;
var ipEndpoint = new IPEndPoint(IPAddress.None, portToUse); // Allocate once
foreach (IPAddress ipAddress in ipAddresses)
{
ipEndpoint.Address = ipAddress;
try
{
socket = await OpenSocket(ipEndpoint, isAsync, ct).ConfigureAwait(false);
break;
}
catch(Exception e)
{
socketOpenExceptions.Add(e);
}
}

// If no socket succeeded, throw
if (socket is null)
{
// If there are any socket exceptions in the collected exceptions, throw the first
// one. If there are not, collect all exceptions and throw them as an aggregate.
foreach (Exception exception in socketOpenExceptions)
{
if (exception is SocketException)
{
throw exception;
}
}

throw new AggregateException(socketOpenExceptions);
}

// Create the stream for the socket
return new NetworkStream(socket);
}

private async ValueTask<Socket> OpenSocket(IPEndPoint ipEndPoint, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sync cases, the callers will provide a default Cancellation token. the ct.ThrowIfCancellationRequested() may not be needed in sync paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to utilize the cancellation token as the way to signal that either the timeout has expired or the user has requested cancellation. This will be adopted in a later PR. As such, I think it would still be valuable to keep the cancellation token in the sync code path. @edwardneal had a good suggestion for how to cancel while the socket is opening by registering event on the cancellation token.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure how this would look end to end. If this theory holds, then great, else we will modify.


var socket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { Blocking = false };

// Enable keep-alive
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, KeepAliveIntervalSeconds);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, KeepAliveTimeSeconds);

try
{
if (isAsync)
{
#if NET6_0_OR_GREATER
await socket.ConnectAsync(ipEndPoint, ct).ConfigureAwait(false);
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
#else
// @TODO: Only real way to cancel this is to register a cancellation token event and dispose of the socket.
await new TaskFactory(ct).FromAsync(socket.BeginConnect, socket.EndConnect, ipEndPoint, null)
.ConfigureAwait(false);
#endif
}
else
{
OpenSocketSync(socket, ipEndPoint, ct);
}
}
catch (Exception)
{
socket.Dispose();
throw;
}

// Connection is established
socket.Blocking = true;
socket.NoDelay = true;

return socket;
}

private void OpenSocketSync(Socket socket, IPEndPoint ipEndPoint, CancellationToken ct)
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
{
ct.ThrowIfCancellationRequested();

try
{
socket.Connect(ipEndPoint);
}
catch (SocketException e)
{
// Because the socket is configured to be non-blocking, any operation that would
// block will throw an exception indicating it would block. Since opening a TCP
// connection will always block, we expect to get an exception for it, and will
// ignore it. This allows us to immediately return from connect and poll it,
// allowing us to observe timeouts and cancellation.
if (e.SocketErrorCode is not SocketError.WouldBlock)
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
{
throw;
}
}

// Poll the socket until it is open
// @TODO: This method can't be cancelled, so we should consider pooling smaller timeouts and looping while
// there is still time left on the timer, checking cancellation token each time.
if (!socket.Poll(DefaultPollTimeout, SelectMode.SelectWrite))
benrr101 marked this conversation as resolved.
Show resolved Hide resolved
{
throw new TimeoutException("Socket failed to open within timeout period.");
}
}
}
}