Skip to content

Commit

Permalink
Use 4-byte LPWStrs on non-Windows platforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Exanite committed Dec 6, 2024
1 parent ead35dd commit af1f79d
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 28 deletions.
32 changes: 29 additions & 3 deletions src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class TestSilkMarshal
};

[Fact]
public unsafe void TestEncodingLPWStr()
public unsafe void TestEncodingToLPWStr()
{
var input = "Hello world";

Expand All @@ -30,7 +30,7 @@ public unsafe void TestEncodingLPWStr()
Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));

// Use short for comparison
Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span<short>((void*)pointer, input.Length));
Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<short>((void*)pointer, input.Length + 1));
}
else
{
Expand All @@ -39,7 +39,33 @@ public unsafe void TestEncodingLPWStr()
Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));

// Use int for comparison
Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span<int>((void*)pointer, input.Length));
Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<int>((void*)pointer, input.Length + 1));
}
}

[Fact]
public unsafe void TestEncodingFromLPWStr()
{
var expected = "Hello world";

// LPWStr is 2 bytes on Windows, 4 bytes elsewhere (usually)
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var characters = new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
fixed (short* pCharacters = characters)
{
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
Assert.Equal(expected, output);
}
}
else
{
var characters = new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
fixed (int* pCharacters = characters)
{
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
Assert.Equal(expected, output);
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public enum NativeStringEncoding
LPStr = UnmanagedType.LPStr,
LPTStr = UnmanagedType.LPTStr,
LPUTF8Str = UnmanagedType.LPUTF8Str,
/// <summary>
/// On Windows, a 2-byte, null-terminated Unicode character string. On other platforms, each character will be 4 bytes instead.
/// </summary>
LPWStr = UnmanagedType.LPWStr,
WinString = UnmanagedType.WinString,
Ansi = LPStr,
Expand Down
139 changes: 114 additions & 25 deletions src/Core/Silk.NET.Core/Native/SilkMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
NativeStringEncoding.BStr => -1,
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 4,
_ => -1
};

Expand Down Expand Up @@ -198,19 +199,35 @@ public static unsafe int StringIntoSpan
span[convertedBytes] = 0;
return ++convertedBytes;
}
case NativeStringEncoding.LPWStr:
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}

return input.Length + 1;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
var maxLength = span.Length / 2;
var i = 0;
while (firstChar[i] != 0 && i < maxLength - 1)
{
((uint*)bytes)[i] = firstChar[i];
i++;
}

((uint*)bytes)[i] = default;

return i * 4;
}
}
default:
{
ThrowInvalidEncoding<GlobalMemory>();
Expand Down Expand Up @@ -238,7 +255,7 @@ public static nint AllocateString(int length, NativeStringEncoding encoding = Na
NativeStringEncoding.LPWStr => Allocate(length),
_ => ThrowInvalidEncoding<nint>()
};

/// <summary>
/// Free a string pointer
/// </summary>
Expand Down Expand Up @@ -311,7 +328,28 @@ static unsafe string BStrToString(nint ptr)
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));

static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
static unsafe string WideToString(nint ptr) => new string((char*) ptr);

static unsafe string WideToString(nint ptr)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
return new string((char*) ptr);
}
else
{
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
var characters = new ushort[length];
for (var i = 0; i < (uint)length; i++)
{
characters[i] = (ushort)((uint*)ptr)[i];
}

fixed (ushort* pCharacters = characters)
{
return new string((char*)pCharacters);
}
}
};
}

/// <summary>
Expand Down Expand Up @@ -524,15 +562,41 @@ Func<nint, string> customUnmarshaller
/// </remarks>
#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe nuint StringLength(
public static unsafe nuint StringLength
(
nint ptr,
NativeStringEncoding encoding = NativeStringEncoding.Ansi
) =>
(nuint)(
encoding == NativeStringEncoding.LPWStr
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
);
)
{
switch (encoding)
{
default:
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
// No int overload for CreateReadOnlySpanFromNullTerminated
if (ptr == 0)
{
return 0;
}

nuint length = 0;
while (((uint*) ptr)![length] != 0)
{
length++;
}

return length;
}
}
}

#else
public static unsafe nuint StringLength(
nint ptr,
Expand All @@ -543,15 +607,40 @@ public static unsafe nuint StringLength(
{
return 0;
}
nuint ret;
for (
ret = 0;
encoding == NativeStringEncoding.LPWStr
? ((char*)ptr)![ret] != 0
: ((byte*)ptr)![ret] != 0;
ret++
) { }
return ret;

nuint length = 0;
switch (encoding)
{
default:
{
while (((byte*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((char*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((uint*) ptr)![length] != 0)
{
length++;
}

break;
}
}

return length;
}
#endif

Expand Down

0 comments on commit af1f79d

Please sign in to comment.