Skip to content

Commit

Permalink
[Packetbeat] Fix Packetbeat parsing mongodb OP_MSG (#40589)
Browse files Browse the repository at this point in the history
* [Packetbeat] Fix Packetbeat parsing mongodb OP_MSG

* Fixes handling OP_MSG based request/response, missing "end" timestamp
  and "duration" field for the event
  • Loading branch information
aleksmaus authored Sep 4, 2024
1 parent 49582f4 commit b11b86a
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 92 deletions.
56 changes: 38 additions & 18 deletions packetbeat/protos/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
"github.com/elastic/beats/v7/packetbeat/procs"
"github.com/elastic/beats/v7/packetbeat/protos"
"github.com/elastic/beats/v7/packetbeat/protos/tcp"

"go.mongodb.org/mongo-driver/bson/primitive"
)

var debugf = logp.MakeDebug("mongodb")
Expand All @@ -54,7 +56,7 @@ type mongodbPlugin struct {

type transactionKey struct {
tcp common.HashableTCPTuple
id int
id int32
}

var unmatchedRequests = monitoring.NewInt(nil, "mongodb.unmatched_requests")
Expand Down Expand Up @@ -232,7 +234,7 @@ func (mongodb *mongodbPlugin) handleMongodb(

func (mongodb *mongodbPlugin) onRequest(conn *mongodbConnectionData, msg *mongodbMessage) {
// publish request only transaction
if !awaitsReply(msg.opCode) {
if !awaitsReply(msg) {
mongodb.onTransComplete(msg, nil)
return
}
Expand Down Expand Up @@ -273,7 +275,6 @@ func (mongodb *mongodbPlugin) onResponse(conn *mongodbConnectionData, msg *mongo
func (mongodb *mongodbPlugin) onTransComplete(requ, resp *mongodbMessage) {
trans := newTransaction(requ, resp)
debugf("Mongodb transaction completed: %s", trans.mongodb)

mongodb.publishTransaction(trans)
}

Expand All @@ -294,8 +295,9 @@ func newTransaction(requ, resp *mongodbMessage) *transaction {
}
trans.params = requ.params
trans.resource = requ.resource
trans.bytesIn = requ.messageLength
trans.bytesIn = int(requ.messageLength)
trans.documents = requ.documents
trans.requestDocuments = requ.documents // preserving request documents that contains mongodb query for the new OP_MSG based protocol
}

// fill response
Expand All @@ -308,7 +310,7 @@ func newTransaction(requ, resp *mongodbMessage) *transaction {
trans.documents = resp.documents

trans.endTime = resp.ts
trans.bytesOut = resp.messageLength
trans.bytesOut = int(resp.messageLength)

}

Expand All @@ -325,10 +327,17 @@ func (mongodb *mongodbPlugin) ReceivedFin(tcptuple *common.TCPTuple, dir uint8,
return private
}

func copyMapWithoutKey(d map[string]interface{}, key string) map[string]interface{} {
func copyMapWithoutKey(d map[string]interface{}, keys ...string) map[string]interface{} {
res := map[string]interface{}{}
for k, v := range d {
if k != key {
found := false
for _, excludeKey := range keys {
if k == excludeKey {
found = true
break
}
}
if !found {
res[k] = v
}
}
Expand All @@ -337,29 +346,40 @@ func copyMapWithoutKey(d map[string]interface{}, key string) map[string]interfac

func reconstructQuery(t *transaction, full bool) (query string) {
query = t.resource + "." + t.method + "("
var doc interface{}

if len(t.params) > 0 {
var err error
var params string
if !full {
// remove the actual data.
// TODO: review if we need to add other commands here
switch t.method {
case "insert":
params, err = doc2str(copyMapWithoutKey(t.params, "documents"))
doc = copyMapWithoutKey(t.params, "documents")
case "update":
params, err = doc2str(copyMapWithoutKey(t.params, "updates"))
doc = copyMapWithoutKey(t.params, "updates")
case "findandmodify":
params, err = doc2str(copyMapWithoutKey(t.params, "update"))
doc = copyMapWithoutKey(t.params, "update")
}
} else {
params, err = doc2str(t.params)
doc = t.params
}
if err != nil {
debugf("Error marshaling params: %v", err)
} else {
query += params
} else if len(t.requestDocuments) > 0 { // This recovers the query document from OP_MSG
if m, ok := t.requestDocuments[0].(primitive.M); ok {
excludeKeys := []string{"lsid"}
if !full {
excludeKeys = append(excludeKeys, "documents")
}
doc = copyMapWithoutKey(m, excludeKeys...)
}
}

queryString, err := doc2str(doc)
if err != nil {
debugf("Error marshaling query document: %v", err)
} else {
query += queryString
}

query += ")"
skip, _ := t.event["numberToSkip"].(int)
if skip > 0 {
Expand All @@ -370,7 +390,7 @@ func reconstructQuery(t *transaction, full bool) (query string) {
if limit > 0 && limit < 0x7fffffff {
query += fmt.Sprintf(".limit(%d)", limit)
}
return
return query
}

func (mongodb *mongodbPlugin) publishTransaction(t *transaction) {
Expand Down
48 changes: 20 additions & 28 deletions packetbeat/protos/mongodb/mongodb_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ func mongodbMessageParser(s *stream) (bool, bool) {
return true, false
}

if length > len(s.data) {
if int(length) > len(s.data) {
// Not yet reached the end of message
return true, false
}

// Tell decoder to only consider current message
d.truncate(length)
d.truncate(int(length))

// fill up the header common to all messages
// see http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#standard-message-header
Expand All @@ -72,8 +72,7 @@ func mongodbMessageParser(s *stream) (bool, bool) {
}

s.message.opCode = opCode
s.message.isResponse = false // default is that the message is a request. If not opReplyParse will set this to false
s.message.expectsResponse = false
s.message.isResponse = false // default is that the message is a request. If not opReplyParse will set this to true
debugf("opCode = %d (%v)", s.message.opCode, s.message.opCode)

// then split depending on operation type
Expand All @@ -93,11 +92,9 @@ func mongodbMessageParser(s *stream) (bool, bool) {
s.message.method = "insert"
return opInsertParse(d, s.message)
case opQuery:
s.message.expectsResponse = true
return opQueryParse(d, s.message)
case opGetMore:
s.message.method = "getMore"
s.message.expectsResponse = true
return opGetMoreParse(d, s.message)
case opDelete:
s.message.method = "delete"
Expand All @@ -107,6 +104,11 @@ func mongodbMessageParser(s *stream) (bool, bool) {
return opKillCursorsParse(d, s.message)
case opMsg:
s.message.method = "msg"
// The assumption is that the message with responseTo == 0 is the request
// TODO: handle the cases where moreToCome flag is set (multiple responses chained by responseTo)
if s.message.responseTo > 0 {
s.message.isResponse = true
}
return opMsgParse(d, s.message)
}

Expand Down Expand Up @@ -141,7 +143,7 @@ func opReplyParse(d *decoder, m *mongodbMessage) (bool, bool) {
debugf("Prepare to read %d document from reply", m.event["numberReturned"])

documents := make([]interface{}, numberReturned)
for i := 0; i < numberReturned; i++ {
for i := int32(0); i < numberReturned; i++ {
var document bson.M
document, err = d.readDocument()
if err != nil {
Expand Down Expand Up @@ -235,19 +237,6 @@ func opInsertParse(d *decoder, m *mongodbMessage) (bool, bool) {
return true, true
}

func extractDocuments(query map[string]interface{}) []interface{} {
docsVi, present := query["documents"]
if !present {
return []interface{}{}
}

docs, ok := docsVi.([]interface{})
if !ok {
return []interface{}{}
}
return docs
}

// Try to guess whether this key:value pair found in
// the query represents a command.
func isDatabaseCommand(key string, val interface{}) bool {
Expand Down Expand Up @@ -387,12 +376,14 @@ func opKillCursorsParse(d *decoder, m *mongodbMessage) (bool, bool) {

func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
// ignore flagbits
_, err := d.readInt32()
flagBits, err := d.readInt32()
if err != nil {
logp.Err("An error occurred while parsing OP_MSG message: %s", err)
return false, false
}

m.SetFlagBits(flagBits)

// read sections
kind, err := d.readByte()
if err != nil {
Expand Down Expand Up @@ -423,7 +414,7 @@ func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
}
m.event["message"] = cstring
var documents []interface{}
for d.i < start+size {
for d.i < start+int(size) {
document, err := d.readDocument()
if err != nil {
logp.Err("An error occurred while parsing OP_MSG message: %s", err)
Expand All @@ -432,7 +423,8 @@ func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
documents = append(documents, document)
}
m.documents = documents

case msgKindInternal:
// Ignore the internal purposes section
default:
logp.Err("Unknown message kind: %v", kind)
return false, false
Expand Down Expand Up @@ -482,25 +474,25 @@ func (d *decoder) readByte() (byte, error) {
return d.in[i], nil
}

func (d *decoder) readInt32() (int, error) {
func (d *decoder) readInt32() (int32, error) {
b, err := d.readBytes(4)
if err != nil {
return 0, err
}

return int((uint32(b[0]) << 0) |
return int32((uint32(b[0]) << 0) |
(uint32(b[1]) << 8) |
(uint32(b[2]) << 16) |
(uint32(b[3]) << 24)), nil
}

func (d *decoder) readInt64() (int, error) {
func (d *decoder) readInt64() (int64, error) {
b, err := d.readBytes(8)
if err != nil {
return 0, err
}

return int((uint64(b[0]) << 0) |
return int64((uint64(b[0]) << 0) |
(uint64(b[1]) << 8) |
(uint64(b[2]) << 16) |
(uint64(b[3]) << 24) |
Expand All @@ -516,7 +508,7 @@ func (d *decoder) readDocument() (bson.M, error) {
if err != nil {
return nil, err
}
d.i = start + documentLength
d.i = start + int(documentLength)
if len(d.in) < d.i {
return nil, errors.New("document out of bounds")
}
Expand Down
69 changes: 36 additions & 33 deletions packetbeat/protos/mongodb/mongodb_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
package mongodb

import (
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -77,6 +80,39 @@ func TestMongodbParser_simpleRequest(t *testing.T) {
}
}

func TestMongodbParser_OpMsg(t *testing.T) {
files := []string{
"1req.bin",
"1res.bin",
"2req.bin",
"2req.bin",
"3req.bin",
"3res.bin",
}

for _, fn := range files {
data, err := os.ReadFile(filepath.Join("testdata", fn))
if err != nil {
t.Fatal(err)
}

st := &stream{data: data, message: new(mongodbMessage)}

ok, complete := mongodbMessageParser(st)

if !ok {
t.Errorf("Parsing returned error")
}
if !complete {
t.Errorf("Expecting a complete message")
}
_, err = json.Marshal(st.message.documents)
if err != nil {
t.Fatal(err)
}
}
}

func TestMongodbParser_unknownOpCode(t *testing.T) {
var data []byte
data = addInt32(data, 16) // length = 16
Expand Down Expand Up @@ -107,39 +143,6 @@ func addInt32(in []byte, v int32) []byte {
return append(in, byte(u), byte(u>>8), byte(u>>16), byte(u>>24))
}

func Test_extract_documents(t *testing.T) {
type io struct {
Input map[string]interface{}
Output []interface{}
}
tests := []io{
{
Input: map[string]interface{}{
"a": 1,
"documents": []interface{}{"a", "b", "c"},
},
Output: []interface{}{"a", "b", "c"},
},
{
Input: map[string]interface{}{
"a": 1,
},
Output: []interface{}{},
},
{
Input: map[string]interface{}{
"a": 1,
"documents": 1,
},
Output: []interface{}{},
},
}

for _, test := range tests {
assert.Equal(t, test.Output, extractDocuments(test.Input))
}
}

func Test_isDatabaseCommand(t *testing.T) {
type io struct {
Key string
Expand Down
Loading

0 comments on commit b11b86a

Please sign in to comment.