Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert GetAllMountPoints to UnmanagedCallersOnly #73278

Merged
merged 5 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,52 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private unsafe delegate void MountPointFound(byte* name);
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetAllMountPoints")]
private static unsafe partial int GetAllMountPoints(delegate* unmanaged<void*, byte*, void> onFound, void* context);

[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetAllMountPoints", SetLastError = true)]
private static partial int GetAllMountPoints(MountPointFound mpf);
private struct AllMountPointsContext
{
internal List<string> _results;
internal ExceptionDispatchInfo? _exception;
}

[UnmanagedCallersOnly]
private static unsafe void AddMountPoint(void* context, byte* name)
{
ref AllMountPointsContext callbackContext = ref Unsafe.As<byte, AllMountPointsContext>(ref *(byte*)context);

try
{
callbackContext._results.Add(Marshal.PtrToStringUTF8((IntPtr)name)!);
}
catch (Exception e)
{
callbackContext._exception = ExceptionDispatchInfo.Capture(e);
}
}

internal static string[] GetAllMountPoints()
{
int count = 0;
var found = new string[4];
AllMountPointsContext context = default;
context._results = new List<string>();

unsafe
{
GetAllMountPoints((byte* name) =>
{
if (count == found.Length)
{
Array.Resize(ref found, count * 2);
}
found[count++] = Marshal.PtrToStringUTF8((IntPtr)name)!;
});
GetAllMountPoints(&AddMountPoint, Unsafe.AsPointer(ref context));
}

Array.Resize(ref found, count);
return found;
context._exception?.Throw();

return context._results.ToArray();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ internal struct HTTP_BINDING_INFO
internal static partial uint HttpInitialize(HTTPAPI_VERSION version, uint flags, IntPtr pReserved);

[LibraryImport(Libraries.HttpApi, SetLastError = true)]
internal static partial uint HttpSetUrlGroupProperty(ulong urlGroupId, HTTP_SERVER_PROPERTY serverProperty, IntPtr pPropertyInfo, uint propertyInfoLength);
internal static unsafe partial uint HttpSetUrlGroupProperty(ulong urlGroupId, HTTP_SERVER_PROPERTY serverProperty, void* pPropertyInfo, uint propertyInfoLength);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

[LibraryImport(Libraries.HttpApi, SetLastError = true)]
internal static unsafe partial uint HttpCreateServerSession(HTTPAPI_VERSION version, ulong* serverSessionId, uint reserved);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
Link="Common\Interop\Unix\Interop.MountPoints.FormatInfo.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="System.Collections" />
<Reference Include="System.Memory" />
<Reference Include="System.Runtime" />
<Reference Include="System.Runtime.InteropServices" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ public bool UnsafeConnectionNtlmAuthentication
private Dictionary<ulong, DisconnectAsyncResult> DisconnectResults =>
LazyInitializer.EnsureInitialized(ref _disconnectResults, () => new Dictionary<ulong, DisconnectAsyncResult>());

private void SetUrlGroupProperty(Interop.HttpApi.HTTP_SERVER_PROPERTY property, IntPtr info, uint infosize)
private unsafe void SetUrlGroupProperty(Interop.HttpApi.HTTP_SERVER_PROPERTY property, void* info, uint infosize)
{
Debug.Assert(_urlGroupId != 0, "SetUrlGroupProperty called with invalid url group id");
Debug.Assert(info != IntPtr.Zero, "SetUrlGroupProperty called with invalid pointer");
Debug.Assert(info != null, "SetUrlGroupProperty called with invalid pointer");

//
// Set the url group property using Http Api.
Expand Down Expand Up @@ -128,11 +128,9 @@ internal void SetServerTimeout(int[] timeouts, uint minSendBytesPerSecond)
(ushort)timeouts[(int)Interop.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait];
timeoutinfo.MinSendRate = minSendBytesPerSecond;

IntPtr infoptr = new IntPtr(&timeoutinfo);

SetUrlGroupProperty(
Interop.HttpApi.HTTP_SERVER_PROPERTY.HttpServerTimeoutsProperty,
infoptr, (uint)sizeof(Interop.HttpApi.HTTP_TIMEOUT_LIMIT_INFO));
&timeoutinfo, (uint)sizeof(Interop.HttpApi.HTTP_TIMEOUT_LIMIT_INFO));
}

public HttpListenerTimeoutManager TimeoutManager
Expand Down Expand Up @@ -307,10 +305,8 @@ private void AttachRequestQueueToUrlGroup()
info.Flags = Interop.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT;
info.RequestQueueHandle = _currentSession!.RequestQueueHandle.DangerousGetHandle();

IntPtr infoptr = new IntPtr(&info);

SetUrlGroupProperty(Interop.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty,
infoptr, (uint)sizeof(Interop.HttpApi.HTTP_BINDING_INFO));
&info, (uint)sizeof(Interop.HttpApi.HTTP_BINDING_INFO));
}

