From fbeeaa7537853d6d7a936aa180714c9213d64f19 Mon Sep 17 00:00:00 2001 From: Chlins Zhang Date: Mon, 5 Jun 2023 15:12:54 +0800 Subject: [PATCH] fix: add checkpoint when enqueue scan tasks for scan all (#18680) Fix the scanAll cannot be stopped in case of large number of artifacts, add the checkpoint before submit scan tasks, mark the scanAll stopped flag in the redis. Fixes: #18044 Signed-off-by: chlins --- src/controller/artifact/helper.go | 8 ++- src/controller/scan/base_controller.go | 63 ++++++++++++++++++++- src/controller/scan/base_controller_test.go | 37 +++++++++--- src/controller/scan/callback_test.go | 2 +- src/controller/scan/controller.go | 10 ++++ src/server/v2.0/handler/scan_all.go | 11 ++-- src/server/v2.0/handler/scan_all_test.go | 1 + src/testing/controller/scan/controller.go | 14 +++++ 8 files changed, 127 insertions(+), 19 deletions(-) diff --git a/src/controller/artifact/helper.go b/src/controller/artifact/helper.go index 3a5d798e2ee..d5602dad822 100644 --- a/src/controller/artifact/helper.go +++ b/src/controller/artifact/helper.go @@ -40,7 +40,13 @@ func Iterator(ctx context.Context, chunkSize int, query *q.Query, option *Option } for _, artifact := range artifacts { - ch <- artifact + select { + case <-ctx.Done(): + log.G(ctx).Errorf("context done, list artifacts exited, error: %v", ctx.Err()) + return + case ch <- artifact: + continue + } } if len(artifacts) < chunkSize { diff --git a/src/controller/scan/base_controller.go b/src/controller/scan/base_controller.go index 2a2e94b6e98..bc504597276 100644 --- a/src/controller/scan/base_controller.go +++ b/src/controller/scan/base_controller.go @@ -21,6 +21,7 @@ import ( "reflect" "strings" "sync" + "time" "github.com/google/uuid" @@ -30,6 +31,7 @@ import ( sc "github.com/goharbor/harbor/src/controller/scanner" "github.com/goharbor/harbor/src/controller/tag" "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/lib/cache" "github.com/goharbor/harbor/src/lib/config" "github.com/goharbor/harbor/src/lib/errors" "github.com/goharbor/harbor/src/lib/log" @@ -50,8 +52,12 @@ import ( "github.com/goharbor/harbor/src/pkg/task" ) -// DefaultController is a default singleton scan API controller. -var DefaultController = NewController() +var ( + // DefaultController is a default singleton scan API controller. + DefaultController = NewController() + + errScanAllStopped = errors.New("scanAll stopped") +) // const definitions const ( @@ -74,6 +80,9 @@ type uuidGenerator func() (string, error) // utility methods. type configGetter func(cfg string) (string, error) +// cacheGetter returns cache +type cacheGetter func() cache.Cache + // launchScanJobParam is a param to launch scan job. type launchScanJobParam struct { ExecutionID int64 @@ -109,6 +118,8 @@ type basicController struct { taskMgr task.Manager // Converter for V1 report to V2 report reportConverter postprocessors.NativeScanReportConverter + // cache stores the stop scan all marks + cache cacheGetter } // NewController news a scan API controller @@ -154,6 +165,9 @@ func NewController() Controller { taskMgr: task.Mgr, // Get the scan V1 to V2 report converters reportConverter: postprocessors.Converter, + cache: func() cache.Cache { + return cache.Default() + }, } } @@ -368,6 +382,44 @@ func (bc *basicController) ScanAll(ctx context.Context, trigger string, async bo return executionID, nil } +func (bc *basicController) StopScanAll(ctx context.Context, executionID int64, async bool) error { + stopScanAll := func(ctx context.Context, executionID int64) error { + // mark scan all stopped + if err := bc.markScanAllStopped(ctx, executionID); err != nil { + return err + } + // stop the execution and sub tasks + return bc.execMgr.Stop(ctx, executionID) + } + + if async { + go func() { + if err := stopScanAll(ctx, executionID); err != nil { + log.Errorf("failed to stop scan all, error: %v", err) + } + }() + return nil + } + + return stopScanAll(ctx, executionID) +} + +func scanAllStoppedKey(execID int64) string { + return fmt.Sprintf("scan_all:execution_id:%d:stopped", execID) +} + +func (bc *basicController) markScanAllStopped(ctx context.Context, execID int64) error { + // set the expire time to 2 hours, the duration should be large enough + // for controller to capture the stop flag, leverage the key recycled + // by redis TTL, no need to clean by scan controller as the new scan all + // will have a new unique execution id, the old key has no effects to anything. + return bc.cache().Save(ctx, scanAllStoppedKey(execID), "", 2*time.Hour) +} + +func (bc *basicController) isScanAllStopped(ctx context.Context, execID int64) bool { + return bc.cache().Contains(ctx, scanAllStoppedKey(execID)) +} + func (bc *basicController) startScanAll(ctx context.Context, executionID int64) error { batchSize := 50 @@ -379,8 +431,15 @@ func (bc *basicController) startScanAll(ctx context.Context, executionID int64) UnsupportCount int `json:"unsupport_count"` UnknowCount int `json:"unknow_count"` }{} + // with cancel function to signal downstream worker + ctx, cancel := context.WithCancel(ctx) + defer cancel() for artifact := range ar.Iterator(ctx, batchSize, nil, nil) { + if bc.isScanAllStopped(ctx, executionID) { + return errScanAllStopped + } + summary.TotalCount++ scan := func(ctx context.Context) error { diff --git a/src/controller/scan/base_controller_test.go b/src/controller/scan/base_controller_test.go index 8b781fc501c..d922c3f2228 100644 --- a/src/controller/scan/base_controller_test.go +++ b/src/controller/scan/base_controller_test.go @@ -30,6 +30,7 @@ import ( "github.com/goharbor/harbor/src/common/rbac" "github.com/goharbor/harbor/src/controller/artifact" "github.com/goharbor/harbor/src/controller/robot" + "github.com/goharbor/harbor/src/lib/cache" "github.com/goharbor/harbor/src/lib/config" "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" @@ -49,6 +50,7 @@ import ( robottesting "github.com/goharbor/harbor/src/testing/controller/robot" scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner" tagtesting "github.com/goharbor/harbor/src/testing/controller/tag" + mockcache "github.com/goharbor/harbor/src/testing/lib/cache" ormtesting "github.com/goharbor/harbor/src/testing/lib/orm" "github.com/goharbor/harbor/src/testing/mock" accessorytesting "github.com/goharbor/harbor/src/testing/pkg/accessory" @@ -77,6 +79,7 @@ type ControllerTestSuite struct { ar artifact.Controller c Controller reportConverter *postprocessorstesting.ScanReportV1ToV2Converter + cache *mockcache.Cache } // TestController is the entry point of ControllerTestSuite. @@ -271,6 +274,8 @@ func (suite *ControllerTestSuite) SetupSuite() { suite.taskMgr = &tasktesting.Manager{} + suite.cache = &mockcache.Cache{} + suite.c = &basicController{ manager: mgr, ar: suite.ar, @@ -298,6 +303,7 @@ func (suite *ControllerTestSuite) SetupSuite() { execMgr: suite.execMgr, taskMgr: suite.taskMgr, reportConverter: &postprocessorstesting.ScanReportV1ToV2Converter{}, + cache: func() cache.Cache { return suite.cache }, } } @@ -522,25 +528,25 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { func (suite *ControllerTestSuite) TestScanAll() { { // no artifacts found when scan all - ctx := context.TODO() - executionID := int64(1) suite.execMgr.On( - "Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE", + "Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE", ).Return(executionID, nil).Once() mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once() mock.OnAnything(suite.artifactCtl, "List").Return([]*artifact.Artifact{}, nil).Once() - suite.taskMgr.On("Count", ctx, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once() + suite.taskMgr.On("Count", mock.Anything, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once() mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() - suite.execMgr.On("MarkDone", ctx, executionID, mock.Anything).Return(nil).Once() + suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once() - _, err := suite.c.ScanAll(ctx, "SCHEDULE", false) + suite.cache.On("Contains", mock.Anything, scanAllStoppedKey(1)).Return(false).Once() + + _, err := suite.c.ScanAll(context.TODO(), "SCHEDULE", false) suite.NoError(err) } @@ -551,7 +557,7 @@ func (suite *ControllerTestSuite) TestScanAll() { executionID := int64(1) suite.execMgr.On( - "Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE", + "Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE", ).Return(executionID, nil).Once() mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once() @@ -568,13 +574,28 @@ func (suite *ControllerTestSuite) TestScanAll() { mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once() mock.OnAnything(suite.taskMgr, "Create").Return(int64(0), fmt.Errorf("failed")).Once() mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() - suite.execMgr.On("MarkError", ctx, executionID, mock.Anything).Return(nil).Once() + suite.execMgr.On("MarkError", mock.Anything, executionID, mock.Anything).Return(nil).Once() _, err := suite.c.ScanAll(ctx, "SCHEDULE", false) suite.NoError(err) } } +func (suite *ControllerTestSuite) TestStopScanAll() { + mockExecID := int64(100) + // mock error case + mockErr := fmt.Errorf("stop scan all error") + suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(mockErr).Once() + err := suite.c.StopScanAll(context.TODO(), mockExecID, false) + suite.EqualError(err, mockErr.Error()) + + // mock normal case + suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(nil).Once() + suite.execMgr.On("Stop", mock.Anything, mockExecID).Return(nil).Once() + err = suite.c.StopScanAll(context.TODO(), mockExecID, false) + suite.NoError(err) +} + func (suite *ControllerTestSuite) TestDeleteReports() { suite.reportMgr.On("DeleteByDigests", context.TODO(), "digest").Return(nil).Once() diff --git a/src/controller/scan/callback_test.go b/src/controller/scan/callback_test.go index 19c3d565998..4a1be86b1f9 100644 --- a/src/controller/scan/callback_test.go +++ b/src/controller/scan/callback_test.go @@ -157,7 +157,7 @@ func (suite *CallbackTestSuite) TestScanAllCallback() { mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() - suite.execMgr.On("MarkDone", context.TODO(), executionID, mock.Anything).Return(nil).Once() + suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once() suite.NoError(scanAllCallback(context.TODO(), "")) } diff --git a/src/controller/scan/controller.go b/src/controller/scan/controller.go index 5029a7dcc3f..b890ff7d3f0 100644 --- a/src/controller/scan/controller.go +++ b/src/controller/scan/controller.go @@ -115,6 +115,16 @@ type Controller interface { // error : non nil error if any errors occurred ScanAll(ctx context.Context, trigger string, async bool) (int64, error) + // StopScanAll stops the scanAll + // + // Arguments: + // ctx context.Context : the context for this method + // executionID int64 : the id of scan all execution + // async bool : stop scan all in background + // Returns: + // error : non nil error if any errors occurred + StopScanAll(ctx context.Context, executionID int64, async bool) error + // GetVulnerable returns the vulnerable of the artifact for the allowlist // // Arguments: diff --git a/src/server/v2.0/handler/scan_all.go b/src/server/v2.0/handler/scan_all.go index 1359fe8dee1..111101913d6 100644 --- a/src/server/v2.0/handler/scan_all.go +++ b/src/server/v2.0/handler/scan_all.go @@ -28,7 +28,6 @@ import ( "github.com/goharbor/harbor/src/controller/scanner" "github.com/goharbor/harbor/src/jobservice/job" "github.com/goharbor/harbor/src/lib/errors" - "github.com/goharbor/harbor/src/lib/log" "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" "github.com/goharbor/harbor/src/pkg/scheduler" @@ -74,12 +73,10 @@ func (s *scanAllAPI) StopScanAll(ctx context.Context, params operation.StopScanA if execution == nil { return s.SendError(ctx, errors.BadRequestError(nil).WithMessage("no scan all job is found currently")) } - go func(ctx context.Context, eid int64) { - err := s.execMgr.Stop(ctx, eid) - if err != nil { - log.Errorf("failed to stop the execution of executionID=%+v", execution.ID) - } - }(s.makeCtx(), execution.ID) + + if err = s.scanCtl.StopScanAll(s.makeCtx(), execution.ID, true); err != nil { + return s.SendError(ctx, err) + } return operation.NewStopScanAllAccepted() } diff --git a/src/server/v2.0/handler/scan_all_test.go b/src/server/v2.0/handler/scan_all_test.go index 3003955a63e..61e3a77ed24 100644 --- a/src/server/v2.0/handler/scan_all_test.go +++ b/src/server/v2.0/handler/scan_all_test.go @@ -247,6 +247,7 @@ func (suite *ScanAllTestSuite) TestStopScanAll() { times := 3 suite.Security.On("IsAuthenticated").Return(true).Times(times) suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(times) + mock.OnAnything(suite.scanCtl, "StopScanAll").Return(nil).Times(times) mock.OnAnything(suite.scannerCtl, "ListRegistrations").Return([]*scanner.Registration{{ID: int64(1)}}, nil).Times(times) { diff --git a/src/testing/controller/scan/controller.go b/src/testing/controller/scan/controller.go index df31b2ab402..aae3b890750 100644 --- a/src/testing/controller/scan/controller.go +++ b/src/testing/controller/scan/controller.go @@ -205,6 +205,20 @@ func (_m *Controller) Stop(ctx context.Context, _a1 *artifact.Artifact) error { return r0 } +// StopScanAll provides a mock function with given fields: ctx, executionID, async +func (_m *Controller) StopScanAll(ctx context.Context, executionID int64, async bool) error { + ret := _m.Called(ctx, executionID, async) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, bool) error); ok { + r0 = rf(ctx, executionID, async) + } else { + r0 = ret.Error(0) + } + + return r0 +} + type mockConstructorTestingTNewController interface { mock.TestingT Cleanup(func())