From b0dc1006738f234a8201e9be73282f2dda1cad45 Mon Sep 17 00:00:00 2001 From: Vytenis Darulis Date: Wed, 2 Sep 2020 23:31:10 -0400 Subject: [PATCH] Fix m3msg races --- src/msg/integration/integration_test.go | 16 ++++++------- src/msg/integration/setup.go | 16 ++++++++++++- src/msg/producer/ref_counted.go | 26 ++++++++++----------- src/msg/producer/writer/message.go | 30 ++++++++++++++++++++----- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/src/msg/integration/integration_test.go b/src/msg/integration/integration_test.go index 68cbaac8cf..0c9cdf2092 100644 --- a/src/msg/integration/integration_test.go +++ b/src/msg/integration/integration_test.go @@ -57,7 +57,7 @@ func TestSharedConsumer(t *testing.T) { } } -func TestReplicatedConsumerx(t *testing.T) { +func TestReplicatedConsumer(t *testing.T) { if testing.Short() { t.SkipNow() // Just skip if we're doing a short run } @@ -129,9 +129,9 @@ func TestSharedConsumerWithDeadInstance(t *testing.T) { s.Run(t, ctrl) s.VerifyConsumers(t) testConsumers := s.consumerServices[0].testConsumers - require.True(t, testConsumers[len(testConsumers)-1].consumed <= s.TotalMessages()*10/100) + require.True(t, testConsumers[len(testConsumers)-1].numConsumed() <= s.TotalMessages()*10/100) testConsumers = s.consumerServices[1].testConsumers - require.True(t, testConsumers[len(testConsumers)-1].consumed <= s.TotalMessages()*20/100) + require.True(t, testConsumers[len(testConsumers)-1].numConsumed() <= s.TotalMessages()*20/100) } } @@ -546,8 +546,8 @@ func TestRemoveConsumerService(t *testing.T) { ) s.Run(t, ctrl) s.VerifyConsumers(t) - require.Equal(t, msgPerShard*numberOfShards, len(s.consumerServices[0].consumed)) - require.Equal(t, msgPerShard*numberOfShards, len(s.consumerServices[1].consumed)) + require.Equal(t, msgPerShard*numberOfShards, s.consumerServices[0].numConsumed()) + require.Equal(t, msgPerShard*numberOfShards, s.consumerServices[1].numConsumed()) } } @@ -574,8 +574,8 @@ func TestAddConsumerService(t *testing.T) { }, ) s.Run(t, ctrl) - require.Equal(t, s.ExpectedNumMessages(), len(s.consumerServices[0].consumed)) - require.Equal(t, s.ExpectedNumMessages(), len(s.consumerServices[1].consumed)) - require.True(t, len(s.consumerServices[2].consumed) <= s.ExpectedNumMessages()*80/100) + require.Equal(t, s.ExpectedNumMessages(), s.consumerServices[0].numConsumed()) + require.Equal(t, s.ExpectedNumMessages(), s.consumerServices[1].numConsumed()) + require.True(t, s.consumerServices[2].numConsumed() <= s.ExpectedNumMessages()*80/100) } } diff --git a/src/msg/integration/setup.go b/src/msg/integration/setup.go index c1200f58e9..ee4e57743e 100644 --- a/src/msg/integration/setup.go +++ b/src/msg/integration/setup.go @@ -237,7 +237,7 @@ func (s *setup) Run( func (s *setup) VerifyConsumers(t *testing.T) { numWritesPerProducer := s.ExpectedNumMessages() for _, cs := range s.consumerServices { - require.Equal(t, numWritesPerProducer, len(cs.consumed)) + require.Equal(t, numWritesPerProducer, cs.numConsumed()) } } @@ -407,6 +407,13 @@ func (cs *testConsumerService) markConsumed(b []byte) { cs.consumed[string(b)] = struct{}{} } +func (cs *testConsumerService) numConsumed() int { + cs.Lock() + defer cs.Unlock() + + return len(cs.consumed) +} + func (cs *testConsumerService) Close() { for _, c := range cs.testConsumers { c.Close() @@ -437,6 +444,13 @@ func (c *testConsumer) Close() { close(c.doneCh) } +func (c *testConsumer) numConsumed() int { + c.Lock() + defer c.Unlock() + + return c.consumed +} + func newTestConsumer(t *testing.T, cs *testConsumerService) *testConsumer { consumerListener, err := consumer.NewListener("127.0.0.1:0", testConsumerOptions(t)) require.NoError(t, err) diff --git a/src/msg/producer/ref_counted.go b/src/msg/producer/ref_counted.go index 93a17fd00e..18ecb24869 100644 --- a/src/msg/producer/ref_counted.go +++ b/src/msg/producer/ref_counted.go @@ -31,24 +31,24 @@ type OnFinalizeFn func(rm *RefCountedMessage) // RefCountedMessage is a reference counted message. type RefCountedMessage struct { - sync.RWMutex + mu sync.RWMutex Message size uint64 onFinalizeFn OnFinalizeFn - refCount *atomic.Int32 - isDroppedOrConsumed *atomic.Bool + // RefCountedMessage must not be copied by value due to RWMutex, + // safe to store values here and not just pointers + refCount atomic.Int32 + isDroppedOrConsumed atomic.Bool } // NewRefCountedMessage creates RefCountedMessage. func NewRefCountedMessage(m Message, fn OnFinalizeFn) *RefCountedMessage { return &RefCountedMessage{ - Message: m, - refCount: atomic.NewInt32(0), - size: uint64(m.Size()), - onFinalizeFn: fn, - isDroppedOrConsumed: atomic.NewBool(false), + Message: m, + size: uint64(m.Size()), + onFinalizeFn: fn, } } @@ -76,12 +76,12 @@ func (rm *RefCountedMessage) DecRef() { // IncReads increments the reads count. func (rm *RefCountedMessage) IncReads() { - rm.RLock() + rm.mu.RLock() } // DecReads decrements the reads count. func (rm *RefCountedMessage) DecReads() { - rm.RUnlock() + rm.mu.RUnlock() } // NumRef returns the number of references remaining. @@ -107,13 +107,13 @@ func (rm *RefCountedMessage) IsDroppedOrConsumed() bool { func (rm *RefCountedMessage) finalize(r FinalizeReason) bool { // NB: This lock prevents the message from being finalized when its still // being read. - rm.Lock() + rm.mu.Lock() if rm.isDroppedOrConsumed.Load() { - rm.Unlock() + rm.mu.Unlock() return false } rm.isDroppedOrConsumed.Store(true) - rm.Unlock() + rm.mu.Unlock() if rm.onFinalizeFn != nil { rm.onFinalizeFn(rm) } diff --git a/src/msg/producer/writer/message.go b/src/msg/producer/writer/message.go index 6c9501ff36..535ea32e97 100644 --- a/src/msg/producer/writer/message.go +++ b/src/msg/producer/writer/message.go @@ -21,6 +21,9 @@ package writer import ( + stdatomic "sync/atomic" + "unsafe" + "github.com/m3db/m3/src/msg/generated/proto/msgpb" "github.com/m3db/m3/src/msg/producer" "github.com/m3db/m3/src/msg/protocol/proto" @@ -38,14 +41,14 @@ type message struct { retried int // NB(cw) isAcked could be accessed concurrently by the background thread // in message writer and acked by consumer service writers. - isAcked *atomic.Bool + // Safe to store value inside struct, as message is never copied by value + isAcked atomic.Bool } func newMessage() *message { return &message{ retryAtNanos: 0, retried: 0, - isAcked: atomic.NewBool(false), } } @@ -53,7 +56,7 @@ func newMessage() *message { func (m *message) Set(meta metadata, rm *producer.RefCountedMessage, initNanos int64) { m.initNanos = initNanos m.meta = meta - m.RefCountedMessage = rm + m.storeRefCountedMessage(rm) m.ToProto(&m.pb) } @@ -98,7 +101,7 @@ func (m *message) IsAcked() bool { // Ack acknowledges the message. Duplicated acks on the same message might cause panic. func (m *message) Ack() { m.isAcked.Store(true) - m.RefCountedMessage.DecRef() + m.loadRefCountedMessage().DecRef() } // Metadata returns the metadata. @@ -108,14 +111,29 @@ func (m *message) Metadata() metadata { // Marshaler returns the marshaler and a bool to indicate whether the marshaler is valid. func (m *message) Marshaler() (proto.Marshaler, bool) { - return &m.pb, !m.RefCountedMessage.IsDroppedOrConsumed() + return &m.pb, !m.loadRefCountedMessage().IsDroppedOrConsumed() } func (m *message) ToProto(pb *msgpb.Message) { m.meta.ToProto(&pb.Metadata) - pb.Value = m.RefCountedMessage.Bytes() + pb.Value = m.loadRefCountedMessage().Bytes() } func (m *message) ResetProto(pb *msgpb.Message) { pb.Value = nil } + +func (m *message) storeRefCountedMessage(rm *producer.RefCountedMessage) { + stdatomic.StorePointer( + (*unsafe.Pointer)(unsafe.Pointer(&m.RefCountedMessage)), + unsafe.Pointer(rm), + ) +} + +func (m *message) loadRefCountedMessage() *producer.RefCountedMessage { + return (*producer.RefCountedMessage)( + stdatomic.LoadPointer( + (*unsafe.Pointer)(unsafe.Pointer(&m.RefCountedMessage)), + ), + ) +}