Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Merge pull request #33300 from luigiberrettini/keep-alive-iocontrol-f…
Browse files Browse the repository at this point in the history
…rom-setsockopt

Use *SocketOption for keep-alive on Windows versions less than 10 v1709
  • Loading branch information
stephentoub authored Nov 15, 2018
2 parents e0cb81d + a075c33 commit 3b358e2
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/System.Net.Sockets/src/System.Net.Sockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
<Compile Include="System\Net\Sockets\SendPacketsElementFlags.Windows.cs" />
<Compile Include="System\Net\Sockets\Socket.Windows.cs" />
<Compile Include="System\Net\Sockets\SocketAsyncEventArgs.Windows.cs" />
<Compile Include="System\Net\Sockets\IOControlKeepAlive.Windows.cs" />
<Compile Include="System\Net\Sockets\SocketPal.Windows.cs" />
<Compile Include="System\Net\Sockets\TransmitFileAsyncResult.Windows.cs" />
<Compile Include="System\Net\Sockets\UnixDomainSocketEndPoint.Windows.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Runtime.CompilerServices;

namespace System.Net.Sockets
{
// Allow hiding keep-alive time and interval handling behind *SocketOption
// on Windows < 10 v1709 that only supports set via IOControl
internal sealed class IOControlKeepAlive
{
private const uint WindowsDefaultTimeMs = 7200000u;
private const uint WindowsDefaultIntervalMs = 1000u;
private static readonly bool s_supportsKeepAliveViaSocketOption = SupportsKeepAliveViaSocketOption();
private static readonly ConditionalWeakTable<SafeSocketHandle, IOControlKeepAlive> s_socketKeepAliveTable = new ConditionalWeakTable<SafeSocketHandle, IOControlKeepAlive>();
[ThreadStatic]
private static byte[] s_keepAliveValuesBuffer;

private uint _timeMs = WindowsDefaultTimeMs;
private uint _intervalMs = WindowsDefaultIntervalMs;

public static bool IsNeeded => !s_supportsKeepAliveViaSocketOption;

public static SocketError Get(SafeSocketHandle handle, SocketOptionName optionName, byte[] optionValueSeconds, ref int optionLength)
{
if (optionValueSeconds == null ||
!BitConverter.TryWriteBytes(optionValueSeconds.AsSpan(), Get(handle, optionName)))
{
return SocketError.Fault;
}

optionLength = optionValueSeconds.Length;
return SocketError.Success;
}

public static int Get(SafeSocketHandle handle, SocketOptionName optionName)
{
if (s_socketKeepAliveTable.TryGetValue(handle, out IOControlKeepAlive ioControlKeepAlive))
{
return optionName == SocketOptionName.TcpKeepAliveTime ?
MillisecondsToSeconds(ioControlKeepAlive._timeMs) :
MillisecondsToSeconds(ioControlKeepAlive._intervalMs);
}

return optionName == SocketOptionName.TcpKeepAliveTime ?
MillisecondsToSeconds(WindowsDefaultTimeMs) :
MillisecondsToSeconds(WindowsDefaultIntervalMs);
}

public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionName, byte[] optionValueSeconds)
{
if (optionValueSeconds == null ||
optionValueSeconds.Length < sizeof(int))
{
return SocketError.Fault;
}

return Set(handle, optionName, BitConverter.ToInt32(optionValueSeconds, 0));
}

public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionName, int optionValueSeconds)
{
IOControlKeepAlive ioControlKeepAlive = s_socketKeepAliveTable.GetOrCreateValue(handle);
if (optionName == SocketOptionName.TcpKeepAliveTime)
{
ioControlKeepAlive._timeMs = SecondsToMilliseconds(optionValueSeconds);
}
else
{
ioControlKeepAlive._intervalMs = SecondsToMilliseconds(optionValueSeconds);
}

byte[] buffer = s_keepAliveValuesBuffer ?? (s_keepAliveValuesBuffer = new byte[3 * sizeof(uint)]);
ioControlKeepAlive.Fill(buffer);
int realOptionLength = 0;
return SocketPal.WindowsIoctl(handle, unchecked((int)IOControlCode.KeepAliveValues), buffer, null, out realOptionLength);
}

private static bool SupportsKeepAliveViaSocketOption()
{
AddressFamily addressFamily = Socket.OSSupportsIPv4 ? AddressFamily.InterNetwork : AddressFamily.InterNetworkV6;
using (Socket socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp))
{
int time = MillisecondsToSeconds(WindowsDefaultTimeMs);
SocketError timeErrCode = Interop.Winsock.setsockopt(
socket.SafeHandle,
SocketOptionLevel.Tcp,
SocketOptionName.TcpKeepAliveTime,
ref time,
sizeof(int));

int interval = MillisecondsToSeconds(WindowsDefaultIntervalMs);
SocketError intervalErrCode = Interop.Winsock.setsockopt(
socket.SafeHandle,
SocketOptionLevel.Tcp,
SocketOptionName.TcpKeepAliveInterval,
ref interval,
sizeof(int));

return
timeErrCode == SocketError.Success &&
intervalErrCode == SocketError.Success;
}
}

private static int MillisecondsToSeconds(uint milliseconds) => (int)(milliseconds / 1000u);

