Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Ensure HttpListener request buffer is aligned as required by the host…
Browse files Browse the repository at this point in the history
… processor (#25563)

On Windows, the HttpRecieveHttpRequest function requires a buffer with a memory alignment greater than or equal to the required alignment of the HTTP_REQUEST struct.

This fix ensures that alignment requirements are respected when allocating buffers in HttpListener RequestContextBase. Since HttpReceiveHttpRequest copies both the HTTP_REQUEST struct and the variable length request body into the buffer, we need to be able to allocate a buffer with variable size and with a set alignment. Since C# does not provide a method for specifying the alignment of byte arrays, I switched the underlying buffer to be unmanaged. This unmanaged buffer is allocated using Marshal.AllocHGlobal, which allocates memory at the maximum alignment required by the host processor.
  • Loading branch information
rmkerr authored Dec 1, 2017
1 parent 2a6c455 commit 357327a
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 144 deletions.
180 changes: 83 additions & 97 deletions src/Common/src/Interop/Windows/HttpApi/Interop.HttpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -786,75 +786,69 @@ internal static unsafe string GetVerb(HTTP_REQUEST* request)
return GetVerb(request, 0);
}

internal static unsafe string GetVerb(byte[] memoryBlob, IntPtr originalAddress)
internal static unsafe string GetVerb(IntPtr memoryBlob, IntPtr originalAddress)
{
fixed (byte* pMemoryBlob = memoryBlob)
{
return GetVerb((HTTP_REQUEST*)pMemoryBlob, pMemoryBlob - (byte*)originalAddress);
}
return GetVerb((HTTP_REQUEST*)memoryBlob.ToPointer(), (byte*)memoryBlob - (byte*)originalAddress);
}

// Server API

