From 2a3766644d5e4a25f74baf57929904ab1500bb7e Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Wed, 15 Jan 2020 22:09:11 +0000 Subject: [PATCH] Postgres connector validates packet lengths --- .../connectors/tcp/pg/protocol/protocol.go | 7 ++- .../tcp/pg/protocol/protocol_test.go | 63 +++++++++++++++++++ .../connectors/tcp/pg/protocol/startup.go | 5 +- 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 internal/plugin/connectors/tcp/pg/protocol/protocol_test.go diff --git a/internal/plugin/connectors/tcp/pg/protocol/protocol.go b/internal/plugin/connectors/tcp/pg/protocol/protocol.go index 50a4ffe7d..e8f80e0f5 100644 --- a/internal/plugin/connectors/tcp/pg/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/pg/protocol/protocol.go @@ -16,6 +16,7 @@ package protocol import ( "encoding/binary" + "errors" "io" ) @@ -66,7 +67,7 @@ func ReadStartupMessage(client io.Reader) ([]byte, error) { // ReadMessage accepts an incoming message. The first byte is the message type, the second int32 // is the message length, and the rest of the bytes are the message body. func ReadMessage(client io.Reader) (messageType byte, message []byte, err error) { - var messageTypeBytes = make([]byte, 1) + messageTypeBytes := make([]byte, 1) if err = binary.Read(client, binary.BigEndian, &messageTypeBytes); err != nil { return } @@ -84,6 +85,10 @@ func readMessage(client io.Reader) (message []byte, err error) { return } + if messageLength < 4 { + err = errors.New("invalid message length < 4") + return + } // Build a buffer of the appropriate size and fill it message = make([]byte, messageLength-4) if _, err = io.ReadFull(client, message); err != nil { diff --git a/internal/plugin/connectors/tcp/pg/protocol/protocol_test.go b/internal/plugin/connectors/tcp/pg/protocol/protocol_test.go new file mode 100644 index 000000000..29fa29e56 --- /dev/null +++ b/internal/plugin/connectors/tcp/pg/protocol/protocol_test.go @@ -0,0 +1,63 @@ +package protocol + +import ( + "encoding/binary" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReadMessage(t *testing.T) { + t.Run("parses contents", func(t *testing.T) { + r, w := net.Pipe() + expectedMessageType := byte(12) + expectedMessage := []byte{0,1,2,3,4} + + go func() { + err := binary.Write(w, binary.BigEndian, expectedMessageType) + if err != nil { + panic(err) + } + err = binary.Write(w, binary.BigEndian, int32(len(expectedMessage) + 4)) + if err != nil { + panic(err) + } + + _, err = w.Write(expectedMessage) + if err != nil { + panic(err) + } + }() + messageType, message, err := ReadMessage(r) + + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, expectedMessage, message) + assert.Equal(t, expectedMessageType, messageType) + }) + + t.Run("validates message length", func(t *testing.T) { + r, w := net.Pipe() + expectedMessageType := byte(12) + // a message length less than 4 is invalid + expectedMessageLength := int32(3) + + go func() { + err := binary.Write(w, binary.BigEndian, expectedMessageType) + if err != nil { + panic(err) + } + err = binary.Write(w, binary.BigEndian, expectedMessageLength) + if err != nil { + panic(err) + } + }() + _, _, err := ReadMessage(r) + if assert.Error(t, err) { + return + } + }) +} diff --git a/internal/plugin/connectors/tcp/pg/protocol/startup.go b/internal/plugin/connectors/tcp/pg/protocol/startup.go index cdeefa4ba..02b7bac12 100644 --- a/internal/plugin/connectors/tcp/pg/protocol/startup.go +++ b/internal/plugin/connectors/tcp/pg/protocol/startup.go @@ -24,10 +24,13 @@ func ParseStartupMessage(message []byte) (version int32, options map[string]stri options = make(map[string]string) for { param, err := messageBuffer.ReadString() - value, err := messageBuffer.ReadString() if err != nil || param == "\x00" { break } + value, err := messageBuffer.ReadString() + if err != nil || value == "\x00" { + break + } options[param] = value }