diff --git a/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs b/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs index 874f95b1f3..ce0082b541 100644 --- a/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs +++ b/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs @@ -18,7 +18,7 @@ public class TestSilkMarshal }; [Fact] - public unsafe void TestEncodingLPWStr() + public unsafe void TestEncodingToLPWStr() { var input = "Hello world"; @@ -30,7 +30,7 @@ public unsafe void TestEncodingLPWStr() Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr)); // Use short for comparison - Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span((void*)pointer, input.Length)); + Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span((void*)pointer, input.Length + 1)); } else { @@ -39,7 +39,33 @@ public unsafe void TestEncodingLPWStr() Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr)); // Use int for comparison - Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span((void*)pointer, input.Length)); + Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span((void*)pointer, input.Length + 1)); + } + } + + [Fact] + public unsafe void TestEncodingFromLPWStr() + { + var expected = "Hello world"; + + // LPWStr is 2 bytes on Windows, 4 bytes elsewhere (usually) + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var characters = new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }; + fixed (short* pCharacters = characters) + { + var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr); + Assert.Equal(expected, output); + } + } + else + { + var characters = new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }; + fixed (int* pCharacters = characters) + { + var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr); + Assert.Equal(expected, output); + } } } diff --git a/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs b/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs index c759a97408..f0d7341299 100644 --- a/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs +++ b/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs @@ -9,6 +9,9 @@ public enum NativeStringEncoding LPStr = UnmanagedType.LPStr, LPTStr = UnmanagedType.LPTStr, LPUTF8Str = UnmanagedType.LPUTF8Str, + /// + /// On Windows, a 2-byte, null-terminated Unicode character string. On other platforms, each character will be 4 bytes instead. + /// LPWStr = UnmanagedType.LPWStr, WinString = UnmanagedType.WinString, Ansi = LPStr, diff --git a/src/Core/Silk.NET.Core/Native/SilkMarshal.cs b/src/Core/Silk.NET.Core/Native/SilkMarshal.cs index 9bebfc1d5e..91b304844f 100644 --- a/src/Core/Silk.NET.Core/Native/SilkMarshal.cs +++ b/src/Core/Silk.NET.Core/Native/SilkMarshal.cs @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na NativeStringEncoding.BStr => -1, NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str => (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1, - NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2, + NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2, + NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 4, _ => -1 }; @@ -198,19 +199,35 @@ public static unsafe int StringIntoSpan span[convertedBytes] = 0; return ++convertedBytes; } - case NativeStringEncoding.LPWStr: + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): { fixed (char* firstChar = input) + fixed (byte* bytes = span) { - fixed (byte* bytes = span) - { - Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2); - ((char*)bytes)[input.Length] = default; - } + Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2); + ((char*)bytes)[input.Length] = default; } return input.Length + 1; } + case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + fixed (char* firstChar = input) + fixed (byte* bytes = span) + { + var maxLength = span.Length / 2; + var i = 0; + while (firstChar[i] != 0 && i < maxLength - 1) + { + ((uint*)bytes)[i] = firstChar[i]; + i++; + } + + ((uint*)bytes)[i] = default; + + return i * 4; + } + } default: { ThrowInvalidEncoding(); @@ -238,7 +255,7 @@ public static nint AllocateString(int length, NativeStringEncoding encoding = Na NativeStringEncoding.LPWStr => Allocate(length), _ => ThrowInvalidEncoding() }; - + /// /// Free a string pointer /// @@ -311,7 +328,28 @@ static unsafe string BStrToString(nint ptr) => new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char))); static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr); - static unsafe string WideToString(nint ptr) => new string((char*) ptr); + + static unsafe string WideToString(nint ptr) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return new string((char*) ptr); + } + else + { + var length = StringLength(ptr, NativeStringEncoding.LPWStr); + var characters = new ushort[length]; + for (var i = 0; i < (uint)length; i++) + { + characters[i] = (ushort)((uint*)ptr)[i]; + } + + fixed (ushort* pCharacters = characters) + { + return new string((char*)pCharacters); + } + } + }; } /// @@ -524,15 +562,41 @@ Func customUnmarshaller /// #if NET6_0_OR_GREATER [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe nuint StringLength( + public static unsafe nuint StringLength + ( nint ptr, NativeStringEncoding encoding = NativeStringEncoding.Ansi - ) => - (nuint)( - encoding == NativeStringEncoding.LPWStr - ? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length - : MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length - ); + ) + { + switch (encoding) + { + default: + { + return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length; + } + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length; + } + case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + // No int overload for CreateReadOnlySpanFromNullTerminated + if (ptr == 0) + { + return 0; + } + + nuint length = 0; + while (((uint*) ptr)![length] != 0) + { + length++; + } + + return length; + } + } + } + #else public static unsafe nuint StringLength( nint ptr, @@ -543,15 +607,40 @@ public static unsafe nuint StringLength( { return 0; } - nuint ret; - for ( - ret = 0; - encoding == NativeStringEncoding.LPWStr - ? ((char*)ptr)![ret] != 0 - : ((byte*)ptr)![ret] != 0; - ret++ - ) { } - return ret; + + nuint length = 0; + switch (encoding) + { + default: + { + while (((byte*) ptr)![length] != 0) + { + length++; + } + + break; + } + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + while (((char*) ptr)![length] != 0) + { + length++; + } + + break; + } + case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + while (((uint*) ptr)![length] != 0) + { + length++; + } + + break; + } + } + + return length; } #endif