From fe0640d8909c3fa85ce64c1e492af3bb798be0b0 Mon Sep 17 00:00:00 2001 From: Orestis Floros Date: Tue, 10 Oct 2023 13:14:03 +0200 Subject: [PATCH] Azure: Batch Fetcher: Group resources by subscription IDs and types (#1400) This has the side-benefit of only requiring one query for all types. Utilizes lo.GroupBy Fixes #1391 --- .../fetching/fetchers/azure/batch_fetcher.go | 25 +-- .../fetchers/azure/batch_fetcher_test.go | 144 ++++++++++++++++-- 2 files changed, 149 insertions(+), 20 deletions(-) diff --git a/resources/fetching/fetchers/azure/batch_fetcher.go b/resources/fetching/fetchers/azure/batch_fetcher.go index 6df0362189..c488ecd370 100644 --- a/resources/fetching/fetchers/azure/batch_fetcher.go +++ b/resources/fetching/fetchers/azure/batch_fetcher.go @@ -22,6 +22,8 @@ import ( "fmt" "github.com/elastic/elastic-agent-libs/logp" + "github.com/samber/lo" + "golang.org/x/exp/maps" "github.com/elastic/cloudbeat/resources/fetching" "github.com/elastic/cloudbeat/resources/providers/azurelib" @@ -50,21 +52,26 @@ func NewAzureBatchAssetFetcher(log *logp.Logger, ch chan fetching.ResourceInfo, func (f *AzureBatchAssetFetcher) Fetch(ctx context.Context, cMetadata fetching.CycleMetadata) error { f.log.Info("Starting AzureBatchAssetFetcher.Fetch") - for assetType, pair := range AzureBatchAssets { - assets, err := f.provider.ListAllAssetTypesByName([]string{assetType}) - if err != nil { - return err - } + allAssets, err := f.provider.ListAllAssetTypesByName(maps.Keys(AzureBatchAssets)) + if err != nil { + return err + } - if len(assets) == 0 { - continue - } + if len(allAssets) == 0 { + return nil + } + + assetGroups := lo.GroupBy(allAssets, func(item inventory.AzureAsset) string { + return fmt.Sprintf("%s-%s", item.SubscriptionId, item.Type) + }) + + for _, assets := range assetGroups { + pair := AzureBatchAssets[assets[0].Type] select { case <-ctx.Done(): f.log.Infof("AzureBatchAssetFetcher.Fetch context err: %s", ctx.Err().Error()) return nil - // TODO: Groups by subscription id to create multiple batches of assets case f.resourceCh <- fetching.ResourceInfo{ CycleMetadata: cMetadata, Resource: &AzureBatchResource{ diff --git a/resources/fetching/fetchers/azure/batch_fetcher_test.go b/resources/fetching/fetchers/azure/batch_fetcher_test.go index 438f4abcef..a32845cd58 100644 --- a/resources/fetching/fetchers/azure/batch_fetcher_test.go +++ b/resources/fetching/fetchers/azure/batch_fetcher_test.go @@ -20,10 +20,12 @@ package fetchers import ( "context" "fmt" + "strconv" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "golang.org/x/exp/maps" "github.com/elastic/cloudbeat/resources/fetching" "github.com/elastic/cloudbeat/resources/providers/azurelib/inventory" @@ -51,9 +53,6 @@ func (s *AzureBatchAssetFetcherTestSuite) TearDownTest() { } func (s *AzureBatchAssetFetcherTestSuite) TestFetcher_Fetch() { - ctx := context.Background() - - mockInventoryService := &inventory.MockServiceAPI{} mockAssets := map[string][]inventory.AzureAsset{ inventory.ActivityLogAlertAssetType: { { @@ -97,26 +96,29 @@ func (s *AzureBatchAssetFetcherTestSuite) TestFetcher_Fetch() { }, } + mockInventoryService := inventory.NewMockServiceAPI(s.T()) mockInventoryService.EXPECT(). ListAllAssetTypesByName(mock.AnythingOfType("[]string")). RunAndReturn(func(types []string) ([]inventory.AzureAsset, error) { - s.Require().Len(types, 1) - mockAssetsList, ok := mockAssets[types[0]] - s.Require().True(ok) - return mockAssetsList, nil - }).Times(len(mockAssets)) - defer mockInventoryService.AssertExpectations(s.T()) + s.ElementsMatch(maps.Keys(mockAssets), types) + + var result []inventory.AzureAsset + for _, tpe := range types { + result = append(result, mockAssets[tpe]...) + } + return result, nil + }).Once() fetcher := AzureBatchAssetFetcher{ log: testhelper.NewLogger(s.T()), resourceCh: s.resourceCh, provider: mockInventoryService, } - err := fetcher.Fetch(ctx, fetching.CycleMetadata{}) + err := fetcher.Fetch(context.Background(), fetching.CycleMetadata{}) s.Require().NoError(err) results := testhelper.CollectResources(s.resourceCh) - s.Require().Len(results, len(mockAssets)) + s.Len(results, len(mockAssets)) for assetType, expectedAssets := range mockAssets { result := findResult(results, assetType) @@ -158,6 +160,126 @@ func (s *AzureBatchAssetFetcherTestSuite) TestFetcher_Fetch() { } } +func (s *AzureBatchAssetFetcherTestSuite) TestFetcher_Fetch_Batches() { + var mockAssets []inventory.AzureAsset + for i, variableFields := range []struct { + sub string + tpe string + }{ + { + // 0 + sub: "1", + tpe: inventory.ActivityLogAlertAssetType, + }, + { + // 1 + sub: "1", + tpe: inventory.ActivityLogAlertAssetType, + }, + { + // 2 + sub: "2", + tpe: inventory.ActivityLogAlertAssetType, + }, + { + // 3 + sub: "3", + tpe: inventory.BastionAssetType, + }, + { + // 4 + sub: "1", + tpe: inventory.BastionAssetType, + }, + { + // 5 + sub: "2", + tpe: inventory.ActivityLogAlertAssetType, + }, + { + // 6 + sub: "3", + tpe: inventory.BastionAssetType, + }, + { + // 7 + sub: "4", + tpe: inventory.BastionAssetType, + }, + } { + id := strconv.Itoa(i) + mockAssets = append(mockAssets, inventory.AzureAsset{ + Id: "id" + id, + Name: "name" + id, + Location: "loc" + id, + Properties: map[string]any{"key" + id: "value" + id}, + ResourceGroup: "rg" + id, + SubscriptionId: variableFields.sub, + TenantId: "tenant", + Type: variableFields.tpe, + Sku: "sku" + id, + }) + } + + mockInventoryService := inventory.NewMockServiceAPI(s.T()) + mockInventoryService.EXPECT(). + ListAllAssetTypesByName(mock.AnythingOfType("[]string")). + Return(mockAssets, nil) + fetcher := AzureBatchAssetFetcher{ + log: testhelper.NewLogger(s.T()), + resourceCh: s.resourceCh, + provider: mockInventoryService, + } + + err := fetcher.Fetch(context.Background(), fetching.CycleMetadata{}) + s.Require().NoError(err) + results := testhelper.CollectResources(s.resourceCh) + + s.Len(results, 5) + s.ElementsMatch([]fetching.ResourceInfo{ + { // sub 1 + Resource: &AzureBatchResource{ + Type: fetching.MonitoringIdentity, + SubType: fetching.AzureActivityLogAlertType, + Assets: []inventory.AzureAsset{mockAssets[0], mockAssets[1]}, + }, + CycleMetadata: fetching.CycleMetadata{Sequence: 0}, + }, + { // sub 2 + Resource: &AzureBatchResource{ + Type: fetching.MonitoringIdentity, + SubType: fetching.AzureActivityLogAlertType, + Assets: []inventory.AzureAsset{mockAssets[2], mockAssets[5]}, + }, + CycleMetadata: fetching.CycleMetadata{Sequence: 0}, + }, + { // sub 1 + Resource: &AzureBatchResource{ + Type: fetching.CloudDns, + SubType: fetching.AzureBastionType, + Assets: []inventory.AzureAsset{mockAssets[4]}, + }, + CycleMetadata: fetching.CycleMetadata{Sequence: 0}, + }, + { // sub 3 + Resource: &AzureBatchResource{ + Type: fetching.CloudDns, + SubType: fetching.AzureBastionType, + Assets: []inventory.AzureAsset{mockAssets[3], mockAssets[6]}, + }, + CycleMetadata: fetching.CycleMetadata{Sequence: 0}, + }, + { // sub 4 + Resource: &AzureBatchResource{ + Type: fetching.CloudDns, + SubType: fetching.AzureBastionType, + Assets: []inventory.AzureAsset{mockAssets[7]}, + }, + CycleMetadata: fetching.CycleMetadata{Sequence: 0}, + }, + }, results) +} + func findResult(results []fetching.ResourceInfo, assetType string) *fetching.ResourceInfo { for _, result := range results { if result.GetData().([]inventory.AzureAsset)[0].Type == assetType {