Skip to content

Commit

Permalink
Add Scan method to SourceManager to scan a single SourceUnit (#3650)
Browse files Browse the repository at this point in the history
* renaming to enumeration

* update enumeration

* comments

* remove commented out func

* Add Scan method to SourceManager to scan a single SourceUnit

* Add tests for each Enumerate and Scan

* add source name to log

* rename scanWithUnits

* updating comments to be more clear

---------

Co-authored-by: ahmed <[email protected]>
Co-authored-by: 0x1 <[email protected]>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent 1276d26 commit 1e5aac4
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 2 deletions.
122 changes: 120 additions & 2 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,50 @@ func (s *SourceManager) Enumerate(ctx context.Context, sourceName string, source
case s.firstErr <- err:
default:
}
progress.ReportError(Fatal{err})
}
}()
return progress.Ref(), nil
}

// Scan blocks until a resource is available to run the source against a single
// SourceUnit, then asynchronously runs it. Error information is stored and
// accessible via the JobProgressRef as it becomes available.
func (s *SourceManager) Scan(ctx context.Context, sourceName string, source Source, unit SourceUnit) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
return JobProgressRef{
SourceName: sourceName,
SourceID: sourceID,
JobID: jobID,
}, err
}
// Create a JobProgress object for tracking progress.
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
if err := s.sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
}
s.wg.Add(1)
go func() {
// Call Finish after the semaphore has been released.
defer progress.Finish()
defer s.sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel(nil)
if err := s.scan(ctx, source, progress, unit); err != nil {
select {
case s.firstErr <- err:
default:
}
progress.ReportError(Fatal{err})
}
}()
return progress.Ref(), nil
Expand Down Expand Up @@ -320,7 +364,7 @@ func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgr
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}

// Check for the preferred method of tracking source units.
// Check if source units are supported and configured.
canUseSourceUnits := len(targets) == 0 && s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
Expand Down Expand Up @@ -359,7 +403,7 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}

// Check for the preferred method of tracking source units.
// Check if source units are supported and configured.
canUseSourceUnits := s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
Expand All @@ -369,6 +413,42 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo
return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())
}

// scan runs a scan against a single SourceUnit as its only job. This method
// manages the lifecycle of the provided report.
func (s *SourceManager) scan(ctx context.Context, source Source, report *JobProgress, unit SourceUnit) error {
report.Start(time.Now())
defer func() { report.End(time.Now()) }()

defer func() {
if err := context.Cause(ctx); err != nil {
report.ReportError(Fatal{err})
}
}()

report.TrackProgress(source.GetProgress())
if ctx.Value("job_id") == "" {
ctx = context.WithValue(ctx, "job_id", report.JobID)
}
if ctx.Value("source_id") == "" {
ctx = context.WithValue(ctx, "source_id", report.SourceID)
}
if ctx.Value("source_name") == "" {
ctx = context.WithValue(ctx, "source_name", report.SourceName)
}
if ctx.Value("source_type") == "" {
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}

// Check if source units are supported and configured.
canUseSourceUnits := s.useSourceUnitsFunc != nil
if unitChunker, ok := source.(SourceUnitChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
"with_units", true)
return s.scanWithUnit(ctx, unitChunker, report, unit)
}
return fmt.Errorf("source units not supported or configured for source: %s (%s)", report.SourceName, source.Type().String())
}

// enumerateWithUnits is a helper method to enumerate a Source that is also a
// SourceUnitEnumerator. This allows better introspection of what is getting
// enumerated and any errors encountered.
Expand Down Expand Up @@ -511,6 +591,44 @@ func (s *SourceManager) runWithUnits(ctx context.Context, source SourceUnitEnumC
}
}

// scanWithUnit produces chunks from a single SourceUnit.
func (s *SourceManager) scanWithUnit(ctx context.Context, source SourceUnitChunker, report *JobProgress, unit SourceUnit) error {
// Create a function that will save the first error encountered (if
// any) and discard the rest.
chunkReporter := &mgrChunkReporter{
unit: unit,
chunkCh: make(chan *Chunk, defaultChannelSize),
report: report,
}
// Produce chunks from the given unit.
var chunkErr error
go func() {
report.StartUnitChunking(unit, time.Now())
// TODO: Catch panics and add to report.
defer close(chunkReporter.chunkCh)
id, kind := unit.SourceUnitID()
ctx := context.WithValues(ctx, "unit_kind", kind, "unit", id)
ctx.Logger().V(3).Info("chunking unit")
if err := source.ChunkUnit(ctx, unit, chunkReporter); err != nil {
report.ReportError(Fatal{ChunkError{Unit: unit, Err: err}})
chunkErr = Fatal{err}
}
}()
// Consume chunks and export chunks.
// This anonymous function blocks until the chunkReporter.chunkCh is
// closed in the above goroutine.
func() {
defer func() { report.EndUnitChunking(unit, time.Now()) }()
for chunk := range chunkReporter.chunkCh {
if src, ok := source.(Source); ok {
chunk.JobID = src.JobID()
}
s.outputChunks <- chunk
}
}()
return chunkErr
}

// headlessAPI implements the apiClient interface locally.
type headlessAPI struct {
// Counters for assigning source and job IDs.
Expand Down
50 changes: 50 additions & 0 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,56 @@ func TestSourceManagerReport(t *testing.T) {
}
}

func TestSourceManagerEnumerate(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
var enumeratedUnits []SourceUnit
reporter := visitorUnitReporter{
ok: func(_ context.Context, unit SourceUnit) error {
enumeratedUnits = append(enumeratedUnits, unit)
return nil
},
}
for i := 0; i < 3; i++ {
ref, err := mgr.Enumerate(context.Background(), "dummy", source, reporter)
<-ref.Done()
assert.NoError(t, err)
assert.NoError(t, ref.Snapshot().FatalError())
// The Chunks channel should be empty because we only enumerated.
_, err = tryRead(mgr.Chunks())
assert.Error(t, err)
// Each time the loop iterates, we add 1 unit to the slice.
assert.Equal(t, i+1, len(enumeratedUnits), ref.Snapshot())
}
}

func TestSourceManagerScan(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
for i := 0; i < 3; i++ {
ref, err := mgr.Scan(context.Background(), "dummy", source, countChunk(123))
<-ref.Done()
assert.NoError(t, err)
assert.NoError(t, ref.Snapshot().FatalError())
chunk, err := tryRead(mgr.Chunks())
assert.NoError(t, err)
assert.Equal(t, []byte{123}, chunk.Data)
// The Chunks channel should be empty now.
_, err = tryRead(mgr.Chunks())
assert.Error(t, err)
}
}

type visitorUnitReporter struct {
ok func(context.Context, SourceUnit) error
err func(context.Context, error) error
}

func (v visitorUnitReporter) UnitOk(ctx context.Context, u SourceUnit) error { return v.ok(ctx, u) }
func (v visitorUnitReporter) UnitErr(ctx context.Context, err error) error { return v.err(ctx, err) }

type unitChunk struct {
unit string
output string
Expand Down

0 comments on commit 1e5aac4

Please sign in to comment.