private static uint SecondsToMilliseconds(int seconds) => (uint)seconds * 1000u;

private void Fill(byte[] buffer)
{
Debug.Assert(buffer != null);
Debug.Assert(buffer.Length == 3 * sizeof(uint));

const uint OnOff = 1u;
bool written =
BitConverter.TryWriteBytes(buffer.AsSpan(), OnOff) |
BitConverter.TryWriteBytes(buffer.AsSpan(sizeof(uint)), _timeMs) |
BitConverter.TryWriteBytes(buffer.AsSpan(sizeof(uint) * 2), _intervalMs);
Debug.Assert(written);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ public static unsafe SocketError SetSockOpt(SafeSocketHandle handle, SocketOptio
{
fixed (byte* pinnedValue = optionValue)
{
Interop.Error err = Interop.Sys.SetSockOpt(handle, optionLevel, optionName, pinnedValue, optionValue.Length);
Interop.Error err = Interop.Sys.SetSockOpt(handle, optionLevel, optionName, pinnedValue, optionValue != null ? optionValue.Length : 0);
return GetErrorAndTrackSetting(handle, optionLevel, optionName, err);
}
}
Expand Down
62 changes: 49 additions & 13 deletions src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public static SocketError Connect(SafeSocketHandle handle, byte[] peerAddress, i
IntPtr.Zero);
return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success;
}

public static SocketError Send(SafeSocketHandle handle, IList<ArraySegment<byte>> buffers, SocketFlags socketFlags, out int bytesTransferred)
{
const int StackThreshold = 16; // arbitrary limit to avoid too much space on stack (note: may be over-sized, that's OK - length passed separately)
Expand Down Expand Up @@ -489,24 +490,44 @@ public static SocketError WindowsIoctl(SafeSocketHandle handle, int ioControlCod

public static unsafe SocketError SetSockOpt(SafeSocketHandle handle, SocketOptionLevel optionLevel, SocketOptionName optionName, int optionValue)
{
SocketError errorCode = Interop.Winsock.setsockopt(
handle,
optionLevel,
optionName,
ref optionValue,
sizeof(int));
SocketError errorCode;
if (optionLevel == SocketOptionLevel.Tcp &&
(optionName == SocketOptionName.TcpKeepAliveTime || optionName == SocketOptionName.TcpKeepAliveInterval) &&
IOControlKeepAlive.IsNeeded)
{
errorCode = IOControlKeepAlive.Set(handle, optionName, optionValue);
}
else
{
errorCode = Interop.Winsock.setsockopt(
handle,
optionLevel,
optionName,
ref optionValue,
sizeof(int));
}
return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success;
}

public static SocketError SetSockOpt(SafeSocketHandle handle, SocketOptionLevel optionLevel, SocketOptionName optionName, byte[] optionValue)
{
SocketError errorCode = Interop.Winsock.setsockopt(
handle,
optionLevel,
optionName,
optionValue,
optionValue != null ? optionValue.Length : 0);
return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success;
SocketError errorCode;
if (optionLevel == SocketOptionLevel.Tcp &&
(optionName == SocketOptionName.TcpKeepAliveTime || optionName == SocketOptionName.TcpKeepAliveInterval) &&
IOControlKeepAlive.IsNeeded)
{
return IOControlKeepAlive.Set(handle, optionName, optionValue);
}
else
{
errorCode = Interop.Winsock.setsockopt(
handle,
optionLevel,
optionName,
optionValue,
optionValue != null ? optionValue.Length : 0);
return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success;
}
}

public static void SetReceivingDualModeIPv4PacketInformation(Socket socket)
Expand Down Expand Up @@ -599,6 +620,14 @@ public static void SetIPProtectionLevel(Socket socket, SocketOptionLevel optionL

public static SocketError GetSockOpt(SafeSocketHandle handle, SocketOptionLevel optionLevel, SocketOptionName optionName, out int optionValue)
{
if (optionLevel == SocketOptionLevel.Tcp &&
(optionName == SocketOptionName.TcpKeepAliveTime || optionName == SocketOptionName.TcpKeepAliveInterval) &&
IOControlKeepAlive.IsNeeded)
{
optionValue = IOControlKeepAlive.Get(handle, optionName);
return SocketError.Success;
}

int optionLength = 4; // sizeof(int)
SocketError errorCode = Interop.Winsock.getsockopt(
handle,
Expand All @@ -611,6 +640,13 @@ public static SocketError GetSockOpt(SafeSocketHandle handle, SocketOptionLevel

public static SocketError GetSockOpt(SafeSocketHandle handle, SocketOptionLevel optionLevel, SocketOptionName optionName, byte[] optionValue, ref int optionLength)
{
if (optionLevel == SocketOptionLevel.Tcp &&
(optionName == SocketOptionName.TcpKeepAliveTime || optionName == SocketOptionName.TcpKeepAliveInterval) &&
IOControlKeepAlive.IsNeeded)
{
return IOControlKeepAlive.Get(handle, optionName, optionValue, ref optionLength);
}

SocketError errorCode = Interop.Winsock.getsockopt(
handle,
optionLevel,
Expand Down
Loading

0 comments on commit 3b358e2

Please sign in to comment.