diff --git a/sys/windows/syscall_windows_test.go b/sys/windows/syscall_windows_test.go index 6178e4a..3cc18b4 100644 --- a/sys/windows/syscall_windows_test.go +++ b/sys/windows/syscall_windows_test.go @@ -224,6 +224,7 @@ func TestReadProcessUnicodeString(t *testing.T) { if err != nil { t.Fatal(err) } + defer syscall.CloseHandle(h) info, err := NtQueryProcessBasicInformation(h) if err != nil { t.Fatal(err) @@ -236,10 +237,36 @@ func TestReadProcessUnicodeString(t *testing.T) { if err != nil { t.Fatal(err) } - defer syscall.CloseHandle(h) assert.NoError(t, err) assert.NotEmpty(t, read) } + +const currentProcessHandle = syscall.Handle(^uintptr(0)) + +func TestReadProcessUnicodeStringTerminator(t *testing.T) { + data := []byte{'H', 0, 'E', 0, 'L', 0, 'L', 0, 'O', 0, 0, 0} + for n := len(data); n >= 0; n-- { + us := UnicodeString{ + Buffer: uintptr(unsafe.Pointer(&data[0])), + Size: uint16(n), + } + read, err := ReadProcessUnicodeString(currentProcessHandle, &us) + if err != nil { + t.Fatal(err) + } + nRead := len(read) + // Strings must match + assert.True(t, nRead >= n) + assert.Equal(t, data[:n], read[:n]) + // result is an array of uint16, can't have odd length. + assert.True(t, nRead&1 == 0) + // Must include a zero terminator at the end. + assert.True(t, nRead >= 2) + assert.Zero(t, read[nRead-1]) + assert.Zero(t, read[nRead-2]) + } +} + func TestReadProcessUnicodeStringInvalidHandle(t *testing.T) { var handle syscall.Handle var cmd = UnicodeString{Size: 5, MaximumLength: 400, Buffer: 400} @@ -264,6 +291,34 @@ func TestByteSliceToStringSliceEmptyBytes(t *testing.T) { assert.Empty(t, cmd) } +func mkUtf16bytes(s string) []byte { + n := len(s) + b := make([]byte, n * 2) + for idx, val := range s { + *(*uint16)(unsafe.Pointer(&b[idx*2])) = uint16(val) + } + return b +} + +func TestByteSliceToStringSliceNotTerminated(t *testing.T) { + b := mkUtf16bytes("Hello World") + cmd, err := ByteSliceToStringSlice(b) + assert.NoError(t, err) + assert.Len(t, cmd, 2) + assert.Equal(t, "Hello", cmd[0]) + assert.Equal(t, "World", cmd[1]) +} + + +func TestByteSliceToStringSliceNotOddSize(t *testing.T) { + b := mkUtf16bytes("BAD")[:5] + cmd, err := ByteSliceToStringSlice(b) + assert.NoError(t, err) + assert.Len(t, cmd, 1) + // Odd character is dropped + assert.Equal(t, "BA", cmd[0]) +} + func TestReadProcessMemory(t *testing.T) { h, err := syscall.OpenProcess(syscall.PROCESS_QUERY_INFORMATION|PROCESS_VM_READ, false, uint32(syscall.Getpid())) if err != nil {