diff --git a/pkg/agent/flags/flag_controller.go b/pkg/agent/flags/flag_controller.go index c18915f01..a1d9b6706 100644 --- a/pkg/agent/flags/flag_controller.go +++ b/pkg/agent/flags/flag_controller.go @@ -465,6 +465,17 @@ func (fc *FlagController) TraceSamplingRate() float64 { ).get(fc.getControlServerValue(keys.TraceSamplingRate)) } +func (fc *FlagController) SetTraceBatchTimeout(duration time.Duration) error { + return fc.setControlServerValue(keys.TraceBatchTimeout, durationToBytes(duration)) +} +func (fc *FlagController) TraceBatchTimeout() time.Duration { + return NewDurationFlagValue(fc.logger, keys.TraceBatchTimeout, + WithDefault(1*time.Minute), + WithMin(5*time.Second), + WithMax(1*time.Hour), + ).get(fc.getControlServerValue(keys.TraceBatchTimeout)) +} + func (fc *FlagController) SetLogIngestServerURL(url string) error { return fc.setControlServerValue(keys.LogIngestServerURL, []byte(url)) } diff --git a/pkg/agent/flags/keys/keys.go b/pkg/agent/flags/keys/keys.go index cc7df4e20..095dd7065 100644 --- a/pkg/agent/flags/keys/keys.go +++ b/pkg/agent/flags/keys/keys.go @@ -45,6 +45,7 @@ const ( UpdateDirectory FlagKey = "update_directory" ExportTraces FlagKey = "export_traces" TraceSamplingRate FlagKey = "trace_sampling_rate" + TraceBatchTimeout FlagKey = "trace_batch_timeout" LogIngestServerURL FlagKey = "log_ingest_url" TraceIngestServerURL FlagKey = "trace_ingest_url" DisableTraceIngestTLS FlagKey = "disable_trace_ingest_tls" diff --git a/pkg/agent/knapsack/knapsack.go b/pkg/agent/knapsack/knapsack.go index 29e687dca..217eb9351 100644 --- a/pkg/agent/knapsack/knapsack.go +++ b/pkg/agent/knapsack/knapsack.go @@ -393,6 +393,13 @@ func (k *knapsack) DisableTraceIngestTLS() bool { return k.flags.DisableTraceIngestTLS() } +func (k *knapsack) SetTraceBatchTimeout(duration time.Duration) error { + return k.flags.SetTraceBatchTimeout(duration) +} +func (k *knapsack) TraceBatchTimeout() time.Duration { + return k.flags.TraceBatchTimeout() +} + func (k *knapsack) SetLogIngestServerURL(url string) error { return k.flags.SetLogIngestServerURL(url) } diff --git a/pkg/agent/types/flags.go b/pkg/agent/types/flags.go index 2873faca4..ae764d5c6 100644 --- a/pkg/agent/types/flags.go +++ b/pkg/agent/types/flags.go @@ -189,6 +189,10 @@ type Flags interface { SetDisableTraceIngestTLS(enabled bool) error DisableTraceIngestTLS() bool + // TraceBatchTimeout is the maximum amount of time before the trace exporter will export the next batch of spans + SetTraceBatchTimeout(duration time.Duration) error + TraceBatchTimeout() time.Duration + // InModernStandby indicates whether a Windows machine is awake or in modern standby SetInModernStandby(enabled bool) error InModernStandby() bool diff --git a/pkg/agent/types/mocks/flags.go b/pkg/agent/types/mocks/flags.go index 87176985d..bdf098f96 100644 --- a/pkg/agent/types/mocks/flags.go +++ b/pkg/agent/types/mocks/flags.go @@ -1033,6 +1033,20 @@ func (_m *Flags) SetOsqueryVerbose(verbose bool) error { return r0 } +// SetTraceBatchTimeout provides a mock function with given fields: duration +func (_m *Flags) SetTraceBatchTimeout(duration time.Duration) error { + ret := _m.Called(duration) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(duration) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SetTraceIngestServerURL provides a mock function with given fields: url func (_m *Flags) SetTraceIngestServerURL(url string) error { ret := _m.Called(url) @@ -1103,6 +1117,20 @@ func (_m *Flags) SetUpdateDirectory(directory string) error { return r0 } +// TraceBatchTimeout provides a mock function with given fields: +func (_m *Flags) TraceBatchTimeout() time.Duration { + ret := _m.Called() + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + // TraceIngestServerURL provides a mock function with given fields: func (_m *Flags) TraceIngestServerURL() string { ret := _m.Called() diff --git a/pkg/agent/types/mocks/knapsack.go b/pkg/agent/types/mocks/knapsack.go index ecb1a4586..24c4ce90d 100644 --- a/pkg/agent/types/mocks/knapsack.go +++ b/pkg/agent/types/mocks/knapsack.go @@ -1212,6 +1212,20 @@ func (_m *Knapsack) SetOsqueryVerbose(verbose bool) error { return r0 } +// SetTraceBatchTimeout provides a mock function with given fields: duration +func (_m *Knapsack) SetTraceBatchTimeout(duration time.Duration) error { + ret := _m.Called(duration) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(duration) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SetTraceIngestServerURL provides a mock function with given fields: url func (_m *Knapsack) SetTraceIngestServerURL(url string) error { ret := _m.Called(url) @@ -1314,6 +1328,20 @@ func (_m *Knapsack) TokenStore() types.GetterSetterDeleterIteratorUpdater { return r0 } +// TraceBatchTimeout provides a mock function with given fields: +func (_m *Knapsack) TraceBatchTimeout() time.Duration { + ret := _m.Called() + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + // TraceIngestServerURL provides a mock function with given fields: func (_m *Knapsack) TraceIngestServerURL() string { ret := _m.Called() diff --git a/pkg/traces/exporter/exporter.go b/pkg/traces/exporter/exporter.go index 5a51a22e1..1548a6236 100644 --- a/pkg/traces/exporter/exporter.go +++ b/pkg/traces/exporter/exporter.go @@ -55,6 +55,7 @@ type TraceExporter struct { disableIngestTLS bool enabled bool traceSamplingRate float64 + batchTimeout time.Duration ctx context.Context // nolint:containedctx cancel context.CancelFunc interrupted bool @@ -90,13 +91,14 @@ func NewTraceExporter(ctx context.Context, k types.Knapsack, client osquery.Quer disableIngestTLS: k.DisableTraceIngestTLS(), enabled: k.ExportTraces(), traceSamplingRate: k.TraceSamplingRate(), + batchTimeout: k.TraceBatchTimeout(), ctx: ctx, cancel: cancel, } - // Observe ExportTraces and IngestServerURL changes to know when to start/stop exporting, and where - // to export to - t.knapsack.RegisterChangeObserver(t, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS) + // Observe changes to trace configuration to know when to start/stop exporting, and when + // to adjust exporting behavior + t.knapsack.RegisterChangeObserver(t, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout) if !t.enabled { return t, nil @@ -251,7 +253,7 @@ func (t *TraceExporter) setNewGlobalProvider() { parentBasedSampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(t.traceSamplingRate)) newProvider := sdktrace.NewTracerProvider( - sdktrace.WithBatcher(exp), + sdktrace.WithBatcher(exp, sdktrace.WithBatchTimeout(t.batchTimeout)), sdktrace.WithResource(r), sdktrace.WithSampler(parentBasedSampler), ) @@ -356,6 +358,15 @@ func (t *TraceExporter) FlagsChanged(flagKeys ...keys.FlagKey) { } } + // Handle trace_batch_timeout updates + if slices.Contains(flagKeys, keys.TraceBatchTimeout) { + if t.batchTimeout != t.knapsack.TraceBatchTimeout() { + t.batchTimeout = t.knapsack.TraceBatchTimeout() + needsNewProvider = true + level.Debug(t.logger).Log("msg", "updating trace batch timeout", "new_batch_timeout", t.batchTimeout) + } + } + if !t.enabled || !needsNewProvider { return } diff --git a/pkg/traces/exporter/exporter_test.go b/pkg/traces/exporter/exporter_test.go index 63f26c410..4ec5d1041 100644 --- a/pkg/traces/exporter/exporter_test.go +++ b/pkg/traces/exporter/exporter_test.go @@ -42,7 +42,8 @@ func TestNewTraceExporter(t *testing.T) { //nolint:paralleltest mockKnapsack.On("DisableTraceIngestTLS").Return(false) mockKnapsack.On("ExportTraces").Return(true) mockKnapsack.On("TraceSamplingRate").Return(1.0) - mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil) + mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute) + mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil) osqueryClient := mocks.NewQuerier(t) osqueryClient.On("Query", mock.Anything).Return([]map[string]string{ @@ -85,7 +86,8 @@ func TestNewTraceExporter_exportNotEnabled(t *testing.T) { mockKnapsack.On("DisableTraceIngestTLS").Return(false) mockKnapsack.On("ExportTraces").Return(false) mockKnapsack.On("TraceSamplingRate").Return(0.0) - mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil) + mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute) + mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil) traceExporter, err := NewTraceExporter(context.Background(), mockKnapsack, mocks.NewQuerier(t), log.NewNopLogger()) require.NoError(t, err) @@ -122,7 +124,8 @@ func TestInterrupt_Multiple(t *testing.T) { mockKnapsack.On("DisableTraceIngestTLS").Return(false) mockKnapsack.On("ExportTraces").Return(false) mockKnapsack.On("TraceSamplingRate").Return(0.0) - mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil) + mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute) + mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil) traceExporter, err := NewTraceExporter(context.Background(), mockKnapsack, mocks.NewQuerier(t), log.NewNopLogger()) require.NoError(t, err) @@ -603,6 +606,75 @@ func TestFlagsChanged_DisableTraceIngestTLS(t *testing.T) { //nolint:paralleltes } } +func TestFlagsChanged_TraceBatchTimeout(t *testing.T) { //nolint:paralleltest + tests := []struct { + testName string + currentBatchTimeout time.Duration + newBatchTimeout time.Duration + tracingEnabled bool + shouldReplaceProvider bool + }{ + { + testName: "update", + currentBatchTimeout: 1 * time.Minute, + newBatchTimeout: 5 * time.Second, + tracingEnabled: true, + shouldReplaceProvider: true, + }, + { + testName: "update but tracing not enabled", + currentBatchTimeout: 1 * time.Minute, + newBatchTimeout: 5 * time.Second, + tracingEnabled: false, + shouldReplaceProvider: false, + }, + { + testName: "no update", + currentBatchTimeout: 1 * time.Minute, + newBatchTimeout: 1 * time.Minute, + tracingEnabled: true, + shouldReplaceProvider: false, + }, + } + + for _, tt := range tests { //nolint:paralleltest + tt := tt + t.Run(tt.testName, func(t *testing.T) { + mockKnapsack := typesmocks.NewKnapsack(t) + mockKnapsack.On("TraceBatchTimeout").Return(tt.newBatchTimeout) + osqueryClient := mocks.NewQuerier(t) + + ctx, cancel := context.WithCancel(context.Background()) + traceExporter := &TraceExporter{ + knapsack: mockKnapsack, + osqueryClient: osqueryClient, + logger: log.NewNopLogger(), + attrs: make([]attribute.KeyValue, 0), + attrLock: sync.RWMutex{}, + ingestClientAuthenticator: newClientAuthenticator("test token", false), + ingestAuthToken: "test token", + ingestUrl: "localhost:4317", + disableIngestTLS: false, + enabled: tt.tracingEnabled, + traceSamplingRate: 1.0, + batchTimeout: tt.currentBatchTimeout, + ctx: ctx, + cancel: cancel, + } + + traceExporter.FlagsChanged(keys.TraceBatchTimeout) + + require.Equal(t, tt.newBatchTimeout, traceExporter.batchTimeout, "batch timeout value not updated") + + if tt.shouldReplaceProvider { + require.NotNil(t, traceExporter.provider) + } else { + require.Nil(t, traceExporter.provider) + } + }) + } +} + func testServerProvidedDataStore(t *testing.T) types.KVStore { s, err := storageci.NewStore(t, log.NewNopLogger(), storage.ServerProvidedDataStore.String()) require.NoError(t, err)