internal static unsafe WebHeaderCollection GetHeaders(byte[] memoryBlob, IntPtr originalAddress)
internal static unsafe WebHeaderCollection GetHeaders(IntPtr memoryBlob, IntPtr originalAddress)
{
NetEventSource.Enter(null);

// Return value.
WebHeaderCollection headerCollection = new WebHeaderCollection();
fixed (byte* pMemoryBlob = memoryBlob)
{
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
long fixup = pMemoryBlob - (byte*)originalAddress;
int index;
byte* pMemoryBlob = (byte*)memoryBlob;
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
long fixup = pMemoryBlob - (byte*)originalAddress;
int index;

// unknown headers
if (request->Headers.UnknownHeaderCount != 0)
// unknown headers
if (request->Headers.UnknownHeaderCount != 0)
{
HTTP_UNKNOWN_HEADER* pUnknownHeader = (HTTP_UNKNOWN_HEADER*)(fixup + (byte*)request->Headers.pUnknownHeaders);
for (index = 0; index < request->Headers.UnknownHeaderCount; index++)
{
HTTP_UNKNOWN_HEADER* pUnknownHeader = (HTTP_UNKNOWN_HEADER*)(fixup + (byte*)request->Headers.pUnknownHeaders);
for (index = 0; index < request->Headers.UnknownHeaderCount; index++)
// For unknown headers, when header value is empty, RawValueLength will be 0 and
// pRawValue will be null.
if (pUnknownHeader->pName != null && pUnknownHeader->NameLength > 0)
{
// For unknown headers, when header value is empty, RawValueLength will be 0 and
// pRawValue will be null.
if (pUnknownHeader->pName != null && pUnknownHeader->NameLength > 0)
string headerName = new string(pUnknownHeader->pName + fixup, 0, pUnknownHeader->NameLength);
string headerValue;
if (pUnknownHeader->pRawValue != null && pUnknownHeader->RawValueLength > 0)
{
string headerName = new string(pUnknownHeader->pName + fixup, 0, pUnknownHeader->NameLength);
string headerValue;
if (pUnknownHeader->pRawValue != null && pUnknownHeader->RawValueLength > 0)
{
headerValue = new string(pUnknownHeader->pRawValue + fixup, 0, pUnknownHeader->RawValueLength);
}
else
{
headerValue = string.Empty;
}
headerCollection.Add(headerName, headerValue);
headerValue = new string(pUnknownHeader->pRawValue + fixup, 0, pUnknownHeader->RawValueLength);
}
pUnknownHeader++;
else
{
headerValue = string.Empty;
}
headerCollection.Add(headerName, headerValue);
}
pUnknownHeader++;
}
}

// known headers
HTTP_KNOWN_HEADER* pKnownHeader = &request->Headers.KnownHeaders;
for (index = 0; index < HttpHeaderRequestMaximum; index++)
// known headers
HTTP_KNOWN_HEADER* pKnownHeader = &request->Headers.KnownHeaders;
for (index = 0; index < HttpHeaderRequestMaximum; index++)
{
// For known headers, when header value is empty, RawValueLength will be 0 and
// pRawValue will point to empty string ("\0")
if (pKnownHeader->pRawValue != null)
{
// For known headers, when header value is empty, RawValueLength will be 0 and
// pRawValue will point to empty string ("\0")
if (pKnownHeader->pRawValue != null)
{
string headerValue = new string(pKnownHeader->pRawValue + fixup, 0, pKnownHeader->RawValueLength);
headerCollection.Add(HTTP_REQUEST_HEADER_ID.ToString(index), headerValue);
}
pKnownHeader++;
string headerValue = new string(pKnownHeader->pRawValue + fixup, 0, pKnownHeader->RawValueLength);
headerCollection.Add(HTTP_REQUEST_HEADER_ID.ToString(index), headerValue);
}
pKnownHeader++;
}

NetEventSource.Exit(null);
return headerCollection;
}


internal static unsafe uint GetChunks(byte[] memoryBlob, IntPtr originalAddress, ref int dataChunkIndex, ref uint dataChunkOffset, byte[] buffer, int offset, int size)
internal static unsafe uint GetChunks(IntPtr memoryBlob, IntPtr originalAddress, ref int dataChunkIndex, ref uint dataChunkOffset, byte[] buffer, int offset, int size)
{
if (NetEventSource.IsEnabled)
{
Expand All @@ -863,92 +857,86 @@ internal static unsafe uint GetChunks(byte[] memoryBlob, IntPtr originalAddress,

// Return value.
uint dataRead = 0;
fixed (byte* pMemoryBlob = memoryBlob)
byte* pMemoryBlob = (byte*)memoryBlob;
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
long fixup = pMemoryBlob - (byte*)originalAddress;

if (request->EntityChunkCount > 0 && dataChunkIndex < request->EntityChunkCount && dataChunkIndex != -1)
{
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
long fixup = pMemoryBlob - (byte*)originalAddress;
HTTP_DATA_CHUNK* pDataChunk = (HTTP_DATA_CHUNK*)(fixup + (byte*)&request->pEntityChunks[dataChunkIndex]);

if (request->EntityChunkCount > 0 && dataChunkIndex < request->EntityChunkCount && dataChunkIndex != -1)
fixed (byte* pReadBuffer = buffer)
{
HTTP_DATA_CHUNK* pDataChunk = (HTTP_DATA_CHUNK*)(fixup + (byte*)&request->pEntityChunks[dataChunkIndex]);
byte* pTo = &pReadBuffer[offset];

fixed (byte* pReadBuffer = buffer)
while (dataChunkIndex < request->EntityChunkCount && dataRead < size)
{
byte* pTo = &pReadBuffer[offset];

while (dataChunkIndex < request->EntityChunkCount && dataRead < size)
if (dataChunkOffset >= pDataChunk->BufferLength)
{
dataChunkOffset = 0;
dataChunkIndex++;
pDataChunk++;
}
else
{
if (dataChunkOffset >= pDataChunk->BufferLength)
byte* pFrom = pDataChunk->pBuffer + dataChunkOffset + fixup;

uint bytesToRead = pDataChunk->BufferLength - (uint)dataChunkOffset;
if (bytesToRead > (uint)size)
{
dataChunkOffset = 0;
dataChunkIndex++;
pDataChunk++;
bytesToRead = (uint)size;
}
else
for (uint i = 0; i < bytesToRead; i++)
{
byte* pFrom = pDataChunk->pBuffer + dataChunkOffset + fixup;

uint bytesToRead = pDataChunk->BufferLength - (uint)dataChunkOffset;
if (bytesToRead > (uint)size)
{
bytesToRead = (uint)size;
}
for (uint i = 0; i < bytesToRead; i++)
{
*(pTo++) = *(pFrom++);
}
dataRead += bytesToRead;
dataChunkOffset += bytesToRead;
*(pTo++) = *(pFrom++);
}
dataRead += bytesToRead;
dataChunkOffset += bytesToRead;
}
}
}
//we're finished.
if (dataChunkIndex == request->EntityChunkCount)
{
dataChunkIndex = -1;
}
}

//we're finished.
if (dataChunkIndex == request->EntityChunkCount)
{
dataChunkIndex = -1;
}

if (NetEventSource.IsEnabled)
{
NetEventSource.Exit(null);
}
return dataRead;
}

internal static unsafe HTTP_VERB GetKnownVerb(byte[] memoryBlob, IntPtr originalAddress)
internal static unsafe HTTP_VERB GetKnownVerb(IntPtr memoryBlob, IntPtr originalAddress)
{
NetEventSource.Enter(null);

// Return value.
HTTP_VERB verb = HTTP_VERB.HttpVerbUnknown;
fixed (byte* pMemoryBlob = memoryBlob)

HTTP_REQUEST* request = (HTTP_REQUEST*)memoryBlob.ToPointer();
if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnparsed && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum)
{
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnparsed && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum)
{
verb = request->Verb;
}
verb = request->Verb;
}

NetEventSource.Exit(null);
return verb;
}

internal static unsafe IPEndPoint GetRemoteEndPoint(byte[] memoryBlob, IntPtr originalAddress)
internal static unsafe IPEndPoint GetRemoteEndPoint(IntPtr memoryBlob, IntPtr originalAddress)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(null);

SocketAddress v4address = new SocketAddress(AddressFamily.InterNetwork, IPv4AddressSize);
SocketAddress v6address = new SocketAddress(AddressFamily.InterNetworkV6, IPv6AddressSize);

fixed (byte* pMemoryBlob = memoryBlob)
{
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
IntPtr address = request->Address.pRemoteAddress != null ? (IntPtr)(pMemoryBlob - (byte*)originalAddress + (byte*)request->Address.pRemoteAddress) : IntPtr.Zero;
CopyOutAddress(address, ref v4address, ref v6address);
}
byte* pMemoryBlob = (byte*)memoryBlob;
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
IntPtr address = request->Address.pRemoteAddress != null ? (IntPtr)(pMemoryBlob - (byte*)originalAddress + (byte*)request->Address.pRemoteAddress) : IntPtr.Zero;
CopyOutAddress(address, ref v4address, ref v6address);

IPEndPoint endpoint = null;
if (v4address != null)
Expand All @@ -964,19 +952,17 @@ internal static unsafe IPEndPoint GetRemoteEndPoint(byte[] memoryBlob, IntPtr or
return endpoint;
}

internal static unsafe IPEndPoint GetLocalEndPoint(byte[] memoryBlob, IntPtr originalAddress)
internal static unsafe IPEndPoint GetLocalEndPoint(IntPtr memoryBlob, IntPtr originalAddress)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(null);

SocketAddress v4address = new SocketAddress(AddressFamily.InterNetwork, IPv4AddressSize);
SocketAddress v6address = new SocketAddress(AddressFamily.InterNetworkV6, IPv6AddressSize);

fixed (byte* pMemoryBlob = memoryBlob)
{
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
IntPtr address = request->Address.pLocalAddress != null ? (IntPtr)(pMemoryBlob - (byte*)originalAddress + (byte*)request->Address.pLocalAddress) : IntPtr.Zero;
CopyOutAddress(address, ref v4address, ref v6address);
}
byte* pMemoryBlob = (byte*)memoryBlob;
HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob;
IntPtr address = request->Address.pLocalAddress != null ? (IntPtr)(pMemoryBlob - (byte*)originalAddress + (byte*)request->Address.pLocalAddress) : IntPtr.Zero;
CopyOutAddress(address, ref v4address, ref v6address);

IPEndPoint endpoint = null;
if (v4address != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal AsyncRequestContext(ThreadPoolBoundHandle boundHandle, ListenerAsyncRes

private Interop.HttpApi.HTTP_REQUEST* Allocate(ThreadPoolBoundHandle boundHandle, uint size)
{
uint newSize = size != 0 ? size : RequestBuffer == null ? 4096 : Size;
uint newSize = size != 0 ? size : RequestBuffer == IntPtr.Zero ? 4096 : Size;
if (_nativeOverlapped != null)
{
#if DEBUG
Expand All @@ -57,7 +57,7 @@ internal AsyncRequestContext(ThreadPoolBoundHandle boundHandle, ListenerAsyncRes
_boundHandle = boundHandle;
_nativeOverlapped = boundHandle.AllocateNativeOverlapped(ListenerAsyncResult.IOCallback, state: _result, pinData: RequestBuffer);

return (Interop.HttpApi.HTTP_REQUEST*)Marshal.UnsafeAddrOfPinnedArrayElement(RequestBuffer, 0);
return (Interop.HttpApi.HTTP_REQUEST*)RequestBuffer.ToPointer();
}

internal void Reset(ThreadPoolBoundHandle boundHandle, ulong requestId, uint size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ internal HttpListenerRequest(HttpListenerContext httpContext, RequestContextBase
// Note: RequestBuffer may get moved in memory. If you dereference a pointer from inside the RequestBuffer,
// you must use 'OriginalBlobAddress' below to adjust the location of the pointer to match the location of
// RequestBuffer.
internal byte[] RequestBuffer
internal IntPtr RequestBuffer
{
get
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Runtime.InteropServices;

namespace System.Net
{
internal abstract unsafe class RequestContextBase : IDisposable
{
private Interop.HttpApi.HTTP_REQUEST* _memoryBlob;
private Interop.HttpApi.HTTP_REQUEST* _originalBlobAddress;
private byte[] _backingBuffer;
private IntPtr _backingBuffer = IntPtr.Zero;
private int _backingBufferLength = 0;

// Must call this from derived class' constructors.
protected void BaseConstruction(Interop.HttpApi.HTTP_REQUEST* requestBlob)
Expand All @@ -29,7 +31,7 @@ protected void BaseConstruction(Interop.HttpApi.HTTP_REQUEST* requestBlob)
// before an object (HttpListenerRequest) which closes the RequestContext on demand is returned to the application.
internal void ReleasePins()
{
Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::ReleasePins()|ReleasePins() called twice.");
Debug.Assert(_memoryBlob != null || _backingBuffer == IntPtr.Zero, "RequestContextBase::ReleasePins()|ReleasePins() called twice.");
_originalBlobAddress = _memoryBlob;
UnsetBlob();
OnReleasePins();
Expand All @@ -48,7 +50,14 @@ public void Dispose()
Dispose(true);
}

protected virtual void Dispose(bool disposing) { }
protected virtual void Dispose(bool disposing)
{
if (_backingBuffer != IntPtr.Zero)
{
Marshal.FreeHGlobal(_backingBuffer);
_backingBuffer = IntPtr.Zero;
}
}

~RequestContextBase()
{
Expand All @@ -59,12 +68,12 @@ internal Interop.HttpApi.HTTP_REQUEST* RequestBlob
{
get
{
Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::Dispose()|RequestBlob requested after ReleasePins().");
Debug.Assert(_memoryBlob != null || _backingBuffer == IntPtr.Zero, "RequestContextBase::Dispose()|RequestBlob requested after ReleasePins().");
return _memoryBlob;
}
}

internal byte[] RequestBuffer
internal IntPtr RequestBuffer
{
get
{
Expand All @@ -76,7 +85,7 @@ internal uint Size
{
get
{
return (uint)_backingBuffer.Length;
return (uint)_backingBufferLength;
}
}

Expand All @@ -91,7 +100,7 @@ internal IntPtr OriginalBlobAddress

protected void SetBlob(Interop.HttpApi.HTTP_REQUEST* requestBlob)
{
Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::Dispose()|SetBlob() called after ReleasePins().");
Debug.Assert(_memoryBlob != null || _backingBuffer == IntPtr.Zero, "RequestContextBase::Dispose()|SetBlob() called after ReleasePins().");
if (requestBlob == null)
{
UnsetBlob();
Expand All @@ -116,7 +125,16 @@ protected void UnsetBlob()

protected void SetBuffer(int size)
{
_backingBuffer = size == 0 ? null : new byte[size];
if (_backingBuffer != IntPtr.Zero)
{
Marshal.FreeHGlobal(_backingBuffer);
}

_backingBuffer = size == 0 ? IntPtr.Zero : Marshal.AllocHGlobal(size);
_backingBufferLength = size;

// Zero out the contents of the buffer.
new Span<byte>(_backingBuffer.ToPointer(), size).Fill(0);
}
}
}
Loading

0 comments on commit 357327a

Please sign in to comment.