Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated cherry pick of #6664: Replace unsafe.Slice with memory copying to avoid potential #6672

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions pkg/agent/util/syscall/syscall_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,17 @@ type NetIOInterface interface {

type netIO struct {
syscallN func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno)
// It needs be declared as a variable and replaced during unit tests because the real getIPForwardTable function
// converts a Pointer to a uintptr as an argument of syscallN, while converting a uintptr back to a Pointer in the
// fake syscallN is not valid.
getIPForwardTable func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error)
}

func NewNetIO() NetIOInterface {
return &netIO{syscallN: syscall.SyscallN}
return &netIO{
syscallN: syscall.SyscallN,
getIPForwardTable: getIPForwardTable,
}
}

func (n *netIO) GetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) {
Expand Down Expand Up @@ -351,8 +358,8 @@ func (n *netIO) freeMibTable(table unsafe.Pointer) {
return
}

func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
r0, _, _ := n.syscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable)))
func getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
r0, _, _ := syscall.SyscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable)))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
Expand All @@ -362,13 +369,16 @@ func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTa
func (n *netIO) ListIPForwardRows(family uint16) ([]MibIPForwardRow, error) {
var table *MibIPForwardTable
err := n.getIPForwardTable(family, &table)
if table != nil {
defer n.freeMibTable(unsafe.Pointer(table))
}
if err != nil {
return nil, os.NewSyscallError("iphlpapi.GetIpForwardTable", err)
}
return unsafe.Slice(&table.Table[0], table.NumEntries), nil
defer n.freeMibTable(unsafe.Pointer(table))

// Copy the rows from the table into a new slice as the table's memory will be freed.
// Since MibIPForwardRow contains only value data (no references), the operation performs a deep copy.
rows := make([]MibIPForwardRow, 0, table.NumEntries)
rows = append(rows, unsafe.Slice(&table.Table[0], table.NumEntries)...)
return rows, nil
}

func NewIPForwardRow() *MibIPForwardRow {
Expand Down
83 changes: 77 additions & 6 deletions pkg/agent/util/syscall/syscall_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (
"os"
"syscall"
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRawSockAddrTranslation(t *testing.T) {
Expand Down Expand Up @@ -202,16 +204,85 @@ func TestIPForwardEntryOperations(t *testing.T) {
}
}

func TestListIPForwardRows(t *testing.T) {
func TestListIPForwardRowsFailure(t *testing.T) {
testNetIO := &netIO{
getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
return syscall.Errno(22)
},
syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
assert.Fail(t, "freeMibTable shouldn't be called")
return
},
}
wantErr := os.NewSyscallError("iphlpapi.GetIpForwardTable", syscall.Errno(22))
testNetIO := NewTestNetIO(22)
// Skipping no error case because converting uintptr back to Pointer is not valid in general.
gotRow, gotErr := testNetIO.ListIPForwardRows(AF_INET)
assert.Nil(t, gotRow)
gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET)
assert.Nil(t, gotRows)
assert.Equal(t, wantErr, gotErr)
}

func NewTestNetIO(wantR1 uintptr) NetIOInterface {
func TestListIPForwardRowsSuccess(t *testing.T) {
row1 := MibIPForwardRow{
Luid: 10,
Index: 11,
DestinationPrefix: AddressPrefix{
Prefix: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{10, 10, 10, 0},
},
prefixLength: 24,
},
NextHop: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{11, 11, 11, 11},
},
}
row2 := MibIPForwardRow{
Luid: 20,
Index: 21,
DestinationPrefix: AddressPrefix{
Prefix: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{20, 20, 20, 0},
},
prefixLength: 24,
},
NextHop: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{21, 21, 21, 21},
},
}
// The table contains two rows. Its memory address will be assigned to ipForwardTable when getIPForwardTable is called.
table := struct {
NumEntries uint32
Table [2]MibIPForwardRow
}{
NumEntries: 2,
Table: [2]MibIPForwardRow{row1, row2},
}
freeMibTableCalled := false
testNetIO := &netIO{
getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
*ipForwardTable = (*MibIPForwardTable)(unsafe.Pointer(&table))
return nil
},
syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
freeMibTableCalled = true
// Reset the rows.
table.Table[0] = MibIPForwardRow{}
table.Table[1] = MibIPForwardRow{}
return
},
}
gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET)
require.NoError(t, gotErr)
assert.True(t, freeMibTableCalled)
// It verifies that the returned rows are independent copies, not referencing to the original table's memory, by
// asserting they retain the exact same content as the original table whose rows have been reset by freeMibTable.
expectedRows := []MibIPForwardRow{row1, row2}
assert.Equal(t, expectedRows, gotRows)
}

func NewTestNetIO(wantR1 uintptr) *netIO {
mockSyscallN := func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
return wantR1, 0, 0
}
Expand Down
Loading