From c69479f401174a248eff0bfceaf36c765d1f53a2 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Thu, 24 Feb 2022 14:06:24 -0500 Subject: [PATCH] Address PR comments --- lib/srv/db/sqlserver/protocol/fuzz_test.go | 8 ++++++-- lib/srv/db/sqlserver/protocol/login7.go | 16 ++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/lib/srv/db/sqlserver/protocol/fuzz_test.go b/lib/srv/db/sqlserver/protocol/fuzz_test.go index 2ecfecf637f36..30a62298e78f3 100644 --- a/lib/srv/db/sqlserver/protocol/fuzz_test.go +++ b/lib/srv/db/sqlserver/protocol/fuzz_test.go @@ -24,6 +24,8 @@ package protocol import ( "bytes" "testing" + + "github.com/stretchr/testify/require" ) func FuzzMSSQLLogin(f *testing.F) { @@ -44,7 +46,9 @@ func FuzzMSSQLLogin(f *testing.F) { f.Fuzz(func(t *testing.T, packet []byte) { reader := bytes.NewReader(packet) - // no assertion, check for panic - ReadLogin7Packet(reader) + + require.NotPanics(t, func() { + _, _ = ReadLogin7Packet(reader) + }) }) } diff --git a/lib/srv/db/sqlserver/protocol/login7.go b/lib/srv/db/sqlserver/protocol/login7.go index 8753a31b94f59..a413fb559ed46 100644 --- a/lib/srv/db/sqlserver/protocol/login7.go +++ b/lib/srv/db/sqlserver/protocol/login7.go @@ -141,17 +141,13 @@ func ReadLogin7Packet(r io.Reader) (*Login7Packet, error) { }, nil } -// errInvalidPackage is returned when Login7 package contains invalid data. -var errInvalidPackage = trace.Errorf("invalid login7 packet") +// errInvalidPacket is returned when Login7 package contains invalid data. +var errInvalidPacket = trace.Errorf("invalid login7 packet") // readUsername reads username from login7 package. func readUsername(pkt *Packet, header Login7Header) (string, error) { if len(pkt.Data) < int(header.IbUserName)+int(header.CchUserName)*2 { - return "", errInvalidPackage - } - - if len(pkt.Data) <= int(header.IbUserName) { - return "", errInvalidPackage + return "", errInvalidPacket } // Decode username and database from the packet. Offset/length are counted @@ -167,11 +163,7 @@ func readUsername(pkt *Packet, header Login7Header) (string, error) { // readDatabase reads database name from login7 package. func readDatabase(pkt *Packet, header Login7Header) (string, error) { if len(pkt.Data) < int(header.IbDatabase)+int(header.CchDatabase)*2 { - return "", errInvalidPackage - } - - if len(pkt.Data) < int(header.IbDatabase) { - return "", errInvalidPackage + return "", errInvalidPacket } database, err := mssql.ParseUCS2String(