diff --git a/x-pack/apm-server/sampling/processor.go b/x-pack/apm-server/sampling/processor.go index 52f8cc7d16b..d4561bbbecf 100644 --- a/x-pack/apm-server/sampling/processor.go +++ b/x-pack/apm-server/sampling/processor.go @@ -367,8 +367,9 @@ func (p *Processor) Run() error { remoteSampledTraceIDs := make(chan string) localSampledTraceIDs := make(chan string) - errgroup, ctx := errgroup.WithContext(context.Background()) - errgroup.Go(func() error { + publishSampledTraceIDs := make(chan string) + g, ctx := errgroup.WithContext(context.Background()) + g.Go(func() error { select { case <-ctx.Done(): return ctx.Err() @@ -376,7 +377,7 @@ func (p *Processor) Run() error { return context.Canceled } }) - errgroup.Go(func() error { + g.Go(func() error { // This goroutine is responsible for periodically garbage // collecting the Badger value log, using the recommended // discard ratio of 0.5. @@ -394,11 +395,14 @@ func (p *Processor) Run() error { } } }) - errgroup.Go(func() error { + g.Go(func() error { defer close(subscriberPositions) return pubsub.SubscribeSampledTraceIDs(ctx, initialSubscriberPosition, remoteSampledTraceIDs, subscriberPositions) }) - errgroup.Go(func() error { + g.Go(func() error { + return pubsub.PublishSampledTraceIDs(ctx, publishSampledTraceIDs) + }) + g.Go(func() error { ticker := time.NewTicker(p.config.FlushInterval) defer ticker.Stop() var traceIDs []string @@ -412,21 +416,17 @@ func (p *Processor) Run() error { if len(traceIDs) == 0 { continue } - if err := pubsub.PublishSampledTraceIDs(ctx, traceIDs...); err != nil { + var g errgroup.Group + g.Go(func() error { return sendTraceIDs(ctx, publishSampledTraceIDs, traceIDs) }) + g.Go(func() error { return sendTraceIDs(ctx, localSampledTraceIDs, traceIDs) }) + if err := g.Wait(); err != nil { return err } - for _, traceID := range traceIDs { - select { - case <-ctx.Done(): - return ctx.Err() - case localSampledTraceIDs <- traceID: - } - } traceIDs = traceIDs[:0] } } }) - errgroup.Go(func() error { + g.Go(func() error { // TODO(axw) pace the publishing over the flush interval? // Alternatively we can rely on backpressure from the reporter, // removing the artificial one second timeout from publisher code @@ -475,7 +475,7 @@ func (p *Processor) Run() error { } } }) - errgroup.Go(func() error { + g.Go(func() error { // Write subscriber position to a file on disk, to support resuming // on apm-server restart without reprocessing all indices. for { @@ -489,7 +489,7 @@ func (p *Processor) Run() error { } } }) - if err := errgroup.Wait(); err != nil && err != context.Canceled { + if err := g.Wait(); err != nil && err != context.Canceled { return err } return nil @@ -513,3 +513,14 @@ func writeSubscriberPosition(storageDir string, pos pubsub.SubscriberPosition) e } return ioutil.WriteFile(filepath.Join(storageDir, subscriberPositionFile), data, 0644) } + +func sendTraceIDs(ctx context.Context, out chan<- string, traceIDs []string) error { + for _, traceID := range traceIDs { + select { + case <-ctx.Done(): + return ctx.Err() + case out <- traceID: + } + } + return nil +} diff --git a/x-pack/apm-server/sampling/pubsub/pubsub.go b/x-pack/apm-server/sampling/pubsub/pubsub.go index bee321b4755..dc1a80e8614 100644 --- a/x-pack/apm-server/sampling/pubsub/pubsub.go +++ b/x-pack/apm-server/sampling/pubsub/pubsub.go @@ -27,6 +27,9 @@ import ( logs "github.com/elastic/apm-server/log" ) +// ErrClosed may be returned by Pubsub methods after the Close method is called. +var ErrClosed = errors.New("pubsub closed") + var errIndexNotFound = errors.New("index not found") // Pubsub provides a means of publishing and subscribing to sampled trace IDs, @@ -34,12 +37,12 @@ var errIndexNotFound = errors.New("index not found") // // An independent process will periodically reap old documents in the index. type Pubsub struct { - config Config - indexer elasticsearch.BulkIndexer + config Config } // New returns a new Pubsub which can publish and subscribe sampled trace IDs, -// using Elasticsearch for storage. +// using Elasticsearch for storage. The Pubsub.Close method must be called when +// it is no longer needed. // // Documents are expected to be indexed through a pipeline which sets the // `event.ingested` timestamp field. Another process will periodically reap @@ -51,37 +54,55 @@ func New(config Config) (*Pubsub, error) { if config.Logger == nil { config.Logger = logp.NewLogger(logs.Sampling) } - indexer, err := config.Client.NewBulkIndexer(elasticsearch.BulkIndexerConfig{ - Index: config.DataStream.String(), - FlushInterval: config.FlushInterval, + return &Pubsub{config: config}, nil +} + +// PublishSampledTraceIDs receives trace IDs from the traceIDs channel, +// indexing them into Elasticsearch. PublishSampledTraceIDs returns when +// ctx is canceled. +func (p *Pubsub) PublishSampledTraceIDs(ctx context.Context, traceIDs <-chan string) error { + indexer, err := p.config.Client.NewBulkIndexer(elasticsearch.BulkIndexerConfig{ + Index: p.config.DataStream.String(), + FlushInterval: p.config.FlushInterval, OnError: func(ctx context.Context, err error) { - config.Logger.With(logp.Error(err)).Debug("publishing sampled trace IDs failed") + p.config.Logger.With(logp.Error(err)).Debug("publishing sampled trace IDs failed") }, }) if err != nil { - return nil, err + return err } - return &Pubsub{ - config: config, - indexer: indexer, - }, nil -} -// PublishSampledTraceIDs bulk indexes traceIDs into Elasticsearch. -func (p *Pubsub) PublishSampledTraceIDs(ctx context.Context, traceID ...string) error { - now := time.Now() - for _, id := range traceID { - var json fastjson.Writer - p.marshalTraceIDDocument(&json, id, now, p.config.DataStream) - if err := p.indexer.Add(ctx, elasticsearch.BulkIndexerItem{ - Action: "create", - Body: bytes.NewReader(json.Bytes()), - OnFailure: p.onBulkIndexerItemFailure, - }); err != nil { - return err + var closeIndexerOnce sync.Once + var closeIndexerErr error + closeIndexer := func() error { + closeIndexerOnce.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), p.config.FlushInterval) + defer cancel() + closeIndexerErr = indexer.Close(ctx) + }) + return closeIndexerErr + } + defer closeIndexer() + + for { + select { + case <-ctx.Done(): + if err := ctx.Err(); err != context.Canceled { + return err + } + return closeIndexer() + case id := <-traceIDs: + var json fastjson.Writer + p.marshalTraceIDDocument(&json, id, time.Now(), p.config.DataStream) + if err := indexer.Add(ctx, elasticsearch.BulkIndexerItem{ + Action: "create", + Body: bytes.NewReader(json.Bytes()), + OnFailure: p.onBulkIndexerItemFailure, + }); err != nil { + return err + } } } - return nil } func (p *Pubsub) onBulkIndexerItemFailure(ctx context.Context, item elasticsearch.BulkIndexerItem, resp elasticsearch.BulkIndexerResponseItem, err error) { diff --git a/x-pack/apm-server/sampling/pubsub/pubsub_integration_test.go b/x-pack/apm-server/sampling/pubsub/pubsub_integration_test.go index f4d0faabacf..6cfea4b53ae 100644 --- a/x-pack/apm-server/sampling/pubsub/pubsub_integration_test.go +++ b/x-pack/apm-server/sampling/pubsub/pubsub_integration_test.go @@ -46,11 +46,6 @@ func TestElasticsearchIntegration_PublishSampledTraceIDs(t *testing.T) { client := newElasticsearchClient(t) recreateDataStream(t, client, dataStream) - var input []string - for i := 0; i < 50; i++ { - input = append(input, uuid.Must(uuid.NewV4()).String()) - } - es, err := pubsub.New(pubsub.Config{ Client: client, DataStream: dataStream, @@ -60,8 +55,27 @@ func TestElasticsearchIntegration_PublishSampledTraceIDs(t *testing.T) { }) require.NoError(t, err) - err = es.PublishSampledTraceIDs(context.Background(), input...) - assert.NoError(t, err) + var input []string + for i := 0; i < 50; i++ { + input = append(input, uuid.Must(uuid.NewV4()).String()) + } + ids := make(chan string, len(input)) + for _, id := range input { + ids <- id + } + + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return es.PublishSampledTraceIDs(ctx, ids) + }) + defer func() { + err := g.Wait() + assert.NoError(t, err) + }() + defer cancel() + + //input...) var result struct { Hits struct { diff --git a/x-pack/apm-server/sampling/pubsub/pubsub_test.go b/x-pack/apm-server/sampling/pubsub/pubsub_test.go index 5c05bb9ee15..4e6bbd3cc30 100644 --- a/x-pack/apm-server/sampling/pubsub/pubsub_test.go +++ b/x-pack/apm-server/sampling/pubsub/pubsub_test.go @@ -42,24 +42,36 @@ func TestPublishSampledTraceIDs(t *testing.T) { srv, requests := newRequestResponseWriterServer(t) pub := newPubsub(t, srv, time.Millisecond, time.Minute) - var ids []string - for i := 0; i < 20; i++ { - ids = append(ids, uuid.Must(uuid.NewV4()).String()) + input := make([]string, 20) + for i := 0; i < len(input); i++ { + input[i] = uuid.Must(uuid.NewV4()).String() } // Publish in a separate goroutine, as it may get blocked if we don't // service bulk requests. - go func() { - for i := 0; i < len(ids); i += 2 { - err := pub.PublishSampledTraceIDs(context.Background(), ids[i], ids[i+1]) - assert.NoError(t, err) - time.Sleep(10 * time.Millisecond) // sleep to force a new request + ids := make(chan string) + ctx, cancel := context.WithCancel(context.Background()) + var g errgroup.Group + defer g.Wait() + defer cancel() + g.Go(func() error { + return pub.PublishSampledTraceIDs(ctx, ids) + }) + g.Go(func() error { + for _, id := range input { + select { + case <-ctx.Done(): + return ctx.Err() + case ids <- id: + } + time.Sleep(10 * time.Millisecond) // sleep to force new requests } - }() + return nil + }) var received []string deadlineTimer := time.NewTimer(10 * time.Second) - for len(received) < len(ids) { + for len(received) < len(input) { select { case <-deadlineTimer.C: t.Fatal("timed out waiting for events to be received by server") @@ -108,7 +120,7 @@ func TestPublishSampledTraceIDs(t *testing.T) { // The publisher uses an esutil.BulkIndexer, which may index items out // of order due to having multiple goroutines picking items off a queue. - assert.ElementsMatch(t, ids, received) + assert.ElementsMatch(t, input, received) } func TestSubscribeSampledTraceIDs(t *testing.T) { @@ -386,8 +398,7 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rrw := &requestResponseWriter{ Request: r, - w: w, - done: make(chan struct{}), + done: make(chan response), } select { case <-r.Context().Done(): @@ -398,8 +409,9 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req select { case <-r.Context().Done(): w.WriteHeader(http.StatusRequestTimeout) - return - case <-rrw.done: + case response := <-rrw.done: + w.WriteHeader(response.statusCode) + w.Write([]byte(response.body)) } })) t.Cleanup(srv.Close) @@ -408,8 +420,12 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req type requestResponseWriter struct { *http.Request - w http.ResponseWriter - done chan struct{} + done chan response +} + +type response struct { + statusCode int + body string } func (w *requestResponseWriter) Write(body string) { @@ -417,9 +433,7 @@ func (w *requestResponseWriter) Write(body string) { } func (w *requestResponseWriter) WriteStatus(statusCode int, body string) { - w.w.WriteHeader(statusCode) - w.w.Write([]byte(body)) - close(w.done) + w.done <- response{statusCode, body} } func expectRequest(t testing.TB, ch <-chan *requestResponseWriter, path, body string) *requestResponseWriter {