Skip to content

Commit

Permalink
Replace SocketsHttpConnectionFactory with SocketsConnectionFactory (#…
Browse files Browse the repository at this point in the history
…40506)

Introduces SocketsConnectionFactory from  #40044 without the factory methods for NetworkStream and IDuplexPipe
  • Loading branch information
antonfirsov authored Aug 14, 2020
1 parent ff975e8 commit 8e5da14
Show file tree
Hide file tree
Showing 25 changed files with 779 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public abstract partial class SocketTestServer : IDisposable

protected abstract int Port { get; }
public abstract EndPoint EndPoint { get; }
public event Action<Socket> Accepted;

public static SocketTestServer SocketTestServerFactory(SocketImplementationType type, EndPoint endpoint, ProtocolType protocolType = ProtocolType.Tcp)
{
Expand All @@ -23,6 +24,9 @@ public static SocketTestServer SocketTestServerFactory(SocketImplementationType
return SocketTestServerFactory(type, DefaultNumConnections, DefaultReceiveBufferSize, address, out port);
}

public static SocketTestServer SocketTestServerFactory(SocketImplementationType type, IPAddress address)
=> SocketTestServerFactory(type, address, out _);

public static SocketTestServer SocketTestServerFactory(
SocketImplementationType type,
int numConnections,
Expand Down Expand Up @@ -60,5 +64,7 @@ public void Dispose()
}

protected abstract void Dispose(bool disposing);

protected void NotifyAccepted(Socket socket) => Accepted?.Invoke(socket);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private void OnAccept(IAsyncResult result)
return;
}

NotifyAccepted(client);
ServerSocketState state = new ServerSocketState(client, _receiveBufferSize);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ private void ProcessAccept(SocketAsyncEventArgs e)
// Get the socket for the accepted client connection and put it into the ReadEventArg object user token.
SocketAsyncEventArgs readEventArgs = _readWritePool.Pop();

NotifyAccepted(e.AcceptSocket);
((AsyncUserToken)readEventArgs.UserToken).Socket = e.AcceptSocket;

// As soon as the client is connected, post a receive to the connection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public static partial class PlatformDetection

