From 1e5aac44955de69cdcfdf30eb8b279bc1d0a477f Mon Sep 17 00:00:00 2001 From: Miccah Date: Mon, 25 Nov 2024 11:26:11 -0800 Subject: [PATCH] Add Scan method to SourceManager to scan a single SourceUnit (#3650) * renaming to enumeration * update enumeration * comments * remove commented out func * Add Scan method to SourceManager to scan a single SourceUnit * Add tests for each Enumerate and Scan * add source name to log * rename scanWithUnits * updating comments to be more clear --------- Co-authored-by: ahmed Co-authored-by: 0x1 <13666360+0x1@users.noreply.github.com> --- pkg/sources/source_manager.go | 122 ++++++++++++++++++++++++++++- pkg/sources/source_manager_test.go | 50 ++++++++++++ 2 files changed, 170 insertions(+), 2 deletions(-) diff --git a/pkg/sources/source_manager.go b/pkg/sources/source_manager.go index 3a3ce736d55e..43784141f4a8 100644 --- a/pkg/sources/source_manager.go +++ b/pkg/sources/source_manager.go @@ -212,6 +212,50 @@ func (s *SourceManager) Enumerate(ctx context.Context, sourceName string, source case s.firstErr <- err: default: } + progress.ReportError(Fatal{err}) + } + }() + return progress.Ref(), nil +} + +// Scan blocks until a resource is available to run the source against a single +// SourceUnit, then asynchronously runs it. Error information is stored and +// accessible via the JobProgressRef as it becomes available. +func (s *SourceManager) Scan(ctx context.Context, sourceName string, source Source, unit SourceUnit) (JobProgressRef, error) { + sourceID, jobID := source.SourceID(), source.JobID() + // Do preflight checks before waiting on the pool. + if err := s.preflightChecks(ctx); err != nil { + return JobProgressRef{ + SourceName: sourceName, + SourceID: sourceID, + JobID: jobID, + }, err + } + // Create a JobProgress object for tracking progress. + ctx, cancel := context.WithCancelCause(ctx) + progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel)) + if err := s.sem.Acquire(ctx, 1); err != nil { + // Context cancelled. + progress.ReportError(Fatal{err}) + return progress.Ref(), Fatal{err} + } + s.wg.Add(1) + go func() { + // Call Finish after the semaphore has been released. + defer progress.Finish() + defer s.sem.Release(1) + defer s.wg.Done() + ctx := context.WithValues(ctx, + "source_manager_worker_id", common.RandomID(5), + ) + defer common.Recover(ctx) + defer cancel(nil) + if err := s.scan(ctx, source, progress, unit); err != nil { + select { + case s.firstErr <- err: + default: + } + progress.ReportError(Fatal{err}) } }() return progress.Ref(), nil @@ -320,7 +364,7 @@ func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgr ctx = context.WithValue(ctx, "source_type", source.Type().String()) } - // Check for the preferred method of tracking source units. + // Check if source units are supported and configured. canUseSourceUnits := len(targets) == 0 && s.useSourceUnitsFunc != nil if enumChunker, ok := source.(SourceUnitEnumChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() { ctx.Logger().Info("running source", @@ -359,7 +403,7 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo ctx = context.WithValue(ctx, "source_type", source.Type().String()) } - // Check for the preferred method of tracking source units. + // Check if source units are supported and configured. canUseSourceUnits := s.useSourceUnitsFunc != nil if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() { ctx.Logger().Info("running source", @@ -369,6 +413,42 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String()) } +// scan runs a scan against a single SourceUnit as its only job. This method +// manages the lifecycle of the provided report. +func (s *SourceManager) scan(ctx context.Context, source Source, report *JobProgress, unit SourceUnit) error { + report.Start(time.Now()) + defer func() { report.End(time.Now()) }() + + defer func() { + if err := context.Cause(ctx); err != nil { + report.ReportError(Fatal{err}) + } + }() + + report.TrackProgress(source.GetProgress()) + if ctx.Value("job_id") == "" { + ctx = context.WithValue(ctx, "job_id", report.JobID) + } + if ctx.Value("source_id") == "" { + ctx = context.WithValue(ctx, "source_id", report.SourceID) + } + if ctx.Value("source_name") == "" { + ctx = context.WithValue(ctx, "source_name", report.SourceName) + } + if ctx.Value("source_type") == "" { + ctx = context.WithValue(ctx, "source_type", source.Type().String()) + } + + // Check if source units are supported and configured. + canUseSourceUnits := s.useSourceUnitsFunc != nil + if unitChunker, ok := source.(SourceUnitChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() { + ctx.Logger().Info("running source", + "with_units", true) + return s.scanWithUnit(ctx, unitChunker, report, unit) + } + return fmt.Errorf("source units not supported or configured for source: %s (%s)", report.SourceName, source.Type().String()) +} + // enumerateWithUnits is a helper method to enumerate a Source that is also a // SourceUnitEnumerator. This allows better introspection of what is getting // enumerated and any errors encountered. @@ -511,6 +591,44 @@ func (s *SourceManager) runWithUnits(ctx context.Context, source SourceUnitEnumC } } +// scanWithUnit produces chunks from a single SourceUnit. +func (s *SourceManager) scanWithUnit(ctx context.Context, source SourceUnitChunker, report *JobProgress, unit SourceUnit) error { + // Create a function that will save the first error encountered (if + // any) and discard the rest. + chunkReporter := &mgrChunkReporter{ + unit: unit, + chunkCh: make(chan *Chunk, defaultChannelSize), + report: report, + } + // Produce chunks from the given unit. + var chunkErr error + go func() { + report.StartUnitChunking(unit, time.Now()) + // TODO: Catch panics and add to report. + defer close(chunkReporter.chunkCh) + id, kind := unit.SourceUnitID() + ctx := context.WithValues(ctx, "unit_kind", kind, "unit", id) + ctx.Logger().V(3).Info("chunking unit") + if err := source.ChunkUnit(ctx, unit, chunkReporter); err != nil { + report.ReportError(Fatal{ChunkError{Unit: unit, Err: err}}) + chunkErr = Fatal{err} + } + }() + // Consume chunks and export chunks. + // This anonymous function blocks until the chunkReporter.chunkCh is + // closed in the above goroutine. + func() { + defer func() { report.EndUnitChunking(unit, time.Now()) }() + for chunk := range chunkReporter.chunkCh { + if src, ok := source.(Source); ok { + chunk.JobID = src.JobID() + } + s.outputChunks <- chunk + } + }() + return chunkErr +} + // headlessAPI implements the apiClient interface locally. type headlessAPI struct { // Counters for assigning source and job IDs. diff --git a/pkg/sources/source_manager_test.go b/pkg/sources/source_manager_test.go index 4706ce141f40..c2dadd403ea2 100644 --- a/pkg/sources/source_manager_test.go +++ b/pkg/sources/source_manager_test.go @@ -173,6 +173,56 @@ func TestSourceManagerReport(t *testing.T) { } } +func TestSourceManagerEnumerate(t *testing.T) { + mgr := NewManager(WithBufferedOutput(8), WithSourceUnits()) + source, err := buildDummy(&counterChunker{count: 1}) + assert.NoError(t, err) + var enumeratedUnits []SourceUnit + reporter := visitorUnitReporter{ + ok: func(_ context.Context, unit SourceUnit) error { + enumeratedUnits = append(enumeratedUnits, unit) + return nil + }, + } + for i := 0; i < 3; i++ { + ref, err := mgr.Enumerate(context.Background(), "dummy", source, reporter) + <-ref.Done() + assert.NoError(t, err) + assert.NoError(t, ref.Snapshot().FatalError()) + // The Chunks channel should be empty because we only enumerated. + _, err = tryRead(mgr.Chunks()) + assert.Error(t, err) + // Each time the loop iterates, we add 1 unit to the slice. + assert.Equal(t, i+1, len(enumeratedUnits), ref.Snapshot()) + } +} + +func TestSourceManagerScan(t *testing.T) { + mgr := NewManager(WithBufferedOutput(8), WithSourceUnits()) + source, err := buildDummy(&counterChunker{count: 1}) + assert.NoError(t, err) + for i := 0; i < 3; i++ { + ref, err := mgr.Scan(context.Background(), "dummy", source, countChunk(123)) + <-ref.Done() + assert.NoError(t, err) + assert.NoError(t, ref.Snapshot().FatalError()) + chunk, err := tryRead(mgr.Chunks()) + assert.NoError(t, err) + assert.Equal(t, []byte{123}, chunk.Data) + // The Chunks channel should be empty now. + _, err = tryRead(mgr.Chunks()) + assert.Error(t, err) + } +} + +type visitorUnitReporter struct { + ok func(context.Context, SourceUnit) error + err func(context.Context, error) error +} + +func (v visitorUnitReporter) UnitOk(ctx context.Context, u SourceUnit) error { return v.ok(ctx, u) } +func (v visitorUnitReporter) UnitErr(ctx context.Context, err error) error { return v.err(ctx, err) } + type unitChunk struct { unit string output string