From db78c5ed90c7a40b410b1b6667f2906ad5261a17 Mon Sep 17 00:00:00 2001 From: Jeremy Fox <109584719+d0g0x01@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:34:11 +0000 Subject: [PATCH] [ASENG-654] Mongo trace bugfix (#138) --- pkg/kubehound/graph/adapter/mongo.go | 6 +- .../storage/storedb/mocks/store_provider.go | 20 ++--- .../storage/storedb/mongo_provider.go | 89 +++++++++++++++---- .../storage/storedb/mongo_provider_test.go | 83 ----------------- pkg/kubehound/storage/storedb/mongo_writer.go | 61 +++++++------ .../storage/storedb/mongo_writer_test.go | 50 ++++++----- pkg/kubehound/storage/storedb/provider.go | 4 +- test/system/vertex.gen.go | 2 +- 8 files changed, 147 insertions(+), 168 deletions(-) delete mode 100644 pkg/kubehound/storage/storedb/mongo_provider_test.go diff --git a/pkg/kubehound/graph/adapter/mongo.go b/pkg/kubehound/graph/adapter/mongo.go index a27bfd837..3b0603f1c 100644 --- a/pkg/kubehound/graph/adapter/mongo.go +++ b/pkg/kubehound/graph/adapter/mongo.go @@ -11,12 +11,12 @@ import ( // MongoDB is a helper function to retrieve the store database object from a mongoDB provider. func MongoDB(store storedb.Provider) *mongo.Database { - mongoClient, ok := store.Raw().(*mongo.Client) + db, ok := store.Reader().(*mongo.Database) if !ok { - log.I.Fatalf("Invalid database provider type. Expected *mongo.Client, got %T", store.Raw()) + log.I.Fatalf("Invalid database provider type. Expected *mongo.Client, got %T", store.Reader()) } - return mongoClient.Database(storedb.MongoDatabaseName) + return db } // MongoCursorHandler is the default stream implementation to handle the query results from a mongoDB store provider. diff --git a/pkg/kubehound/storage/storedb/mocks/store_provider.go b/pkg/kubehound/storage/storedb/mocks/store_provider.go index 301c00bfb..055b7d2d6 100644 --- a/pkg/kubehound/storage/storedb/mocks/store_provider.go +++ b/pkg/kubehound/storage/storedb/mocks/store_provider.go @@ -272,8 +272,8 @@ func (_c *Provider_Prepare_Call) RunAndReturn(run func(context.Context) error) * return _c } -// Raw provides a mock function with given fields: -func (_m *Provider) Raw() interface{} { +// Reader provides a mock function with given fields: +func (_m *Provider) Reader() interface{} { ret := _m.Called() var r0 interface{} @@ -288,29 +288,29 @@ func (_m *Provider) Raw() interface{} { return r0 } -// Provider_Raw_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Raw' -type Provider_Raw_Call struct { +// Provider_Reader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Reader' +type Provider_Reader_Call struct { *mock.Call } -// Raw is a helper method to define mock.On call -func (_e *Provider_Expecter) Raw() *Provider_Raw_Call { - return &Provider_Raw_Call{Call: _e.mock.On("Raw")} +// Reader is a helper method to define mock.On call +func (_e *Provider_Expecter) Reader() *Provider_Reader_Call { + return &Provider_Reader_Call{Call: _e.mock.On("Reader")} } -func (_c *Provider_Raw_Call) Run(run func()) *Provider_Raw_Call { +func (_c *Provider_Reader_Call) Run(run func()) *Provider_Reader_Call { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *Provider_Raw_Call) Return(_a0 interface{}) *Provider_Raw_Call { +func (_c *Provider_Reader_Call) Return(_a0 interface{}) *Provider_Reader_Call { _c.Call.Return(_a0) return _c } -func (_c *Provider_Raw_Call) RunAndReturn(run func() interface{}) *Provider_Raw_Call { +func (_c *Provider_Reader_Call) RunAndReturn(run func() interface{}) *Provider_Reader_Call { _c.Call.Return(run) return _c } diff --git a/pkg/kubehound/storage/storedb/mongo_provider.go b/pkg/kubehound/storage/storedb/mongo_provider.go index 513631e5f..50ce10988 100644 --- a/pkg/kubehound/storage/storedb/mongo_provider.go +++ b/pkg/kubehound/storage/storedb/mongo_provider.go @@ -7,10 +7,12 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" "github.com/DataDog/KubeHound/pkg/telemetry/tag" + "github.com/hashicorp/go-multierror" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" + mongotrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go.mongodb.org/mongo-driver/mongo" ) const ( @@ -21,50 +23,81 @@ var ( _ Provider = (*MongoProvider)(nil) ) +// A MongoDB based store provider implementation. type MongoProvider struct { - client *mongo.Client - db *mongo.Database - tags []string + reader *mongo.Client // MongoDB client optimized for read operations + writer *mongo.Client // MongoDB client optimized for write operations + tags []string // Tags to be applied for telemetry } -func NewMongoProvider(ctx context.Context, url string, connectionTimeout time.Duration) (*MongoProvider, error) { - opts := options.Client() - opts.ApplyURI(url + fmt.Sprintf("/?connectTimeoutMS=%d", connectionTimeout)) +// createClient creates a new MongoDB client with the provided options. +func createClient(ctx context.Context, opts *options.ClientOptions, timeout time.Duration) (*mongo.Client, error) { client, err := mongo.Connect(ctx, opts) if err != nil { return nil, err } - ctx, cancel := context.WithTimeout(ctx, connectionTimeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() err = client.Ping(ctx, readpref.Primary()) if err != nil { return nil, err } - db := client.Database(MongoDatabaseName) + return client, nil +} + +// createReaderWriter creates a pair of MongoDB clients - one for writes and another for reads. +func createReaderWriter(ctx context.Context, url string, timeout time.Duration) (*mongo.Client, *mongo.Client, error) { + baseOpts := options.Client() + baseOpts.ApplyURI(url + fmt.Sprintf("/?connectTimeoutMS=%d", timeout)) + + writer, err := createClient(ctx, baseOpts, timeout) + if err != nil { + return nil, nil, err + } + + opts := baseOpts + opts.Monitor = mongotrace.NewMonitor() + reader, err := createClient(ctx, opts, timeout) + if err != nil { + _ = writer.Disconnect(ctx) + + return nil, nil, err + } + + return reader, writer, nil +} + +// NewMongoProvider creates a new instance of the MongoDB store provider +func NewMongoProvider(ctx context.Context, url string, connectionTimeout time.Duration) (*MongoProvider, error) { + reader, writer, err := createReaderWriter(ctx, url, connectionTimeout) + if err != nil { + return nil, err + } return &MongoProvider{ - client: client, - db: db, + reader: reader, + writer: writer, tags: append(tag.BaseTags, tag.Storage(StorageProviderName)), }, nil } func (mp *MongoProvider) Prepare(ctx context.Context) error { - collections, err := mp.db.ListCollectionNames(ctx, bson.M{}) + db := mp.writer.Database(MongoDatabaseName) + collections, err := db.ListCollectionNames(ctx, bson.M{}) if err != nil { return fmt.Errorf("listing mongo DB collections: %w", err) } for _, collectionName := range collections { - err = mp.db.Collection(collectionName).Drop(ctx) + err = db.Collection(collectionName).Drop(ctx) if err != nil { return fmt.Errorf("deleting mongo DB collection %s: %w", collectionName, err) } } - ib, err := NewIndexBuilder(mp.db) + ib, err := NewIndexBuilder(db) if err != nil { return fmt.Errorf("mongo DB index builder create: %w", err) } @@ -76,8 +109,8 @@ func (mp *MongoProvider) Prepare(ctx context.Context) error { return nil } -func (mp *MongoProvider) Raw() any { - return mp.client +func (mp *MongoProvider) Reader() any { + return mp.reader.Database(MongoDatabaseName) } func (mp *MongoProvider) Name() string { @@ -85,7 +118,12 @@ func (mp *MongoProvider) Name() string { } func (mp *MongoProvider) HealthCheck(ctx context.Context) (bool, error) { - err := mp.client.Ping(ctx, nil) + err := mp.reader.Ping(ctx, nil) + if err != nil { + return false, err + } + + err = mp.writer.Ping(ctx, nil) if err != nil { return false, err } @@ -94,11 +132,26 @@ func (mp *MongoProvider) HealthCheck(ctx context.Context) (bool, error) { } func (mp *MongoProvider) Close(ctx context.Context) error { - return mp.client.Disconnect(ctx) + var res *multierror.Error + if mp.reader != nil { + err := mp.reader.Disconnect(ctx) + if err != nil { + res = multierror.Append(res, err) + } + } + + if mp.writer != nil { + err := mp.writer.Disconnect(ctx) + if err != nil { + res = multierror.Append(res, err) + } + } + + return res.ErrorOrNil() } func (mp *MongoProvider) BulkWriter(ctx context.Context, collection collections.Collection, opts ...WriterOption) (AsyncWriter, error) { - writer := NewMongoAsyncWriter(ctx, mp, collection, opts...) + writer := NewMongoAsyncWriter(ctx, mp.writer.Database(MongoDatabaseName), collection, opts...) return writer, nil } diff --git a/pkg/kubehound/storage/storedb/mongo_provider_test.go b/pkg/kubehound/storage/storedb/mongo_provider_test.go deleted file mode 100644 index c8555a883..000000000 --- a/pkg/kubehound/storage/storedb/mongo_provider_test.go +++ /dev/null @@ -1,83 +0,0 @@ -//nolint:containedctx,unused -package storedb - -import ( - "context" - "testing" - "time" - - "github.com/DataDog/KubeHound/pkg/config" - "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" - "go.mongodb.org/mongo-driver/mongo" -) - -func TestMongoProvider_BulkWriter(t *testing.T) { - t.Parallel() - // FIXME: we should probably setup a mongodb test server in CI for the system tests - if config.IsCI() { - t.Skip("Skip mongo tests in CI") - } - - ctx := context.Background() - provider, err := NewMongoProvider(ctx, MongoLocalDatabaseURL, 1*time.Second) - // TODO: add another check (env var maybe?) - // "integration test checks" - if err != nil { - t.Error("FAILED TO CONNECT TO LOCAL MONGO DB DURING TESTS, SKIPPING") - - return - } - - fakeCollection := collections.FakeCollection{} - - type fields struct { - client *mongo.Client - db *mongo.Database - collection *mongo.Collection - } - type args struct { - ctx context.Context - collection collections.Collection - opts []WriterOption - } - - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - name: "Bulk writer test with valid collection", - fields: fields{ - client: provider.client, - db: provider.db, - }, - args: args{ - ctx: context.Background(), - collection: fakeCollection, - }, - wantErr: false, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - mp := &MongoProvider{ - client: tt.fields.client, - db: tt.fields.db, - } - writer, err := mp.BulkWriter(tt.args.ctx, tt.args.collection, tt.args.opts...) - if (err != nil) != tt.wantErr { - t.Errorf("MongoProvider.BulkWriter() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if writer == nil { - t.Errorf("writer returned by BulkWriter is nil") - } - }) - } -} diff --git a/pkg/kubehound/storage/storedb/mongo_writer.go b/pkg/kubehound/storage/storedb/mongo_writer.go index db97dc71e..d84a4d0ad 100644 --- a/pkg/kubehound/storage/storedb/mongo_writer.go +++ b/pkg/kubehound/storage/storedb/mongo_writer.go @@ -25,28 +25,30 @@ const ( var _ AsyncWriter = (*MongoAsyncWriter)(nil) type MongoAsyncWriter struct { - mongodb *MongoProvider + collection collections.Collection ops []mongo.WriteModel - collection *mongo.Collection + opsLock *sync.RWMutex + dbWriter *mongo.Collection batchSize int consumerChan chan []mongo.WriteModel writingInFlight *sync.WaitGroup tags []string } -func NewMongoAsyncWriter(ctx context.Context, mp *MongoProvider, collection collections.Collection, opts ...WriterOption) *MongoAsyncWriter { +func NewMongoAsyncWriter(ctx context.Context, db *mongo.Database, collection collections.Collection, opts ...WriterOption) *MongoAsyncWriter { wOpts := &writerOptions{} for _, o := range opts { o(wOpts) } maw := MongoAsyncWriter{ - mongodb: mp, - collection: mp.db.Collection(collection.Name()), - batchSize: collection.BatchSize(), - tags: append(wOpts.Tags, tag.Collection(collection.Name())), - + dbWriter: db.Collection(collection.Name()), + batchSize: collection.BatchSize(), + tags: append(wOpts.Tags, tag.Collection(collection.Name())), + collection: collection, writingInFlight: &sync.WaitGroup{}, + ops: make([]mongo.WriteModel, 0), + opsLock: &sync.RWMutex{}, } maw.consumerChan = make(chan []mongo.WriteModel, consumerChanSize) maw.startBackgroundWriter(ctx) @@ -86,14 +88,12 @@ func (maw *MongoAsyncWriter) batchWrite(ctx context.Context, ops []mongo.WriteMo span, ctx := tracer.StartSpanFromContext(ctx, span.MongoDBBatchWrite, tracer.Measured()) span.SetTag(tag.CollectionTag, maw.collection.Name()) defer span.Finish() - - maw.writingInFlight.Add(1) defer maw.writingInFlight.Done() _ = statsd.Count(metric.ObjectWrite, int64(len(ops)), maw.tags, 1) bulkWriteOpts := options.BulkWrite().SetOrdered(false) - _, err := maw.collection.BulkWrite(ctx, ops, bulkWriteOpts) + _, err := maw.dbWriter.BulkWrite(ctx, ops, bulkWriteOpts) if err != nil { return fmt.Errorf("could not write in bulk to mongo: %w", err) } @@ -103,10 +103,16 @@ func (maw *MongoAsyncWriter) batchWrite(ctx context.Context, ops []mongo.WriteMo // Queue add a model to an asynchronous write queue. Non-blocking. func (maw *MongoAsyncWriter) Queue(ctx context.Context, model any) error { - maw.ops = append(maw.ops, mongo.NewInsertOneModel().SetDocument(model)) + maw.opsLock.Lock() + defer maw.opsLock.Unlock() + maw.ops = append(maw.ops, mongo.NewInsertOneModel().SetDocument(model)) if len(maw.ops) > maw.batchSize { - maw.consumerChan <- maw.ops + copied := make([]mongo.WriteModel, len(maw.ops)) + copy(copied, maw.ops) + + maw.writingInFlight.Add(1) + maw.consumerChan <- copied _ = statsd.Incr(metric.QueueSize, maw.tags, 1) // cleanup the ops array after we have copied it to the channel @@ -123,7 +129,7 @@ func (maw *MongoAsyncWriter) Flush(ctx context.Context) error { span.SetTag(tag.CollectionTag, maw.collection.Name()) defer span.Finish() - if maw.mongodb.client == nil { + if maw.dbWriter == nil { return fmt.Errorf("mongodb client is not initialized") } @@ -131,31 +137,30 @@ func (maw *MongoAsyncWriter) Flush(ctx context.Context) error { return fmt.Errorf("mongodb collection is not initialized") } - if len(maw.ops) == 0 { - log.Trace(ctx).Debugf("Skipping flush on %s as no write operations", maw.collection.Name()) - // we need to send something to the channel from this function whenever we don't return an error - // we cannot defer it because the go routine may last longer than the current function - // the defer is going to be executed at the return time, whetever or not the inner go routine is processing data - maw.writingInFlight.Wait() + maw.opsLock.Lock() + defer maw.opsLock.Unlock() - return nil - } + if len(maw.ops) != 0 { + maw.writingInFlight.Add(1) + err := maw.batchWrite(ctx, maw.ops) + if err != nil { + log.Trace(ctx).Errorf("batch write %s: %+v", maw.collection.Name(), err) + maw.writingInFlight.Wait() - err := maw.batchWrite(ctx, maw.ops) - if err != nil { - maw.writingInFlight.Wait() + return err + } - return err + maw.ops = nil } - maw.ops = nil + maw.writingInFlight.Wait() return nil } // Close cleans up any resources used by the AsyncWriter implementation. Writer cannot be reused after this call. func (maw *MongoAsyncWriter) Close(ctx context.Context) error { - if maw.mongodb.client == nil { + if maw.dbWriter == nil { return nil } diff --git a/pkg/kubehound/storage/storedb/mongo_writer_test.go b/pkg/kubehound/storage/storedb/mongo_writer_test.go index 9db29a26b..64c9835cd 100644 --- a/pkg/kubehound/storage/storedb/mongo_writer_test.go +++ b/pkg/kubehound/storage/storedb/mongo_writer_test.go @@ -9,6 +9,8 @@ import ( "github.com/DataDog/KubeHound/pkg/config" "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" "github.com/DataDog/KubeHound/pkg/telemetry/tag" + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/mongo" ) // We need a "complex" object to store in MongoDB @@ -17,6 +19,22 @@ type FakeElement struct { FieldB string } +type CleanupFunc func() + +func makeDB(t *testing.T) (*mongo.Database, CleanupFunc) { + t.Helper() + + mp, err := NewMongoProvider(context.Background(), MongoLocalDatabaseURL, 1*time.Second) + assert.NoError(t, err) + + db := mp.writer.Database("testdb") + cleanup := func() { + _ = mp.writer.Disconnect(context.Background()) + } + + return db, cleanup +} + func TestMongoAsyncWriter_Queue(t *testing.T) { t.Parallel() @@ -31,15 +49,8 @@ func TestMongoAsyncWriter_Queue(t *testing.T) { } ctx := context.Background() - mongoProvider, err := NewMongoProvider(ctx, MongoLocalDatabaseURL, 1*time.Second) - - // TODO: add another check (env var maybe?) - // "integration test checks" - if err != nil { - t.Error("FAILED TO CONNECT TO LOCAL MONGO DB DURING TESTS, SKIPPING") - - return - } + db, cleanup := makeDB(t) + t.Cleanup(cleanup) type args struct { ctx context.Context @@ -89,7 +100,7 @@ func TestMongoAsyncWriter_Queue(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - writer := NewMongoAsyncWriter(ctx, mongoProvider, collections.FakeCollection{}, WithTags([]string{tag.Storage("mongotest")})) + writer := NewMongoAsyncWriter(ctx, db, collections.FakeCollection{}, WithTags([]string{tag.Storage("mongotest")})) // insert multiple times if needed for _, args := range tt.args { if err := writer.Queue(args.ctx, args.model); (err != nil) != tt.wantErr { @@ -201,23 +212,16 @@ func TestMongoAsyncWriter_Flush(t *testing.T) { }, } + ctx := context.Background() + db, cleanup := makeDB(t) + t.Cleanup(cleanup) + for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - ctx := context.Background() - mongoProvider, err := NewMongoProvider(ctx, MongoLocalDatabaseURL, 1*time.Second) - // TODO: add another check (env var maybe?) - // "integration test checks" - if err != nil { - t.Error("FAILED TO CONNECT TO LOCAL MONGO DB DURING TESTS, SKIPPING") - - return - } - defer mongoProvider.Close(ctx) - - maw := NewMongoAsyncWriter(ctx, mongoProvider, collections.FakeCollection{}) + maw := NewMongoAsyncWriter(ctx, db, collections.FakeCollection{}) // insert multiple times if needed for _, args := range tt.argsQueue { if err := maw.Queue(args.ctx, args.model); (err != nil) != tt.wantErr { @@ -225,7 +229,7 @@ func TestMongoAsyncWriter_Flush(t *testing.T) { } } // blocking function - err = maw.Flush(tt.argsFlush.ctx) + err := maw.Flush(tt.argsFlush.ctx) if (err != nil) != tt.wantErr { t.Errorf("MongoAsyncWriter.Flush() error = %v, wantErr %v", err, tt.wantErr) diff --git a/pkg/kubehound/storage/storedb/provider.go b/pkg/kubehound/storage/storedb/provider.go index ff4391664..9cd1c8473 100644 --- a/pkg/kubehound/storage/storedb/provider.go +++ b/pkg/kubehound/storage/storedb/provider.go @@ -38,8 +38,8 @@ type Provider interface { // Prepare drops all collections from the database (usually to ensure a clean start) and recreates indices. Prepare(ctx context.Context) error - // Raw returns a handle to the underlying provider to allow implementation specific operations e.g db queries. - Raw() any + // Reader returns a handle to the underlying provider to allow implementation specific queries against the mongo DB + Reader() any // BulkWriter creates a new AsyncWriter instance to enable asynchronous bulk inserts. BulkWriter(ctx context.Context, collection collections.Collection, opts ...WriterOption) (AsyncWriter, error) diff --git a/test/system/vertex.gen.go b/test/system/vertex.gen.go index e4dca24c1..c671d2645 100644 --- a/test/system/vertex.gen.go +++ b/test/system/vertex.gen.go @@ -1,5 +1,5 @@ // PLEASE DO NOT EDIT -// THIS HAS BEEN GENERATED AUTOMATICALLY on 2023-10-24 18:42 +// THIS HAS BEEN GENERATED AUTOMATICALLY on 2023-10-30 17:12 // // Generate it with "go generate ./..." //