From 83dfa9d9d0525381010a4ef1dcc87e1758a8ccbd Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Sun, 27 Feb 2022 18:49:02 -0700 Subject: [PATCH] client: add EndAndBeginTransaction For situations where you just want to end a transaction and begin a new one, without caring about what *exactly* is flushed, it can be beneficial to allow concurrent producing while end / begin are ongoing. In particular, the restriction to flush may add unexpected latency. This adds EndAndBeginTransaction that can be called concurrent with producing. There are two safety options to choose: one that does not really increase throughput **that** much, but does keep things safe, and one that relaxes throughput at the expense of safety. We delete some logic from txn_test to use our new function. Previously, we would know what all buffered records are flushed when ending a transaction. That is no longer true, so we can no longer count on the transactional marker being where we expect. Instead, we ensure our offsets are at least monotonically increasing and call it good. --- pkg/kgo/atomic_maybe_work.go | 8 ++ pkg/kgo/producer.go | 84 +++++++++++-- pkg/kgo/sink.go | 26 ++-- pkg/kgo/txn.go | 236 ++++++++++++++++++++++++++++++++++- pkg/kgo/txn_test.go | 23 ++-- 5 files changed, 336 insertions(+), 41 deletions(-) diff --git a/pkg/kgo/atomic_maybe_work.go b/pkg/kgo/atomic_maybe_work.go index 2f517160..807b6713 100644 --- a/pkg/kgo/atomic_maybe_work.go +++ b/pkg/kgo/atomic_maybe_work.go @@ -15,6 +15,14 @@ func (b *atomicBool) set(v bool) { func (b *atomicBool) get() bool { return atomic.LoadUint32((*uint32)(b)) == 1 } +func (b *atomicBool) swap(v bool) bool { + var swap uint32 + if v { + swap = 1 + } + return atomic.SwapUint32((*uint32)(b), swap) == 1 +} + const ( stateUnstarted = iota stateWorking diff --git a/pkg/kgo/producer.go b/pkg/kgo/producer.go index 8f0fcb51..9e90df00 100644 --- a/pkg/kgo/producer.go +++ b/pkg/kgo/producer.go @@ -48,11 +48,15 @@ type producer struct { idVersion int16 waitBuffer chan struct{} - // notifyMu and notifyCond are used for flush and drain notifications. - notifyMu sync.Mutex - notifyCond *sync.Cond + // mu and c are used for flush and drain notifications; mu is used for + // a few other tight locks. + mu sync.Mutex + c *sync.Cond + + inflight int64 // high 16: # waiters, low 48: # inflight batchPromises ringBatchPromise + promisesMu sync.Mutex txnMu sync.Mutex inTxn bool @@ -84,7 +88,7 @@ func (p *producer) init(cl *Client) { epoch: -1, err: errReloadProducerID, }) - p.notifyCond = sync.NewCond(&p.notifyMu) + p.c = sync.NewCond(&p.mu) inithooks := func() { if p.hooks == nil { @@ -415,6 +419,7 @@ func (p *producer) finishPromises(b batchPromise) { cl := p.cl var more bool start: + p.promisesMu.Lock() for i, pr := range b.recs { pr.Offset = b.baseOffset + int64(i) pr.Partition = b.partition @@ -424,6 +429,7 @@ start: cl.finishRecordPromise(pr, b.err) b.recs[i] = promisedRec{} } + p.promisesMu.Unlock() if cap(b.recs) > 4 { cl.prsPool.put(b.recs) } @@ -452,9 +458,9 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error) { if buffered >= cl.cfg.maxBufferedRecords { p.waitBuffer <- struct{}{} } else if buffered == 0 && atomic.LoadInt32(&p.flushing) > 0 { - p.notifyMu.Lock() - p.notifyMu.Unlock() // nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe. - p.notifyCond.Broadcast() + p.mu.Lock() + p.mu.Unlock() // nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe. + p.c.Broadcast() } } @@ -867,12 +873,12 @@ func (cl *Client) Flush(ctx context.Context) error { quit := false done := make(chan struct{}) go func() { - p.notifyMu.Lock() - defer p.notifyMu.Unlock() + p.mu.Lock() + defer p.mu.Unlock() defer close(done) for !quit && atomic.LoadInt64(&p.bufferedRecords) > 0 { - p.notifyCond.Wait() + p.c.Wait() } }() @@ -880,14 +886,66 @@ func (cl *Client) Flush(ctx context.Context) error { case <-done: return nil case <-ctx.Done(): - p.notifyMu.Lock() + p.mu.Lock() quit = true - p.notifyMu.Unlock() - p.notifyCond.Broadcast() + p.mu.Unlock() + p.c.Broadcast() return ctx.Err() } } +func (p *producer) pause(ctx context.Context) error { + atomic.AddInt64(&p.inflight, 1<<48) + + quit := false + done := make(chan struct{}) + go func() { + p.mu.Lock() + defer p.mu.Unlock() + defer close(done) + for !quit && atomic.LoadInt64(&p.inflight)&0x0000ffffffffffff != 0 { + p.c.Wait() + } + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + p.mu.Lock() + quit = true + p.mu.Unlock() + p.c.Broadcast() + p.resume() // dec our inflight + return ctx.Err() + } +} + +func (p *producer) resume() { + if atomic.AddInt64(&p.inflight, -1<<48) == 0 { + p.cl.allSinksAndSources(func(sns sinkAndSource) { + sns.sink.maybeDrain() + }) + } +} + +func (p *producer) maybeAddInflight() bool { + if atomic.LoadInt64(&p.inflight)>>48 > 0 { + return false + } + if atomic.AddInt64(&p.inflight, 1)>>48 > 0 { + p.decInflight() + return false + } + return true +} + +func (p *producer) decInflight() { + if atomic.AddInt64(&p.inflight, -1)>>48 > 0 { + p.c.Broadcast() + } +} + // Bumps the tries for all buffered records in the client. // // This is called whenever there is a problematic error that would affect the diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index 9c13f50c..0a685565 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -148,10 +148,9 @@ func (t *txnReqBuilder) add(rb *recBuf) { if t.txnID == nil { return } - if rb.addedToTxn { + if rb.addedToTxn.swap(true) { return } - rb.addedToTxn = true if t.req == nil { req := kmsg.NewPtrAddPartitionsToTxnRequest() req.TransactionalID = *t.txnID @@ -282,6 +281,15 @@ func (s *sink) produce(sem <-chan struct{}) bool { return false } + if !s.cl.producer.maybeAddInflight() { // must do before marking recBufs on a txn + return false + } + defer func() { + if !produced { + s.cl.producer.decInflight() + } + }() + // NOTE: we create the req AFTER getting our producer ID! // // If a prior response caused errReloadProducerID, then calling @@ -335,6 +343,7 @@ func (s *sink) produce(sem <-chan struct{}) bool { batches := req.batches.sliced() s.doSequenced(req, func(br *broker, resp kmsg.Response, err error) { + s.cl.producer.decInflight() s.handleReqResp(br, req, resp, err) batches.eachOwnerLocked((*recBatch).decInflight) <-sem @@ -409,7 +418,7 @@ func (s *sink) doTxnReq( // inflight, and that it was not added to the txn and that we need to reset the // drain index. func (b *recBatch) removeFromTxn() { - b.owner.addedToTxn = false + b.owner.addedToTxn.set(false) b.owner.resetBatchDrainIdx() b.decInflight() } @@ -426,7 +435,7 @@ func (s *sink) issueTxnReq( for _, topic := range resp.Topics { topicBatches, ok := req.batches[topic.Topic] if !ok { - s.cl.cfg.logger.Log(LogLevelError, "Kafka replied with topic in AddPartitionsToTxnResponse that was not in request", "broker", logID(s.nodeID), "topic", topic.Topic) + s.cl.cfg.logger.Log(LogLevelError, "Kafka replied with topic in AddPartitionsToTxnResponse that was not in request", "topic", topic.Topic) continue } for _, partition := range topic.Partitions { @@ -440,7 +449,7 @@ func (s *sink) issueTxnReq( batch, ok := topicBatches[partition.Partition] if !ok { - s.cl.cfg.logger.Log(LogLevelError, "Kafka replied with partition in AddPartitionsToTxnResponse that was not in request", "broker", logID(s.nodeID), "topic", topic.Topic, "partition", partition.Partition) + s.cl.cfg.logger.Log(LogLevelError, "Kafka replied with partition in AddPartitionsToTxnResponse that was not in request", "topic", topic.Topic, "partition", partition.Partition) continue } @@ -950,12 +959,7 @@ type recBuf struct { // addedToTxn, for transactions only, signifies whether this partition // has been added to the transaction yet or not. - // - // This does not need to be under the mu since it is updated either - // serially in building a req (the first time) or after failing to add - // the partition to a txn (again serially), or in EndTransaction after - // all buffered records are flushed (if the API is used correctly). - addedToTxn bool + addedToTxn atomicBool // For LoadTopicPartitioner partitioning; atomically tracks the number // of records buffered in total on this recBuf. diff --git a/pkg/kgo/txn.go b/pkg/kgo/txn.go index cb7088c5..0723b7db 100644 --- a/pkg/kgo/txn.go +++ b/pkg/kgo/txn.go @@ -440,7 +440,7 @@ func (cl *Client) BeginTransaction() error { needRecover, didRecover, err := cl.maybeRecoverProducerID() if needRecover && !didRecover { cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err) - return fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %v", err) + return fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err) } cl.producer.inTxn = true @@ -450,6 +450,230 @@ func (cl *Client) BeginTransaction() error { return nil } +// EndBeginTxnHow controls the safety of how EndAndBeginTransaction executes. +type EndBeginTxnHow uint8 + +const ( + // EndBeginTxnSafe ensures a "safe" execution of EndAndBeginTransaction + // at the expense of speed. This option blocks all produce requests and + // only resumes produce requests when onEnd finishes. Note that some + // produce requests may have finished successfully and records that + // were a part of a transaction may have their promises waiting to be + // called: not all promises are guaranteed to be called. + EndBeginTxnSafe EndBeginTxnHow = iota + + // EndBeginTxnUnsafe opts for less safe EndAndBeginTransaction flow to + // achieve higher throughput. This option allows produce requests to + // continue while EndTxn actually commits. This is unsafe because a + // produce request itself only half begins a transaction. Internally, + // AddPartitionsToTxn actually begins a transaction. If your + // application dies before the client is able to successfully issue + // AddPartitionsToTxn, then a transaction will have partially begun + // within Kafka: the partial transaction will prevent the partition + // from being consumable past where the transaction begun, and the + // transaction will not timeout. You will have to restart your + // application with the SAME transactional ID and produce to all the + // same partitions to ensure to resume the transaction and unstick the + // partitions. + EndBeginTxnUnsafe +) + +// EndAndBeginTransaction is a combination of EndTransaction and +// BeginTransaction, and relaxes the restriction that the client must have no +// buffered records. This function does not flush nor abort any buffered +// records. It is ok to concurrently produce while this function executes. +// +// This function has different safety guarantees which are up to the user to +// decide. See the documentation on EndBeginTxnHow for which you would like to +// choose. +// +// The onEnd function is called with your input context and the result of +// EndTransaction. Promises are paused while onEnd executes. If onEnd returns +// an error, BeginTransaction is not called and this function returns the +// result of onEnd. Otherwise, this function returns the result of +// BeginTransaction. See the documentation on EndTransaction and +// BeginTransaction for further details. It is invalid to call this function +// more than once at a time, and it is invalid to call concurrent with +// EndTransaction or BeginTransaction. +func (cl *Client) EndAndBeginTransaction( + ctx context.Context, + how EndBeginTxnHow, + commit TransactionEndTry, + onEnd func(context.Context, error) error, +) (rerr error) { + if g := cl.consumer.g; g != nil { + return errors.New("cannot use EndAndBeginTransaction with EOS") + } + + cl.producer.txnMu.Lock() + defer cl.producer.txnMu.Unlock() + + // From BeginTransaction: if we return with no error, we begin. Unlike + // BeginTransaction, we do not error if in a transaction, because we + // expect to be in one. + defer func() { + if rerr == nil { + needRecover, didRecover, err := cl.maybeRecoverProducerID() + if needRecover && !didRecover { + cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err) + rerr = fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err) + return + } + cl.producer.inTxn = true + cl.cfg.logger.Log(LogLevelInfo, "beginning transaction", "transactional_id", *cl.cfg.txnID) + } + }() + + // If end/beginning safely, we have to pause AddPartitionsToTxn and + // ProduceRequest, and we only resume after the user's onEnd has been + // called. + if how == EndBeginTxnSafe { + if err := cl.producer.pause(ctx); err != nil { + return err + } + defer cl.producer.resume() + } + + // Before BeginTransaction, we block promises & call onEnd with whatever + // the return error is. + cl.producer.promisesMu.Lock() + var promisesUnblocked bool + unblockPromises := func() { + if promisesUnblocked { + return + } + promisesUnblocked = true + defer cl.producer.promisesMu.Unlock() + rerr = onEnd(ctx, rerr) + } + defer unblockPromises() + + if !cl.producer.inTxn { + return nil + } + + var anyAdded bool + var readd map[string][]int32 + for topic, parts := range cl.producer.topics.load() { + for i, part := range parts.load().partitions { + if part.records.addedToTxn.swap(false) { + if how == EndBeginTxnUnsafe { + if readd == nil { + readd = make(map[string][]int32) + } + readd[topic] = append(readd[topic], int32(i)) + } + anyAdded = true + } + } + } + + // EndTxn when no txn was started returns INVALID_TXN_STATE. + if !anyAdded { + cl.cfg.logger.Log(LogLevelInfo, "no records were produced during the commit; thus no transaction was began; ending without doing anything") + return nil + } + + // From EndTransaction: if the pid has an error, we may try to recover. + id, epoch, err := cl.producerID() + if err != nil { + if commit { + return kerr.OperationNotAttempted + } + if _, didRecover, _ := cl.maybeRecoverProducerID(); didRecover { + return nil + } + } + cl.cfg.logger.Log(LogLevelInfo, "ending transaction", + "transactional_id", *cl.cfg.txnID, + "producer_id", id, + "epoch", epoch, + "commit", commit, + ) + err = cl.doWithConcurrentTransactions("EndTxn", func() error { + req := kmsg.NewPtrEndTxnRequest() + req.TransactionalID = *cl.cfg.txnID + req.ProducerID = id + req.ProducerEpoch = epoch + req.Commit = bool(commit) + resp, err := req.RequestWith(ctx, cl) + if err != nil { + return err + } + return kerr.ErrorForCode(resp.ErrorCode) + }) + var ke *kerr.Error + if errors.As(err, &ke) && !ke.Retriable { + cl.failProducerID(id, epoch, err) + } + if err != nil || how != EndBeginTxnUnsafe { + return err + } + unblockPromises() + + // If we are end/beginning unsafely, then we need to re-add all + // partitions to a new transaction immediately. Timing makes it + // impossible to know what was truly added before EndTxn, so we + // pessimistically assume that every partition must be re-added. + // + // We track readd before the txn and swap those to un-added, but we + // also need to track anything that is newly added that raced with our + // EndTxn. We swap before the txn to ensure that *eventually*, + // partitions will be tracked as not in a transaction if people stop + // producing. + // + // We do this before the user callback because we *need* to start a new + // transaction within Kafka to ensure there will be a timeout. Per the + // unsafe aspect, the client could die or this request could error and + // there could be a stranded txn within Kafka's ProducerStateManager, + // but ideally the user will reconnect with the same txnal id. + return cl.doWithConcurrentTransactions("AddPartitionsToTxn", func() error { + req := kmsg.NewPtrAddPartitionsToTxnRequest() + req.TransactionalID = *cl.cfg.txnID + req.ProducerID = id + req.ProducerEpoch = epoch + + for topic, parts := range cl.producer.topics.load() { + for i, part := range parts.load().partitions { + if part.records.addedToTxn.get() { + readd[topic] = append(readd[topic], int32(i)) + } + } + } + + ps := make(map[int32]struct{}) + for topic, parts := range readd { + t := kmsg.NewAddPartitionsToTxnRequestTopic() + t.Topic = topic + for _, part := range parts { + ps[part] = struct{}{} + } + for p := range ps { + t.Partitions = append(t.Partitions, p) + delete(ps, p) + } + if len(t.Partitions) > 0 { + req.Topics = append(req.Topics, t) + } + } + + resp, err := req.RequestWith(ctx, cl) + if err != nil { + return err + } + for i := range resp.Topics { + t := &resp.Topics[i] + for j := range t.Partitions { + p := &t.Partitions[j] + if err := kerr.ErrorForCode(p.ErrorCode); err != nil { + return err + } + } + } + return nil + }) +} + // AbortBufferedRecords fails all unflushed records with ErrAborted and waits // for there to be no buffered records. // @@ -521,6 +745,8 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) // also produced. var anyAdded bool if g := cl.consumer.g; g != nil { + // We do not lock because we expect commitTransactionOffsets to + // be called *before* ending a transaction. if g.offsetsAddedToTxn { g.offsetsAddedToTxn = false anyAdded = true @@ -538,10 +764,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) // addedToTxn to false outside of any mutex. for _, parts := range cl.producer.topics.load() { for _, part := range parts.load().partitions { - if part.records.addedToTxn { - part.records.addedToTxn = false - anyAdded = true - } + anyAdded = part.records.addedToTxn.swap(false) || anyAdded } } @@ -605,6 +828,9 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) // // We call this when beginning a transaction or when ending with an abort. func (cl *Client) maybeRecoverProducerID() (necessary, did bool, err error) { + cl.producer.mu.Lock() + defer cl.producer.mu.Unlock() + id, epoch, err := cl.producerID() if err == nil { return false, false, nil diff --git a/pkg/kgo/txn_test.go b/pkg/kgo/txn_test.go index 85d3a07f..89888d26 100644 --- a/pkg/kgo/txn_test.go +++ b/pkg/kgo/txn_test.go @@ -53,23 +53,22 @@ func TestTxnEtl(t *testing.T) { errs <- fmt.Errorf("unable to end transaction: %v", err) } }() + var safeUnsafe bool for i := 0; i < testRecordLimit; i++ { // We start with a transaction, and every 10k records // we commit and begin a new one. if i > 0 && i%10000 == 0 { - if err := cl.Flush(context.Background()); err != nil { - errs <- fmt.Errorf("unable to flush: %v", err) + how := EndBeginTxnSafe + if safeUnsafe { + how = EndBeginTxnUnsafe } - // Control markers ending a transaction take up - // one record offset, so for all partitions that - // were used in the txn, we bump their offset. - for partition := range partsUsed { - offsets[partition]++ - } - if err := cl.EndTransaction(context.Background(), true); err != nil { - errs <- fmt.Errorf("unable to end transaction: %v", err) - } - if err := cl.BeginTransaction(); err != nil { + safeUnsafe = !safeUnsafe + if err := cl.EndAndBeginTransaction(context.Background(), how, TryCommit, func(_ context.Context, endErr error) error { + if err != nil { + errs <- fmt.Errorf("unable to end transaction: %v", err) + } + return err + }); err != nil { errs <- fmt.Errorf("unable to begin transaction: %v", err) } }