diff --git a/pkg/agent/util/syscall/syscall_windows.go b/pkg/agent/util/syscall/syscall_windows.go index c0b5d29e9ca..2624ed1b45c 100644 --- a/pkg/agent/util/syscall/syscall_windows.go +++ b/pkg/agent/util/syscall/syscall_windows.go @@ -308,10 +308,17 @@ type NetIOInterface interface { type netIO struct { syscallN func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) + // It needs be declared as a variable and replaced during unit tests because the real getIPForwardTable function + // converts a Pointer to a uintptr as an argument of syscallN, while converting a uintptr back to a Pointer in the + // fake syscallN is not valid. + getIPForwardTable func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) } func NewNetIO() NetIOInterface { - return &netIO{syscallN: syscall.SyscallN} + return &netIO{ + syscallN: syscall.SyscallN, + getIPForwardTable: getIPForwardTable, + } } func (n *netIO) GetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) { @@ -351,8 +358,8 @@ func (n *netIO) freeMibTable(table unsafe.Pointer) { return } -func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) { - r0, _, _ := n.syscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable))) +func getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) { + r0, _, _ := syscall.SyscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable))) if r0 != 0 { errcode = syscall.Errno(r0) } @@ -362,13 +369,16 @@ func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTa func (n *netIO) ListIPForwardRows(family uint16) ([]MibIPForwardRow, error) { var table *MibIPForwardTable err := n.getIPForwardTable(family, &table) - if table != nil { - defer n.freeMibTable(unsafe.Pointer(table)) - } if err != nil { return nil, os.NewSyscallError("iphlpapi.GetIpForwardTable", err) } - return unsafe.Slice(&table.Table[0], table.NumEntries), nil + defer n.freeMibTable(unsafe.Pointer(table)) + + // Copy the rows from the table into a new slice as the table's memory will be freed. + // Since MibIPForwardRow contains only value data (no references), the operation performs a deep copy. + rows := make([]MibIPForwardRow, 0, table.NumEntries) + rows = append(rows, unsafe.Slice(&table.Table[0], table.NumEntries)...) + return rows, nil } func NewIPForwardRow() *MibIPForwardRow { diff --git a/pkg/agent/util/syscall/syscall_windows_test.go b/pkg/agent/util/syscall/syscall_windows_test.go index e7f9d36be42..5b5b589d3d5 100644 --- a/pkg/agent/util/syscall/syscall_windows_test.go +++ b/pkg/agent/util/syscall/syscall_windows_test.go @@ -19,8 +19,10 @@ import ( "os" "syscall" "testing" + "unsafe" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRawSockAddrTranslation(t *testing.T) { @@ -202,16 +204,85 @@ func TestIPForwardEntryOperations(t *testing.T) { } } -func TestListIPForwardRows(t *testing.T) { +func TestListIPForwardRowsFailure(t *testing.T) { + testNetIO := &netIO{ + getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) { + return syscall.Errno(22) + }, + syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) { + assert.Fail(t, "freeMibTable shouldn't be called") + return + }, + } wantErr := os.NewSyscallError("iphlpapi.GetIpForwardTable", syscall.Errno(22)) - testNetIO := NewTestNetIO(22) - // Skipping no error case because converting uintptr back to Pointer is not valid in general. - gotRow, gotErr := testNetIO.ListIPForwardRows(AF_INET) - assert.Nil(t, gotRow) + gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET) + assert.Nil(t, gotRows) assert.Equal(t, wantErr, gotErr) } -func NewTestNetIO(wantR1 uintptr) NetIOInterface { +func TestListIPForwardRowsSuccess(t *testing.T) { + row1 := MibIPForwardRow{ + Luid: 10, + Index: 11, + DestinationPrefix: AddressPrefix{ + Prefix: RawSockAddrInet{ + Family: AF_INET, + data: [26]byte{10, 10, 10, 0}, + }, + prefixLength: 24, + }, + NextHop: RawSockAddrInet{ + Family: AF_INET, + data: [26]byte{11, 11, 11, 11}, + }, + } + row2 := MibIPForwardRow{ + Luid: 20, + Index: 21, + DestinationPrefix: AddressPrefix{ + Prefix: RawSockAddrInet{ + Family: AF_INET, + data: [26]byte{20, 20, 20, 0}, + }, + prefixLength: 24, + }, + NextHop: RawSockAddrInet{ + Family: AF_INET, + data: [26]byte{21, 21, 21, 21}, + }, + } + // The table contains two rows. Its memory address will be assigned to ipForwardTable when getIPForwardTable is called. + table := struct { + NumEntries uint32 + Table [2]MibIPForwardRow + }{ + NumEntries: 2, + Table: [2]MibIPForwardRow{row1, row2}, + } + freeMibTableCalled := false + testNetIO := &netIO{ + getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) { + *ipForwardTable = (*MibIPForwardTable)(unsafe.Pointer(&table)) + return nil + }, + syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) { + freeMibTableCalled = true + // Reset the rows. + table.Table[0] = MibIPForwardRow{} + table.Table[1] = MibIPForwardRow{} + return + }, + } + gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET) + require.NoError(t, gotErr) + assert.True(t, freeMibTableCalled) + // It verifies that the returned rows are independent copies, not referencing to the original table's memory, by + // asserting they retain the exact same content as the original table whose rows have been reset by freeMibTable. + expectedRows := []MibIPForwardRow{row1, row2} + assert.Equal(t, expectedRows, gotRows) +} + +func NewTestNetIO(wantR1 uintptr) *netIO { mockSyscallN := func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) { return wantR1, 0, 0 }