public static bool IsNetCore => Environment.Version.Major >= 5 || RuntimeInformation.FrameworkDescription.StartsWith(".NET Core", StringComparison.OrdinalIgnoreCase);
public static bool IsMonoRuntime => Type.GetType("Mono.RuntimeStructs") != null;
public static bool IsNotMonoRuntime => !IsMonoRuntime;
public static bool IsMonoInterpreter => GetIsRunningOnMonoInterpreter();
public static bool IsFreeBSD => RuntimeInformation.IsOSPlatform(OSPlatform.Create("FREEBSD"));
public static bool IsNetBSD => RuntimeInformation.IsOSPlatform(OSPlatform.Create("NETBSD"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ public partial interface IConnectionProperties
{
bool TryGet(System.Type propertyKey, [System.Diagnostics.CodeAnalysis.NotNullWhenAttribute(true)] out object? property);
}
public partial class SocketsConnectionFactory : System.Net.Connections.ConnectionFactory
{
public SocketsConnectionFactory(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { }
public SocketsConnectionFactory(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { }
public override System.Threading.Tasks.ValueTask<System.Net.Connections.Connection> ConnectAsync(System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
protected virtual System.Net.Sockets.Socket CreateSocket(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties? options) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<ItemGroup>
<ProjectReference Include="..\..\System.Runtime\ref\System.Runtime.csproj" />
<ProjectReference Include="..\..\System.Net.Primitives\ref\System.Net.Primitives.csproj" />
<ProjectReference Include="..\..\System.Net.Sockets\ref\System.Net.Sockets.csproj" />
<ProjectReference Include="..\..\System.IO.Pipelines\ref\System.IO.Pipelines.csproj" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(CommonPath)System\Net\NetworkErrorHelper.cs" Link="Common\System\Net\NetworkErrorHelper.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskToApm.cs" Link="Common\System\Threading\Tasks\TaskToApm.cs" />
<Compile Include="System\Net\Connections\ConnectionBase.cs" />
<Compile Include="System\Net\Connections\ConnectionCloseMethod.cs" />
Expand All @@ -13,14 +14,20 @@
<Compile Include="System\Net\Connections\DuplexPipeStream.cs" />
<Compile Include="System\Net\Connections\ConnectionFactory.cs" />
<Compile Include="System\Net\Connections\ConnectionListener.cs" />
<Compile Include="System\Net\Connections\DuplexStreamPipe.cs" />
<Compile Include="System\Net\Connections\IConnectionProperties.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketConnection.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketsConnectionFactory.cs" />
<Compile Include="System\Net\Connections\Sockets\TaskSocketAsyncEventArgs.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="System.Runtime" />
<Reference Include="System.Memory" />
<Reference Include="System.Net.Primitives" />
<Reference Include="System.Net.Sockets" />
<Reference Include="System.Threading" />
<Reference Include="System.Threading.Tasks" />
<Reference Include="System.IO.Pipelines" />
<Reference Include="Microsoft.Win32.Primitives" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,6 @@ protected virtual IDuplexPipe CreatePipe()
}
}

private sealed class DuplexStreamPipe : IDuplexPipe
{
private static readonly StreamPipeReaderOptions s_readerOpts = new StreamPipeReaderOptions(leaveOpen: true);
private static readonly StreamPipeWriterOptions s_writerOpts = new StreamPipeWriterOptions(leaveOpen: true);

public DuplexStreamPipe(Stream stream)
{
Input = PipeReader.Create(stream, s_readerOpts);
Output = PipeWriter.Create(stream, s_writerOpts);
}

public PipeReader Input { get; }

public PipeWriter Output { get; }
}

/// <summary>
/// Creates a connection for a <see cref="Stream"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.IO.Pipelines;

namespace System.Net.Connections
{
internal sealed class DuplexStreamPipe : IDuplexPipe
{
private static readonly StreamPipeReaderOptions s_readerOpts = new StreamPipeReaderOptions(leaveOpen: true);
private static readonly StreamPipeWriterOptions s_writerOpts = new StreamPipeWriterOptions(leaveOpen: true);

public DuplexStreamPipe(Stream stream)
{
Input = PipeReader.Create(stream, s_readerOpts);
Output = PipeWriter.Create(stream, s_writerOpts);
}

public PipeReader Input { get; }

public PipeWriter Output { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Http;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Runtime.ExceptionServices;
using System.Threading;
Expand All @@ -13,15 +13,16 @@ namespace System.Net.Connections
{
internal sealed class SocketConnection : Connection, IConnectionProperties
{
private readonly NetworkStream _stream;
private readonly Socket _socket;
private Stream? _stream;

public override EndPoint? RemoteEndPoint => _stream.Socket.RemoteEndPoint;
public override EndPoint? LocalEndPoint => _stream.Socket.LocalEndPoint;
public override EndPoint? RemoteEndPoint => _socket.RemoteEndPoint;
public override EndPoint? LocalEndPoint => _socket.LocalEndPoint;
public override IConnectionProperties ConnectionProperties => this;

public SocketConnection(Socket socket)
{
_stream = new NetworkStream(socket, ownsSocket: true);
_socket = socket;
}

protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, CancellationToken cancellationToken)
Expand All @@ -37,10 +38,11 @@ protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, Cancel
{
// Dispose must be called first in order to cause a connection reset,
// as NetworkStream.Dispose() will call Shutdown(Both).
_stream.Socket.Dispose();
_socket.Dispose();
}

_stream.Dispose();
// Since CreatePipe() calls CreateStream(), so _stream should be present even in the pipe case:
_stream?.Dispose();
}
catch (SocketException socketException)
{
Expand All @@ -54,18 +56,18 @@ protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, Cancel
return default;
}

protected override Stream CreateStream() => _stream;

bool IConnectionProperties.TryGet(Type propertyKey, [NotNullWhen(true)] out object? property)
{
if (propertyKey == typeof(Socket))
{
property = _stream.Socket;
property = _socket;
return true;
}

property = null;
return false;
}

protected override Stream CreateStream() => _stream ??= new NetworkStream(_socket, true);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Connections
{
/// <summary>
/// A <see cref="ConnectionFactory"/> to establish socket-based connections.
/// </summary>
/// <remarks>
/// When constructed with <see cref="ProtocolType.Tcp"/>, this factory will create connections with <see cref="Socket.NoDelay"/> enabled.
/// In case of IPv6 sockets <see cref="Socket.DualMode"/> is also enabled.
/// </remarks>
public class SocketsConnectionFactory : ConnectionFactory
{
private readonly AddressFamily _addressFamily;
private readonly SocketType _socketType;
private readonly ProtocolType _protocolType;

/// <summary>
/// Initializes a new instance of the <see cref="SocketsConnectionFactory"/> class.
/// </summary>
/// <param name="addressFamily">The <see cref="AddressFamily"/> to forward to the socket.</param>
/// <param name="socketType">The <see cref="SocketType"/> to forward to the socket.</param>
/// <param name="protocolType">The <see cref="ProtocolType"/> to forward to the socket.</param>
public SocketsConnectionFactory(
AddressFamily addressFamily,
SocketType socketType,
ProtocolType protocolType)
{
_addressFamily = addressFamily;
_socketType = socketType;
_protocolType = protocolType;
}

/// <summary>
/// Initializes a new instance of the <see cref="SocketsConnectionFactory"/> class
/// that will forward <see cref="AddressFamily.InterNetworkV6"/> to the Socket constructor.
/// </summary>
/// <param name="socketType">The <see cref="SocketType"/> to forward to the socket.</param>
/// <param name="protocolType">The <see cref="ProtocolType"/> to forward to the socket.</param>
/// <remarks>The created socket will be an IPv6 socket with <see cref="Socket.DualMode"/> enabled.</remarks>
public SocketsConnectionFactory(SocketType socketType, ProtocolType protocolType)
: this(AddressFamily.InterNetworkV6, socketType, protocolType)
{
}

/// <inheritdoc />
/// <exception cref="ArgumentNullException">When <paramref name="endPoint"/> is <see langword="null"/>.</exception>
public override async ValueTask<Connection> ConnectAsync(
EndPoint? endPoint,
IConnectionProperties? options = null,
CancellationToken cancellationToken = default)
{
if (endPoint == null) throw new ArgumentNullException(nameof(endPoint));
cancellationToken.ThrowIfCancellationRequested();

Socket socket = CreateSocket(_addressFamily, _socketType, _protocolType, endPoint, options);

try
{
using var args = new TaskSocketAsyncEventArgs();
args.RemoteEndPoint = endPoint;

if (socket.ConnectAsync(args))
{
using (cancellationToken.UnsafeRegister(static o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args))
{
await args.Task.ConfigureAwait(false);
}
}

if (args.SocketError != SocketError.Success)
{
if (args.SocketError == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw NetworkErrorHelper.MapSocketException(new SocketException((int)args.SocketError));
}

return new SocketConnection(socket);
}
catch (SocketException socketException)
{
socket.Dispose();
throw NetworkErrorHelper.MapSocketException(socketException);
}
catch
{
socket.Dispose();
throw;
}
}

/// <summary>
/// Creates the socket that shall be used with the connection.
/// </summary>
/// <param name="addressFamily">The <see cref="AddressFamily"/> to forward to the socket.</param>
/// <param name="socketType">The <see cref="SocketType"/> to forward to the socket.</param>
/// <param name="protocolType">The <see cref="ProtocolType"/> to forward to the socket.</param>
/// <param name="endPoint">The <see cref="EndPoint"/> this socket will be connected to.</param>
/// <param name="options">Properties, if any, that might change how the socket is initialized.</param>
/// <returns>A new unconnected <see cref="Socket"/>.</returns>
/// <remarks>
/// In case of TCP sockets, the default implementation of this method will create a socket with <see cref="Socket.NoDelay"/> enabled.
/// In case of IPv6 sockets <see cref="Socket.DualMode"/> is also be enabled.
/// </remarks>
protected virtual Socket CreateSocket(
AddressFamily addressFamily,
SocketType socketType,
ProtocolType protocolType,
EndPoint? endPoint,
IConnectionProperties? options)
{
Socket socket = new Socket(addressFamily, socketType, protocolType);

if (protocolType == ProtocolType.Tcp)
{
socket.NoDelay = true;
}

if (addressFamily == AddressFamily.InterNetworkV6)
{
socket.DualMode = true;
}

return socket;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ internal sealed class TaskSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTas
public void GetResult(short token) => _valueTaskSource.GetResult(token);
public ValueTaskSourceStatus GetStatus(short token) => _valueTaskSource.GetStatus(token);
public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _valueTaskSource.OnCompleted(continuation, state, token, flags);
public void Complete() => _valueTaskSource.SetResult(0);

public TaskSocketAsyncEventArgs()
: base(unsafeSuppressExecutionContextFlow: true)
Expand Down
Loading

0 comments on commit 8e5da14

Please sign in to comment.