From 61c984263ab14b89ec86703ef147321b4416bb7a Mon Sep 17 00:00:00 2001 From: Franz Eichhorn Date: Tue, 4 Apr 2023 09:00:33 +0200 Subject: [PATCH] Added context wrapper option for processors (#420) * Added context wrapper option for processors * remove explicit poll-timeout --- examples/docker-compose.yml | 16 ++++++ graph.go | 3 +- integrationtest/processor_test.go | 70 +++++++++++++++++++++++++++ options.go | 12 +++++ partition_processor.go | 11 +++-- processor.go | 5 +- systemtest/commit_test.go | 12 ++--- systemtest/emitter_disconnect_test.go | 2 +- systemtest/multitopic_test.go | 4 +- systemtest/proc_disconnect_test.go | 2 +- systemtest/processor_test.go | 28 +++++------ systemtest/processor_visit_test.go | 31 +++++------- systemtest/utils_test.go | 16 ++++-- systemtest/view_reconnect_test.go | 10 ++-- tester/consumergroup.go | 5 +- 15 files changed, 166 insertions(+), 61 deletions(-) diff --git a/examples/docker-compose.yml b/examples/docker-compose.yml index f586413e..3a19f462 100644 --- a/examples/docker-compose.yml +++ b/examples/docker-compose.yml @@ -10,10 +10,18 @@ services: ZOO_MY_ID: 1 ZOO_PORT: 2181 ZOO_SERVERS: server.1=zoo1:2888:3888 + ulimits: + nofile: + soft: 65536 + hard: 65536 kafka1: image: confluentinc/cp-kafka:5.4.0 hostname: kafka1 container_name: kafka1 + ulimits: + nofile: + soft: 65536 + hard: 65536 ports: - "9092:9092" environment: @@ -33,6 +41,10 @@ services: image: confluentinc/cp-kafka:5.4.0 hostname: kafka2 container_name: kafka2 + ulimits: + nofile: + soft: 65536 + hard: 65536 ports: - "9093:9093" environment: @@ -52,6 +64,10 @@ services: image: confluentinc/cp-kafka:5.4.0 hostname: kafka3 container_name: kafka3 + ulimits: + nofile: + soft: 65536 + hard: 65536 ports: - "9094:9094" environment: diff --git a/graph.go b/graph.go index 3f485fc5..4184b6af 100644 --- a/graph.go +++ b/graph.go @@ -168,7 +168,8 @@ func (gg *GroupGraph) joint(topic string) bool { // DefineGroup creates a group graph with a given group name and a list of // edges. func DefineGroup(group Group, edges ...Edge) *GroupGraph { - gg := GroupGraph{group: string(group), + gg := GroupGraph{ + group: string(group), codecs: make(map[string]Codec), callbacks: make(map[string]ProcessCallback), joinCheck: make(map[string]bool), diff --git a/integrationtest/processor_test.go b/integrationtest/processor_test.go index 186ef525..c2c020d0 100644 --- a/integrationtest/processor_test.go +++ b/integrationtest/processor_test.go @@ -223,6 +223,76 @@ func TestProcessorVisit(t *testing.T) { require.NoError(t, err) } +type ( + gokaCtx = goka.Context + wrapper struct { + gokaCtx + value int64 + } +) + +func (w *wrapper) SetValue(value interface{}, options ...goka.ContextOption) { + val := value.(int64) + w.value = val + w.gokaCtx.SetValue(val + 1) +} + +func TestProcessorContextWrapper(t *testing.T) { + gkt := tester.New(t) + + // holds the last wrapper + var w *wrapper + + // create a new processor, registering the tester + proc, _ := goka.NewProcessor([]string{}, goka.DefineGroup("proc", + goka.Input("input", new(codec.Int64), func(ctx goka.Context, msg interface{}) { + ctx.SetValue(msg) + }), + goka.Visitor("visit", func(ctx goka.Context, msg interface{}) { + ctx.SetValue(msg.(int64)) + }), + goka.Persist(new(codec.Int64)), + ), + goka.WithTester(gkt), + goka.WithContextWrapper(func(ctx goka.Context) goka.Context { + w = &wrapper{ + gokaCtx: ctx, + } + return w + }), + ) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + // start it + go func() { + defer close(done) + err := proc.Run(ctx) + if err != nil { + t.Errorf("error running processor: %v", err) + } + }() + + // send a message + gkt.Consume("input", "key", int64(23)) + + // both wrapper value and real value are set + require.EqualValues(t, 23, w.value) + require.EqualValues(t, 24, gkt.TableValue("proc-table", "key")) + + // also the visitor should wrap the context + err := proc.VisitAll(ctx, "visit", int64(815)) + require.NoError(t, err) + + // both values are set again + require.EqualValues(t, 815, w.value) + require.EqualValues(t, 816, gkt.TableValue("proc-table", "key")) + + cancel() + <-done +} + /* import ( "context" diff --git a/options.go b/options.go index 073f7f15..4fba8435 100644 --- a/options.go +++ b/options.go @@ -129,6 +129,7 @@ type poptions struct { updateCallback UpdateCallback rebalanceCallback RebalanceCallback + contextWrapper ContextWrapper partitionChannelSize int hasher func() hash.Hash32 nilHandling NilHandling @@ -147,6 +148,15 @@ type poptions struct { } } +// WithContextWrapper allows to intercept the context passed to each callback invocation. +// The wrapper function will be called concurrently across all partitions the returned context +// must not be shared. +func WithContextWrapper(wrapper ContextWrapper) ProcessorOption { + return func(o *poptions, gg *GroupGraph) { + o.contextWrapper = wrapper + } +} + // WithUpdateCallback defines the callback called upon recovering a message // from the log. func WithUpdateCallback(cb UpdateCallback) ProcessorOption { @@ -329,6 +339,8 @@ func (opt *poptions) applyOptions(gg *GroupGraph, opts ...ProcessorOption) error opt.log = defaultLogger opt.hasher = DefaultHasher() opt.backoffResetTime = defaultBackoffResetTime + // default context wrapper returns the original context + opt.contextWrapper = func(ctx Context) Context { return ctx } for _, o := range opts { o(opt, gg) diff --git a/partition_processor.go b/partition_processor.go index 7ec0768c..8c1d9d4b 100644 --- a/partition_processor.go +++ b/partition_processor.go @@ -47,6 +47,8 @@ type visit struct { type commitCallback func(msg *message, meta string) +type ContextWrapper func(ctx Context) Context + // PartitionProcessor handles message processing of one partition by serializing // messages from different input topics. // It also handles joined tables as well as lookup views (managed by `Processor`). @@ -595,7 +597,6 @@ func (pp *PartitionProcessor) processVisit(ctx context.Context, wg *sync.WaitGro emitterDefaultHeaders: pp.opts.producerDefaultHeaders, table: pp.table, } - // start context and call the ProcessorCallback cb msgContext.start() @@ -610,8 +611,8 @@ func (pp *PartitionProcessor) processVisit(ctx context.Context, wg *sync.WaitGro } }() - // now call cb - cb(msgContext, v.meta) + // now call cb, wrap the context + cb(pp.opts.contextWrapper(msgContext), v.meta) msgContext.finish(nil) return } @@ -673,7 +674,7 @@ func (pp *PartitionProcessor) processMessage(ctx context.Context, wg *sync.WaitG msgContext.start() // now call cb - cb(msgContext, m) + cb(pp.opts.contextWrapper(msgContext), m) msgContext.finish(nil) return nil } @@ -706,7 +707,7 @@ func (pp *PartitionProcessor) VisitValues(ctx context.Context, name string, meta } defer it.Release() - + stopping, doneWaitingForStop := pp.stopping() defer doneWaitingForStop() diff --git a/processor.go b/processor.go index a07c3b81..1c4c752e 100644 --- a/processor.go +++ b/processor.go @@ -40,6 +40,7 @@ type Processor struct { log logger brokers []string + // hook used to be notified whenever the processor has rebalanced to a new assignment rebalanceCallback RebalanceCallback // rwmutex protecting read/write of partitions and lookuptables. @@ -737,8 +738,8 @@ func (g *Processor) Cleanup(session sarama.ConsumerGroupSession) error { // WaitForReady waits until the processor is ready to consume messages // (or is actually consuming messages) // i.e., it is done catching up all partition tables, joins and lookup tables -func (g *Processor) WaitForReady() { - g.waitForReady(context.Background()) +func (g *Processor) WaitForReady() error { + return g.waitForReady(context.Background()) } // WaitForReadyContext is context aware option of WaitForReady. diff --git a/systemtest/commit_test.go b/systemtest/commit_test.go index 3779d80e..6d4ef318 100644 --- a/systemtest/commit_test.go +++ b/systemtest/commit_test.go @@ -12,10 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -const ( - pollWaitSecs = 15.0 -) - // TestAutoCommit tests/demonstrates the behavior of disabling the auto-commit functionality. // The autocommiter sends the offsets of the marked messages to the broker regularily. If the processor shuts down // (or the group rebalances), the offsets are sent one last time, so just turning it of is not enough. @@ -78,7 +74,7 @@ func TestAutoCommit(t *testing.T) { // run the first processor _, cancel, done := runProc(createProc()) - pollTimed(t, "all-received1", pollWaitSecs, func() bool { + pollTimed(t, "all-received1", func() bool { return len(offsets) == 10 && offsets[0] == 0 }) @@ -96,7 +92,7 @@ func TestAutoCommit(t *testing.T) { // --> we'll receive all messages again // --> i.e., no offsets were committed - pollTimed(t, "all-received2", pollWaitSecs, func() bool { + pollTimed(t, "all-received2", func() bool { return len(offsets) == 10 && offsets[0] == 0 }) @@ -153,7 +149,7 @@ func TestUnmarkedMessages(t *testing.T) { // run the first processor runProc(createProc()) - pollTimed(t, "all-received1", pollWaitSecs, func() bool { + pollTimed(t, "all-received1", func() bool { return len(values) == 2 && values[0] == 1 }) @@ -162,7 +158,7 @@ func TestUnmarkedMessages(t *testing.T) { // restart -> we'll only receive the second message runProc(createProc()) - pollTimed(t, "all-received2", pollWaitSecs, func() bool { + pollTimed(t, "all-received2", func() bool { return len(values) == 1 && values[0] == 2 }) } diff --git a/systemtest/emitter_disconnect_test.go b/systemtest/emitter_disconnect_test.go index daa61414..c6852632 100644 --- a/systemtest/emitter_disconnect_test.go +++ b/systemtest/emitter_disconnect_test.go @@ -78,7 +78,7 @@ func TestEmitter_KafkaDisconnect(t *testing.T) { } }() - pollTimed(t, "emitter emitted something successfully", 10, func() bool { + pollTimed(t, "emitter emitted something successfully", func() bool { return atomic.LoadInt64(&success) > 0 }) diff --git a/systemtest/multitopic_test.go b/systemtest/multitopic_test.go index d3bb0a2a..7778d41a 100644 --- a/systemtest/multitopic_test.go +++ b/systemtest/multitopic_test.go @@ -85,7 +85,7 @@ func TestMultiTopics(t *testing.T) { }) log.Printf("waiting for processor/view to be running") - pollTimed(t, "proc and view are recovered", 10.0, proc.Recovered, view.Recovered) + pollTimed(t, "proc and view are recovered", proc.Recovered, view.Recovered) log.Printf("...done") var sum int64 @@ -110,7 +110,7 @@ func TestMultiTopics(t *testing.T) { } // poll the view and the processor until we're sure that we have - pollTimed(t, "all messages have been transferred", 10.0, + pollTimed(t, "all messages have been transferred", func() bool { value, err := view.Get("key") require.NoError(t, err) diff --git a/systemtest/proc_disconnect_test.go b/systemtest/proc_disconnect_test.go index 79bf298b..00306882 100644 --- a/systemtest/proc_disconnect_test.go +++ b/systemtest/proc_disconnect_test.go @@ -80,7 +80,7 @@ func TestProcessorShutdown_KafkaDisconnect(t *testing.T) { errg.Go(func() error { return proc.Run(ctx) }) - pollTimed(t, "proc running", 10, proc.Recovered, func() bool { + pollTimed(t, "proc running", proc.Recovered, func() bool { if val, _ := proc.Get("key-15"); val != nil && val.(int64) > 0 { return true } diff --git a/systemtest/processor_test.go b/systemtest/processor_test.go index a3515ae7..13baa7f2 100644 --- a/systemtest/processor_test.go +++ b/systemtest/processor_test.go @@ -85,7 +85,7 @@ func TestHotStandby(t *testing.T) { return proc2.Run(ctx) }) - pollTimed(t, "procs 1&2 recovered", 25.0, proc1.Recovered, proc2.Recovered) + pollTimed(t, "procs 1&2 recovered", proc1.Recovered, proc2.Recovered) // check the storages that were initalized by the processors: // proc1 is without hotstandby -> only two storages: (1 for the table, 1 for the join) @@ -120,7 +120,7 @@ func TestHotStandby(t *testing.T) { joinStorage2 := proc2Storages.storages[proc2Storages.key(string(joinTable), party)] // wait until the keys are present - pollTimed(t, "key-values are present", 10, + pollTimed(t, "key-values are present", func() bool { has, _ := tableStorage1.Has("key1") return has @@ -241,7 +241,7 @@ func TestRecoverAhead(t *testing.T) { return proc2.Run(ctx) }) - pollTimed(t, "procs 1&2 recovered", 10.0, func() bool { + pollTimed(t, "procs 1&2 recovered", func() bool { return true }, proc1.Recovered, proc2.Recovered) @@ -257,7 +257,7 @@ func TestRecoverAhead(t *testing.T) { joinStorage2 := proc2Storages.storages[proc2Storages.key(string(joinTable), 0)] // wait until the keys are present - pollTimed(t, "key-values are present", 20.0, + pollTimed(t, "key-values are present", func() bool { return true @@ -436,7 +436,7 @@ func TestRebalanceSharePartitions(t *testing.T) { } p1, cancelP1, p1Done := runProc(createProc()) - pollTimed(t, "p1 started", 10, p1.Recovered) + pollTimed(t, "p1 started", p1.Recovered) // p1 has all active partitions p1Stats := p1.Stats() @@ -445,8 +445,8 @@ func TestRebalanceSharePartitions(t *testing.T) { require.Equal(t, 0, p1Passive) p2, cancelP2, p2Done := runProc(createProc()) - pollTimed(t, "p2 started", 20, p2.Recovered) - pollTimed(t, "p1 still running", 10, p1.Recovered) + pollTimed(t, "p2 started", p2.Recovered) + pollTimed(t, "p1 still running", p1.Recovered) // now p1 and p2 share the partitions p2Stats := p2.Stats() @@ -463,7 +463,7 @@ func TestRebalanceSharePartitions(t *testing.T) { require.True(t, <-p1Done == nil) // p2 should have all partitions - pollTimed(t, "p2 has all partitions", 10, func() bool { + pollTimed(t, "p2 has all partitions", func() bool { p2Stats = p2.Stats() p2Active, p2Passive := activePassivePartitions(p2Stats) return p2Active == numPartitions && p2Passive == 0 @@ -515,7 +515,7 @@ func TestCallbackFail(t *testing.T) { proc, cancel, done := runProc(proc) - pollTimed(t, "recovered", 10, proc.Recovered) + pollTimed(t, "recovered", proc.Recovered) go func() { for i := 0; i < 10000; i++ { @@ -524,7 +524,7 @@ func TestCallbackFail(t *testing.T) { }() defer cancel() - pollTimed(t, "error-response", 10, func() bool { + pollTimed(t, "error-response", func() bool { select { case err, ok := <-done: if !ok { @@ -745,7 +745,7 @@ func TestProcessorGracefulShutdownContinue(t *testing.T) { proc, cancelProc, procDone := runProc(createProc()) - pollTimed(t, "proc-running", 10, proc.Recovered) + pollTimed(t, "proc-running", proc.Recovered) // stop it cancelProc() @@ -759,7 +759,7 @@ func TestProcessorGracefulShutdownContinue(t *testing.T) { for i := 0; i < 10; i++ { log.Printf("creating proc round %d", i) proc, cancelProc, procDone := runProc(createProc()) - pollTimed(t, "proc-running", 10, proc.Recovered) + pollTimed(t, "proc-running", proc.Recovered) // stop it cancelProc() require.NoError(t, <-procDone) @@ -772,9 +772,9 @@ func TestProcessorGracefulShutdownContinue(t *testing.T) { // start one last time to check the values log.Printf("creating a proc for the last time to check values") proc, cancelProc, procDone = runProc(createProc()) - pollTimed(t, "proc-running", 10, proc.Recovered) + pollTimed(t, "proc-running", proc.Recovered) - pollTimed(t, "correct-values", 10, func() bool { + pollTimed(t, "correct-values", func() bool { for key, value := range valueSum { tableVal, err := proc.Get(key) if tableVal == nil || err != nil { diff --git a/systemtest/processor_visit_test.go b/systemtest/processor_visit_test.go index a27a0b3a..efd8f48d 100644 --- a/systemtest/processor_visit_test.go +++ b/systemtest/processor_visit_test.go @@ -13,11 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -const ( - waitRecoveredTimeoutSecs = 15 - emitWaitTimeoutSecs = 15 -) - // TestProcessorVisit tests the visiting functionality. func TestProcessorVisit(t *testing.T) { brokers := initSystemTest(t) @@ -93,17 +88,17 @@ func TestProcessorVisit(t *testing.T) { defer finish() proc, cancel, done := runProc(createProc(group, input, 0)) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, proc.Recovered) + pollTimed(t, "recovered", proc.Recovered) em.EmitSync("value1", int64(1)) - pollTimed(t, "value-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "value-ok", func() bool { val1, _ := proc.Get("value1") return val1 != nil && val1.(int64) == 1 }) require.NoError(t, proc.VisitAll(context.Background(), "visitor", int64(25))) - pollTimed(t, "values-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "values-ok", func() bool { val1, _ := proc.Get("value1") return val1 != nil && val1.(int64) == 25 }) @@ -117,11 +112,11 @@ func TestProcessorVisit(t *testing.T) { defer finish() proc, cancel, done := runProc(createProc(group, input, 0)) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, proc.Recovered) + pollTimed(t, "recovered", proc.Recovered) em.EmitSync("value1", int64(1)) - pollTimed(t, "value-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "value-ok", func() bool { val1, _ := proc.Get("value1") return val1 != nil && val1.(int64) == 1 }) @@ -139,14 +134,14 @@ func TestProcessorVisit(t *testing.T) { defer finish() proc, cancel, done := runProc(createProc(group, input, 500*time.Millisecond)) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, proc.Recovered) + pollTimed(t, "recovered", proc.Recovered) // emit two values where goka.DefaultHasher says they're in the same partition. // We need to achieve this to test that a shutdown will visit one value but not the other em.EmitSync("0", int64(1)) em.EmitSync("02", int64(1)) - pollTimed(t, "value-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "value-ok", func() bool { val1, _ := proc.Get("02") val2, _ := proc.Get("0") return val1 != nil && val1.(int64) == 1 && val2 != nil && val2.(int64) == 1 @@ -206,8 +201,8 @@ func TestProcessorVisit(t *testing.T) { proc, cancel, done := runProc(createProc(group, input, 500*time.Millisecond)) view, viewCancel, viewDone := runView(createView(group)) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, proc.Recovered) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, view.Recovered) + pollTimed(t, "recovered", proc.Recovered) + pollTimed(t, "recovered", view.Recovered) // emit two values where goka.DefaultHasher says they're in the same partition. // We need to achieve this to test that a shutdown will visit one value but not the other @@ -217,7 +212,7 @@ func TestProcessorVisit(t *testing.T) { // emFinish() // poll until all values are there - pollTimed(t, "value-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "value-ok", func() bool { for i := 0; i < 100; i++ { val, _ := proc.Get(fmt.Sprintf("value-%d", i)) if val == nil || val.(int64) != 1 { @@ -260,7 +255,7 @@ func TestProcessorVisit(t *testing.T) { // scenario: sleep in visit, processor shuts down--> visit should cancel too proc1, cancel1, done1 := runProc(createProc(group, input, 500*time.Millisecond)) - pollTimed(t, "recovered", waitRecoveredTimeoutSecs, proc1.Recovered) + pollTimed(t, "recovered", proc1.Recovered) // emit two values where goka.DefaultHasher says they're in the same partition. // We need to achieve this to test that a shutdown will visit one value but not the other @@ -269,7 +264,7 @@ func TestProcessorVisit(t *testing.T) { } // poll until all values are there - pollTimed(t, "value-ok", emitWaitTimeoutSecs, func() bool { + pollTimed(t, "value-ok", func() bool { for i := 0; i < 100; i++ { val, _ := proc1.Get(fmt.Sprintf("value-%d", i)) if val == nil || val.(int64) != 1 { @@ -293,7 +288,7 @@ func TestProcessorVisit(t *testing.T) { _, cancel2, done2 := runProc(createProc(group, input, 500*time.Millisecond)) // wait until the visit is aborted by the new processor (rebalance) - pollTimed(t, "visit-abort", 10, func() bool { + pollTimed(t, "visit-abort", func() bool { select { case <-visitDone: return errors.Is(visitErr, goka.ErrVisitAborted) && visited > 0 && visited < 100 diff --git a/systemtest/utils_test.go b/systemtest/utils_test.go index 564d3680..4beb33ca 100644 --- a/systemtest/utils_test.go +++ b/systemtest/utils_test.go @@ -11,9 +11,19 @@ import ( "github.com/lovoo/goka/storage" ) +var ( + pollMaxWait = 30 * time.Second + sleepTime = 100 * time.Millisecond +) + // polls all pollers until all return true or fails the test when secTimeout has passed. -func pollTimed(t *testing.T, what string, secTimeout float64, pollers ...func() bool) { - for i := 0; i < int(secTimeout/0.02); i++ { +func pollTimed(t *testing.T, what string, pollers ...func() bool) { + end := time.Now().Add(pollMaxWait) + for i := 0; ; i++ { + // we're past max wait time, let's fail + if end.Before(time.Now()) { + break + } ok := true for _, poller := range pollers { if !poller() { @@ -24,7 +34,7 @@ func pollTimed(t *testing.T, what string, secTimeout float64, pollers ...func() if ok { return } - time.Sleep(20 * time.Millisecond) + time.Sleep(sleepTime) } t.Fatalf("waiting for %s timed out", what) } diff --git a/systemtest/view_reconnect_test.go b/systemtest/view_reconnect_test.go index de4f99cf..7a93fff3 100644 --- a/systemtest/view_reconnect_test.go +++ b/systemtest/view_reconnect_test.go @@ -68,7 +68,7 @@ func TestView_Reconnect(t *testing.T) { errg.Go(func() error { return view.Run(ctx) }) - pollTimed(t, "view-recovered", 10, view.Recovered) + pollTimed(t, "view-recovered", view.Recovered) val := func() int64 { val, err := view.Get("key") @@ -79,7 +79,7 @@ func TestView_Reconnect(t *testing.T) { return val.(int64) } - pollTimed(t, "wait-first-value", 3, func() bool { + pollTimed(t, "wait-first-value", func() bool { return val() > 0 }) firstVal := val() @@ -88,7 +88,7 @@ func TestView_Reconnect(t *testing.T) { // kill kafka connection fi.SetReadError(io.EOF) - pollTimed(t, "view-reconnecting", 10, func() bool { + pollTimed(t, "view-reconnecting", func() bool { return view.CurrentState() == goka.ViewStateConnecting }) @@ -102,10 +102,10 @@ func TestView_Reconnect(t *testing.T) { // connect kafka again, wait until it's running -> the value should have changed fi.ResetErrors() - pollTimed(t, "view-running", 10, func() bool { + pollTimed(t, "view-running", func() bool { return view.CurrentState() == goka.ViewStateRunning }) - pollTimed(t, "view-running", 5, func() bool { + pollTimed(t, "value-propagated", func() bool { return val() > secondVal }) diff --git a/tester/consumergroup.go b/tester/consumergroup.go index 37259b4f..3ae08b87 100644 --- a/tester/consumergroup.go +++ b/tester/consumergroup.go @@ -45,6 +45,8 @@ func newConsumerGroup(t T, tt *Tester) *consumerGroup { } func (cg *consumerGroup) catchupAndWait() int { + cg.mu.RLock() + defer cg.mu.RUnlock() if cg.currentSession == nil { panic("There is currently no session. Cannot catchup, but we shouldn't be at this point") } @@ -105,7 +107,9 @@ func (cg *consumerGroup) Consume(ctx context.Context, topics []string, handler s errs = multierror.Append(errs, handler.Cleanup(session)) // remove current sessions + cg.mu.Lock() cg.currentSession = nil + cg.mu.Unlock() return errs.ErrorOrNil() } @@ -169,7 +173,6 @@ type cgSession struct { } func newCgSession(ctx context.Context, generation int32, cg *consumerGroup, topics []string) *cgSession { - ctx, cancel := context.WithCancel(ctx) cgs := &cgSession{ ctx: ctx,