Skip to content

Commit

Permalink
Fix behavior of SilkMarshal.StringToPtr and related methods on Linux (#…
Browse files Browse the repository at this point in the history
…2377)

* Add Silk.NET.Core.Tests project

* Add test cases for testing string encoding

* Fix issue where PtrToStringArray was ignoring the encoding parameter

* Add test case for testing LPWStr char width

* Use 4-byte LPWStrs on non-Windows platforms

* Use Encoding.UTF32 for PtrToString

* Simplify use of "when not Windows" clauses

* Also use Encoding.UTF32 for StringIntoSpan

* Update test cases and docs to show that LPWStr is UTF-32 on non-Windows

See #2377

* Don't test non-ascii characters in TestEncodingString/Array

This is because LPStr doesn't always support non-ascii characters.

---------

Co-authored-by: Dylan Perks <[email protected]>
  • Loading branch information
Exanite and Perksey authored Dec 7, 2024
1 parent 52f00d3 commit ee1d2f0
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 30 deletions.
44 changes: 44 additions & 0 deletions src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Silk.NET.Core.Native;
using Xunit;

Expand All @@ -15,6 +21,44 @@ public class TestSilkMarshal
NativeStringEncoding.LPWStr,
};

private readonly Encoding lpwStrEncoding = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
? Encoding.Unicode
: Encoding.UTF32;

private readonly int lpwStrCharacterWidth = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 2 : 4;

[Fact]
public unsafe void TestEncodingToLPWStr()
{
var input = "Hello world 🧵";

var expectedByteCount = lpwStrEncoding.GetByteCount(input);
var expected = new byte[expectedByteCount + lpwStrCharacterWidth];
lpwStrEncoding.GetBytes(input, expected);

var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr);
var pointerByteCount = lpwStrCharacterWidth * (int) SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr);

Assert.Equal(expected, new Span<byte>((void*)pointer, pointerByteCount + lpwStrCharacterWidth));
}

[Fact]
public unsafe void TestEncodingFromLPWStr()
{
var expected = "Hello world 🧵";

var inputByteCount = lpwStrEncoding.GetByteCount(expected);
var input = new byte[inputByteCount + lpwStrCharacterWidth];
lpwStrEncoding.GetBytes(expected, input);

fixed (byte* pInput = input)
{
var output = SilkMarshal.PtrToString((nint)pInput, NativeStringEncoding.LPWStr);

Assert.Equal(expected, output);
}
}

[Fact]
public void TestEncodingString()
{
Expand Down
3 changes: 3 additions & 0 deletions src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public enum NativeStringEncoding
LPStr = UnmanagedType.LPStr,
LPTStr = UnmanagedType.LPTStr,
LPUTF8Str = UnmanagedType.LPUTF8Str,
/// <summary>
/// On Windows, a null-terminated UTF-16 string. On other platforms, a null-terminated UTF-32 string.
/// </summary>
LPWStr = UnmanagedType.LPWStr,
WinString = UnmanagedType.WinString,
Ansi = LPStr,
Expand Down
133 changes: 103 additions & 30 deletions src/Core/Silk.NET.Core/Native/SilkMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
NativeStringEncoding.BStr => -1,
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 4,
_ => -1
};

Expand Down Expand Up @@ -188,29 +189,38 @@ public static unsafe int StringIntoSpan
int convertedBytes;

fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
}
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
bytes[convertedBytes] = 0;
}

span[convertedBytes] = 0;
return ++convertedBytes;
return convertedBytes + 1;
}
case NativeStringEncoding.LPWStr:
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}

return input.Length + 1;
}
case NativeStringEncoding.LPWStr:
{
int convertedBytes;

fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
convertedBytes = Encoding.UTF32.GetBytes(firstChar, input.Length, bytes, span.Length - 4);
((uint*)bytes)[convertedBytes / 4] = 0;
}

return convertedBytes + 4;
}
default:
{
ThrowInvalidEncoding<GlobalMemory>();
Expand Down Expand Up @@ -311,7 +321,19 @@ static unsafe string BStrToString(nint ptr)
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));

static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
static unsafe string WideToString(nint ptr) => new string((char*) ptr);

static unsafe string WideToString(nint ptr)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
return new string((char*) ptr);
}
else
{
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
return Encoding.UTF32.GetString((byte*) ptr, 4 * (int) length);
}
};
}

/// <summary>
Expand Down Expand Up @@ -524,15 +546,41 @@ Func<nint, string> customUnmarshaller
/// </remarks>
#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe nuint StringLength(
public static unsafe nuint StringLength
(
nint ptr,
NativeStringEncoding encoding = NativeStringEncoding.Ansi
) =>
(nuint)(
encoding == NativeStringEncoding.LPWStr
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
);
)
{
switch (encoding)
{
default:
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
}
case NativeStringEncoding.LPWStr:
{
// No int overload for CreateReadOnlySpanFromNullTerminated
if (ptr == 0)
{
return 0;
}

nuint length = 0;
while (((uint*) ptr)![length] != 0)
{
length++;
}

return length;
}
}
}

#else
public static unsafe nuint StringLength(
nint ptr,
Expand All @@ -543,15 +591,40 @@ public static unsafe nuint StringLength(
{
return 0;
}
nuint ret;
for (
ret = 0;
encoding == NativeStringEncoding.LPWStr
? ((char*)ptr)![ret] != 0
: ((byte*)ptr)![ret] != 0;
ret++
) { }
return ret;

nuint length = 0;
switch (encoding)
{
default:
{
while (((byte*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((char*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr:
{
while (((uint*) ptr)![length] != 0)
{
length++;
}

break;
}
}

return length;
}
#endif

Expand Down

0 comments on commit ee1d2f0

Please sign in to comment.