private void DetachRequestQueueFromUrlGroup()
Expand All @@ -328,11 +324,9 @@ private void DetachRequestQueueFromUrlGroup()
info.Flags = Interop.HttpApi.HTTP_FLAGS.NONE;
info.RequestQueueHandle = IntPtr.Zero;

IntPtr infoptr = new IntPtr(&info);

uint statusCode = Interop.HttpApi.HttpSetUrlGroupProperty(_urlGroupId,
Interop.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty,
infoptr, (uint)sizeof(Interop.HttpApi.HTTP_BINDING_INFO));
&info, (uint)sizeof(Interop.HttpApi.HTTP_BINDING_INFO));

if (statusCode != Interop.HttpApi.ERROR_SUCCESS)
{
Expand Down
8 changes: 4 additions & 4 deletions src/native/libs/System.Native/pal_mount.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#endif
#endif

int32_t SystemNative_GetAllMountPoints(MountPointFound onFound)
int32_t SystemNative_GetAllMountPoints(MountPointFound onFound, void* context)
{
#if HAVE_MNTINFO
// getmntinfo returns pointers to OS-internal structs, so we don't need to worry about free'ing the object
Expand All @@ -41,7 +41,7 @@ int32_t SystemNative_GetAllMountPoints(MountPointFound onFound)
int count = getmntinfo(&mounts, MNT_WAIT);
for (int32_t i = 0; i < count; i++)
{
onFound(mounts[i].f_mntonname);
onFound(context, mounts[i].f_mntonname);
}

return 0;
Expand All @@ -56,7 +56,7 @@ int32_t SystemNative_GetAllMountPoints(MountPointFound onFound)
struct mnttab entry;
while(getmntent(fp, &entry) == 0)
{
onFound(entry.mnt_mountp);
onFound(context, entry.mnt_mountp);
}

result = fclose(fp);
Expand All @@ -79,7 +79,7 @@ int32_t SystemNative_GetAllMountPoints(MountPointFound onFound)
struct mntent entry;
while (getmntent_r(fp, &entry, buffer, STRING_BUFFER_SIZE) != NULL)
{
onFound(entry.mnt_dir);
onFound(context, entry.mnt_dir);
}

result = endmntent(fp);
Expand Down
4 changes: 2 additions & 2 deletions src/native/libs/System.Native/pal_mount.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ typedef struct
* Using the callback pattern allows us to limit the number of allocs we do and makes it
* cleaner on the managed side since we don't have to worry about cleaning up any unmanaged memory.
*/
typedef void (*MountPointFound)(const char* name);
typedef void (*MountPointFound)(void* context, const char* name);

/**
* Gets the space information for the given mount point and populates the input struct with the data.
Expand All @@ -45,4 +45,4 @@ PALEXPORT int32_t SystemNative_GetFormatInfoForMountPoint(
* function pointer once-per-mount-point to prevent heap allocs
* as much as possible.
*/
PALEXPORT int32_t SystemNative_GetAllMountPoints(MountPointFound onFound);
PALEXPORT int32_t SystemNative_GetAllMountPoints(MountPointFound onFound, void* context);