Skip to content

Commit

Permalink
sampling/pubsub: minor fixes (#5568) (#5587)
Browse files Browse the repository at this point in the history
* sampling/pubsub: close BulkIndexer

Revise the pubsub API so that the PublishSampledTraceIDs
method is long-lived and accepts a channel of trace IDs;
this brings it closer to the API of SubscribeSampledTraceIDs,
and ensures we properly close the BulkIndexer used by the
publisher when it exits.

* sampling/pubsub: fix double WriteHeader in test

Write HTTP response from the handler goroutine,
rather than in another goroutine which waits for
the request. This avoids a race in the existing
test code, where the client receives the response
and closes the connection, causing the server's
handler context to be cancelled before the handler
returns.

(cherry picked from commit 0a985aa)

Co-authored-by: Andrew Wilkins <[email protected]>
  • Loading branch information
mergify[bot] and axw authored Jun 30, 2021
1 parent 26a957b commit c378838
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 69 deletions.
43 changes: 27 additions & 16 deletions x-pack/apm-server/sampling/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,17 @@ func (p *Processor) Run() error {

remoteSampledTraceIDs := make(chan string)
localSampledTraceIDs := make(chan string)
errgroup, ctx := errgroup.WithContext(context.Background())
errgroup.Go(func() error {
publishSampledTraceIDs := make(chan string)
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
case <-p.stopping:
return context.Canceled
}
})
errgroup.Go(func() error {
g.Go(func() error {
// This goroutine is responsible for periodically garbage
// collecting the Badger value log, using the recommended
// discard ratio of 0.5.
Expand All @@ -394,11 +395,14 @@ func (p *Processor) Run() error {
}
}
})
errgroup.Go(func() error {
g.Go(func() error {
defer close(subscriberPositions)
return pubsub.SubscribeSampledTraceIDs(ctx, initialSubscriberPosition, remoteSampledTraceIDs, subscriberPositions)
})
errgroup.Go(func() error {
g.Go(func() error {
return pubsub.PublishSampledTraceIDs(ctx, publishSampledTraceIDs)
})
g.Go(func() error {
ticker := time.NewTicker(p.config.FlushInterval)
defer ticker.Stop()
var traceIDs []string
Expand All @@ -412,21 +416,17 @@ func (p *Processor) Run() error {
if len(traceIDs) == 0 {
continue
}
if err := pubsub.PublishSampledTraceIDs(ctx, traceIDs...); err != nil {
var g errgroup.Group
g.Go(func() error { return sendTraceIDs(ctx, publishSampledTraceIDs, traceIDs) })
g.Go(func() error { return sendTraceIDs(ctx, localSampledTraceIDs, traceIDs) })
if err := g.Wait(); err != nil {
return err
}
for _, traceID := range traceIDs {
select {
case <-ctx.Done():
return ctx.Err()
case localSampledTraceIDs <- traceID:
}
}
traceIDs = traceIDs[:0]
}
}
})
errgroup.Go(func() error {
g.Go(func() error {
// TODO(axw) pace the publishing over the flush interval?
// Alternatively we can rely on backpressure from the reporter,
// removing the artificial one second timeout from publisher code
Expand Down Expand Up @@ -475,7 +475,7 @@ func (p *Processor) Run() error {
}
}
})
errgroup.Go(func() error {
g.Go(func() error {
// Write subscriber position to a file on disk, to support resuming
// on apm-server restart without reprocessing all indices.
for {
Expand All @@ -489,7 +489,7 @@ func (p *Processor) Run() error {
}
}
})
if err := errgroup.Wait(); err != nil && err != context.Canceled {
if err := g.Wait(); err != nil && err != context.Canceled {
return err
}
return nil
Expand All @@ -513,3 +513,14 @@ func writeSubscriberPosition(storageDir string, pos pubsub.SubscriberPosition) e
}
return ioutil.WriteFile(filepath.Join(storageDir, subscriberPositionFile), data, 0644)
}

func sendTraceIDs(ctx context.Context, out chan<- string, traceIDs []string) error {
for _, traceID := range traceIDs {
select {
case <-ctx.Done():
return ctx.Err()
case out <- traceID:
}
}
return nil
}
73 changes: 47 additions & 26 deletions x-pack/apm-server/sampling/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ import (
logs "github.com/elastic/apm-server/log"
)

// ErrClosed may be returned by Pubsub methods after the Close method is called.
var ErrClosed = errors.New("pubsub closed")

var errIndexNotFound = errors.New("index not found")

// Pubsub provides a means of publishing and subscribing to sampled trace IDs,
// using Elasticsearch for temporary storage.
//
// An independent process will periodically reap old documents in the index.
type Pubsub struct {
config Config
indexer elasticsearch.BulkIndexer
config Config
}

// New returns a new Pubsub which can publish and subscribe sampled trace IDs,
// using Elasticsearch for storage.
// using Elasticsearch for storage. The Pubsub.Close method must be called when
// it is no longer needed.
//
// Documents are expected to be indexed through a pipeline which sets the
// `event.ingested` timestamp field. Another process will periodically reap
Expand All @@ -51,37 +54,55 @@ func New(config Config) (*Pubsub, error) {
if config.Logger == nil {
config.Logger = logp.NewLogger(logs.Sampling)
}
indexer, err := config.Client.NewBulkIndexer(elasticsearch.BulkIndexerConfig{
Index: config.DataStream.String(),
FlushInterval: config.FlushInterval,
return &Pubsub{config: config}, nil
}

// PublishSampledTraceIDs receives trace IDs from the traceIDs channel,
// indexing them into Elasticsearch. PublishSampledTraceIDs returns when
// ctx is canceled.
func (p *Pubsub) PublishSampledTraceIDs(ctx context.Context, traceIDs <-chan string) error {
indexer, err := p.config.Client.NewBulkIndexer(elasticsearch.BulkIndexerConfig{
Index: p.config.DataStream.String(),
FlushInterval: p.config.FlushInterval,
OnError: func(ctx context.Context, err error) {
config.Logger.With(logp.Error(err)).Debug("publishing sampled trace IDs failed")
p.config.Logger.With(logp.Error(err)).Debug("publishing sampled trace IDs failed")
},
})
if err != nil {
return nil, err
return err
}
return &Pubsub{
config: config,
indexer: indexer,
}, nil
}

// PublishSampledTraceIDs bulk indexes traceIDs into Elasticsearch.
func (p *Pubsub) PublishSampledTraceIDs(ctx context.Context, traceID ...string) error {
now := time.Now()
for _, id := range traceID {
var json fastjson.Writer
p.marshalTraceIDDocument(&json, id, now, p.config.DataStream)
if err := p.indexer.Add(ctx, elasticsearch.BulkIndexerItem{
Action: "create",
Body: bytes.NewReader(json.Bytes()),
OnFailure: p.onBulkIndexerItemFailure,
}); err != nil {
return err
var closeIndexerOnce sync.Once
var closeIndexerErr error
closeIndexer := func() error {
closeIndexerOnce.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), p.config.FlushInterval)
defer cancel()
closeIndexerErr = indexer.Close(ctx)
})
return closeIndexerErr
}
defer closeIndexer()

for {
select {
case <-ctx.Done():
if err := ctx.Err(); err != context.Canceled {
return err
}
return closeIndexer()
case id := <-traceIDs:
var json fastjson.Writer
p.marshalTraceIDDocument(&json, id, time.Now(), p.config.DataStream)
if err := indexer.Add(ctx, elasticsearch.BulkIndexerItem{
Action: "create",
Body: bytes.NewReader(json.Bytes()),
OnFailure: p.onBulkIndexerItemFailure,
}); err != nil {
return err
}
}
}
return nil
}

func (p *Pubsub) onBulkIndexerItemFailure(ctx context.Context, item elasticsearch.BulkIndexerItem, resp elasticsearch.BulkIndexerResponseItem, err error) {
Expand Down
28 changes: 21 additions & 7 deletions x-pack/apm-server/sampling/pubsub/pubsub_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ func TestElasticsearchIntegration_PublishSampledTraceIDs(t *testing.T) {
client := newElasticsearchClient(t)
recreateDataStream(t, client, dataStream)

var input []string
for i := 0; i < 50; i++ {
input = append(input, uuid.Must(uuid.NewV4()).String())
}

es, err := pubsub.New(pubsub.Config{
Client: client,
DataStream: dataStream,
Expand All @@ -60,8 +55,27 @@ func TestElasticsearchIntegration_PublishSampledTraceIDs(t *testing.T) {
})
require.NoError(t, err)

err = es.PublishSampledTraceIDs(context.Background(), input...)
assert.NoError(t, err)
var input []string
for i := 0; i < 50; i++ {
input = append(input, uuid.Must(uuid.NewV4()).String())
}
ids := make(chan string, len(input))
for _, id := range input {
ids <- id
}

ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return es.PublishSampledTraceIDs(ctx, ids)
})
defer func() {
err := g.Wait()
assert.NoError(t, err)
}()
defer cancel()

//input...)

var result struct {
Hits struct {
Expand Down
54 changes: 34 additions & 20 deletions x-pack/apm-server/sampling/pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,36 @@ func TestPublishSampledTraceIDs(t *testing.T) {
srv, requests := newRequestResponseWriterServer(t)
pub := newPubsub(t, srv, time.Millisecond, time.Minute)

var ids []string
for i := 0; i < 20; i++ {
ids = append(ids, uuid.Must(uuid.NewV4()).String())
input := make([]string, 20)
for i := 0; i < len(input); i++ {
input[i] = uuid.Must(uuid.NewV4()).String()
}

// Publish in a separate goroutine, as it may get blocked if we don't
// service bulk requests.
go func() {
for i := 0; i < len(ids); i += 2 {
err := pub.PublishSampledTraceIDs(context.Background(), ids[i], ids[i+1])
assert.NoError(t, err)
time.Sleep(10 * time.Millisecond) // sleep to force a new request
ids := make(chan string)
ctx, cancel := context.WithCancel(context.Background())
var g errgroup.Group
defer g.Wait()
defer cancel()
g.Go(func() error {
return pub.PublishSampledTraceIDs(ctx, ids)
})
g.Go(func() error {
for _, id := range input {
select {
case <-ctx.Done():
return ctx.Err()
case ids <- id:
}
time.Sleep(10 * time.Millisecond) // sleep to force new requests
}
}()
return nil
})

var received []string
deadlineTimer := time.NewTimer(10 * time.Second)
for len(received) < len(ids) {
for len(received) < len(input) {
select {
case <-deadlineTimer.C:
t.Fatal("timed out waiting for events to be received by server")
Expand Down Expand Up @@ -108,7 +120,7 @@ func TestPublishSampledTraceIDs(t *testing.T) {

// The publisher uses an esutil.BulkIndexer, which may index items out
// of order due to having multiple goroutines picking items off a queue.
assert.ElementsMatch(t, ids, received)
assert.ElementsMatch(t, input, received)
}

func TestSubscribeSampledTraceIDs(t *testing.T) {
Expand Down Expand Up @@ -386,8 +398,7 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rrw := &requestResponseWriter{
Request: r,
w: w,
done: make(chan struct{}),
done: make(chan response),
}
select {
case <-r.Context().Done():
Expand All @@ -398,8 +409,9 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req
select {
case <-r.Context().Done():
w.WriteHeader(http.StatusRequestTimeout)
return
case <-rrw.done:
case response := <-rrw.done:
w.WriteHeader(response.statusCode)
w.Write([]byte(response.body))
}
}))
t.Cleanup(srv.Close)
Expand All @@ -408,18 +420,20 @@ func newRequestResponseWriterServer(t testing.TB) (*httptest.Server, <-chan *req

type requestResponseWriter struct {
*http.Request
w http.ResponseWriter
done chan struct{}
done chan response
}

type response struct {
statusCode int
body string
}

func (w *requestResponseWriter) Write(body string) {
w.WriteStatus(http.StatusOK, body)
}

func (w *requestResponseWriter) WriteStatus(statusCode int, body string) {
w.w.WriteHeader(statusCode)
w.w.Write([]byte(body))
close(w.done)
w.done <- response{statusCode, body}
}

func expectRequest(t testing.TB, ch <-chan *requestResponseWriter, path, body string) *requestResponseWriter {
Expand Down

0 comments on commit c378838

Please sign in to comment.