diff --git a/batch.go b/batch.go index 012cf1609..f9f3e5227 100644 --- a/batch.go +++ b/batch.go @@ -2,6 +2,7 @@ package kafka import ( "bufio" + "errors" "io" "sync" "time" @@ -82,7 +83,7 @@ func (batch *Batch) close() (err error) { batch.msgs.discard() } - if err = batch.err; err == io.EOF { + if err = batch.err; errors.Is(batch.err, io.EOF) { err = nil } @@ -93,7 +94,8 @@ func (batch *Batch) close() (err error) { conn.mutex.Unlock() if err != nil { - if _, ok := err.(Error); !ok && err != io.ErrShortBuffer { + var kafkaError Error + if !errors.As(err, &kafkaError) && !errors.Is(err, io.ErrShortBuffer) { conn.Close() } } @@ -238,11 +240,11 @@ func (batch *Batch) readMessage( var lastOffset int64 offset, lastOffset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val) - switch err { - case nil: + switch { + case err == nil: batch.offset = offset + 1 batch.lastOffset = lastOffset - case errShortRead: + case errors.Is(err, errShortRead): // As an "optimization" kafka truncates the returned response after // producing MaxBytes, which could then cause the code to return // errShortRead. @@ -272,7 +274,7 @@ func (batch *Batch) readMessage( // to MaxBytes truncation // - `batch.lastOffset` to ensure that the message format contains // `lastOffset` - if batch.err == io.EOF && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 { + if errors.Is(batch.err, io.EOF) && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 { // Log compaction can create batches that end with compacted // records so the normal strategy that increments the "next" // offset as records are read doesn't work as the compacted diff --git a/batch_test.go b/batch_test.go index d7bc48d47..2a52752c6 100644 --- a/batch_test.go +++ b/batch_test.go @@ -2,6 +2,7 @@ package kafka import ( "context" + "errors" "io" "net" "strconv" @@ -30,11 +31,11 @@ func TestBatchDontExpectEOF(t *testing.T) { batch := conn.ReadBatch(1024, 8192) - if _, err := batch.ReadMessage(); err != io.ErrUnexpectedEOF { + if _, err := batch.ReadMessage(); !errors.Is(err, io.ErrUnexpectedEOF) { t.Error("bad error when reading message:", err) } - if err := batch.Close(); err != io.ErrUnexpectedEOF { + if err := batch.Close(); !errors.Is(err, io.ErrUnexpectedEOF) { t.Error("bad error when closing the batch:", err) } } diff --git a/client.go b/client.go index 29602d212..d965040e8 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package kafka import ( "context" "errors" + "fmt" "net" "time" @@ -67,7 +68,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get topic metadata :%w", err) } topic := metadata.Topics[0] @@ -85,7 +86,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get offsets: %w", err) } topicOffsets := offsets.Topics[topic.Name] diff --git a/compress/snappy/xerial.go b/compress/snappy/xerial.go index 6445d7486..e2725af9c 100644 --- a/compress/snappy/xerial.go +++ b/compress/snappy/xerial.go @@ -3,6 +3,7 @@ package snappy import ( "bytes" "encoding/binary" + "errors" "io" "github.com/klauspost/compress/snappy" @@ -64,7 +65,7 @@ func (x *xerialReader) WriteTo(w io.Writer) (int64, error) { } if _, err := x.readChunk(nil); err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } return wn, err @@ -128,7 +129,7 @@ func (x *xerialReader) readChunk(dst []byte) (int, error) { n, err := x.read(x.input[len(x.input):cap(x.input)]) x.input = x.input[:len(x.input)+n] if err != nil { - if err == io.EOF && len(x.input) > 0 { + if errors.Is(err, io.EOF) && len(x.input) > 0 { break } return 0, err @@ -212,7 +213,7 @@ func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) { } if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } return wn, err diff --git a/conn.go b/conn.go index 2a2411adb..de6fc9b4b 100644 --- a/conn.go +++ b/conn.go @@ -853,7 +853,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { default: throttle, highWaterMark, remain, err = readFetchResponseHeaderV2(&c.rbuf, size) } - if err == errShortRead { + if errors.Is(err, errShortRead) { err = checkTimeoutErr(adjustedDeadline) } @@ -865,9 +865,10 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { msgs, err = newMessageSetReader(&c.rbuf, remain) } } - if err == errShortRead { + if errors.Is(err, errShortRead) { err = checkTimeoutErr(adjustedDeadline) } + return &Batch{ conn: c, msgs: msgs, diff --git a/conn_test.go b/conn_test.go index f3921818a..7c4186856 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package kafka import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -640,10 +641,13 @@ func testConnReadBatchWithMaxWait(t *testing.T, conn *Conn) { conn.Seek(0, SeekAbsolute) conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) batch = conn.ReadBatchWith(cfg) + var netErr net.Error if err := batch.Err(); err == nil { t.Fatal("should have timed out, but got no error") - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { - t.Fatalf("should have timed out, but got: %v", err) + } else if errors.As(err, &netErr) { + if !netErr.Timeout() { + t.Fatalf("should have timed out, but got: %v", err) + } } } @@ -761,7 +765,7 @@ func testConnFindCoordinator(t *testing.T, conn *Conn) { func testConnJoinGroupInvalidGroupID(t *testing.T, conn *Conn) { _, err := conn.joinGroup(joinGroupRequestV1{}) - if err != InvalidGroupId && err != NotCoordinatorForGroup { + if !errors.Is(err, InvalidGroupId) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", InvalidGroupId, NotCoordinatorForGroup, err) } } @@ -773,7 +777,7 @@ func testConnJoinGroupInvalidSessionTimeout(t *testing.T, conn *Conn) { _, err := conn.joinGroup(joinGroupRequestV1{ GroupID: groupID, }) - if err != InvalidSessionTimeout && err != NotCoordinatorForGroup { + if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err) } } @@ -786,7 +790,7 @@ func testConnJoinGroupInvalidRefreshTimeout(t *testing.T, conn *Conn) { GroupID: groupID, SessionTimeout: int32(3 * time.Second / time.Millisecond), }) - if err != InvalidSessionTimeout && err != NotCoordinatorForGroup { + if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err) } } @@ -798,7 +802,7 @@ func testConnHeartbeatErr(t *testing.T, conn *Conn) { _, err := conn.syncGroup(syncGroupRequestV0{ GroupID: groupID, }) - if err != UnknownMemberId && err != NotCoordinatorForGroup { + if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err) } } @@ -810,7 +814,7 @@ func testConnLeaveGroupErr(t *testing.T, conn *Conn) { _, err := conn.leaveGroup(leaveGroupRequestV0{ GroupID: groupID, }) - if err != UnknownMemberId && err != NotCoordinatorForGroup { + if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err) } } @@ -822,7 +826,7 @@ func testConnSyncGroupErr(t *testing.T, conn *Conn) { _, err := conn.syncGroup(syncGroupRequestV0{ GroupID: groupID, }) - if err != UnknownMemberId && err != NotCoordinatorForGroup { + if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) { t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err) } } @@ -985,7 +989,7 @@ func testConnReadShortBuffer(t *testing.T, conn *Conn) { b[3] = 0 n, err := conn.Read(b) - if err != io.ErrShortBuffer { + if !errors.Is(err, io.ErrShortBuffer) { t.Error("bad error:", i, err) } if n != 4 { @@ -1061,7 +1065,7 @@ func testDeleteTopicsInvalidTopic(t *testing.T, conn *Conn) { } conn.SetDeadline(time.Now().Add(5 * time.Second)) err = conn.DeleteTopics("invalid-topic", topic) - if err != UnknownTopicOrPartition { + if !errors.Is(err, UnknownTopicOrPartition) { t.Fatalf("expected UnknownTopicOrPartition error, but got %v", err) } partitions, err := conn.ReadPartitions(topic) @@ -1154,7 +1158,7 @@ func TestUnsupportedSASLMechanism(t *testing.T) { } defer conn.Close() - if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism { + if err := conn.saslHandshake("FOO"); !errors.Is(err, UnsupportedSASLMechanism) { t.Errorf("Expected UnsupportedSASLMechanism but got %v", err) } } diff --git a/consumergroup.go b/consumergroup.go index b6c5af3b6..ab6595b3c 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -1026,7 +1026,7 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup // assignments for the topic. this matches the behavior of the official // clients: java, python, and librdkafka. // a topic watcher can trigger a rebalance when the topic comes into being. - if err != nil && err != UnknownTopicOrPartition { + if err != nil && !errors.Is(err, UnknownTopicOrPartition) { return nil, err } diff --git a/consumergroup_test.go b/consumergroup_test.go index 62add3dd8..2c48423e6 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -285,7 +285,7 @@ func TestConsumerGroup(t *testing.T) { if gen != nil { t.Errorf("expected generation to be nil") } - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Errorf("expected context.Canceled, but got %+v", err) } }, @@ -301,7 +301,7 @@ func TestConsumerGroup(t *testing.T) { if gen != nil { t.Errorf("expected generation to be nil") } - if err != ErrGroupClosed { + if !errors.Is(err, ErrGroupClosed) { t.Errorf("expected ErrGroupClosed, but got %+v", err) } }, @@ -398,7 +398,7 @@ func TestConsumerGroupErrors(t *testing.T) { gen, err := group.Next(ctx) if err == nil { t.Errorf("expected an error") - } else if err != NotCoordinatorForGroup { + } else if !errors.Is(err, NotCoordinatorForGroup) { t.Errorf("got wrong error: %+v", err) } if gen != nil { @@ -460,7 +460,7 @@ func TestConsumerGroupErrors(t *testing.T) { gen, err := group.Next(ctx) if err == nil { t.Errorf("expected an error") - } else if err != InvalidTopic { + } else if !errors.Is(err, InvalidTopic) { t.Errorf("got wrong error: %+v", err) } if gen != nil { @@ -540,7 +540,7 @@ func TestConsumerGroupErrors(t *testing.T) { gen, err := group.Next(ctx) if err == nil { t.Errorf("expected an error") - } else if err != InvalidTopic { + } else if !errors.Is(err, InvalidTopic) { t.Errorf("got wrong error: %+v", err) } if gen != nil { @@ -672,7 +672,7 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) { case <-time.After(time.Second): t.Fatal("timed out waiting for func to run") case err := <-ch: - if err != ErrGenerationEnded { + if !errors.Is(err, ErrGenerationEnded) { t.Fatalf("expected %v but got %v", ErrGenerationEnded, err) } } diff --git a/dialer.go b/dialer.go index a733f7eeb..e8d02839e 100644 --- a/dialer.go +++ b/dialer.go @@ -3,6 +3,7 @@ package kafka import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -276,7 +277,7 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C c, err := d.dialContext(ctx, network, address) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to dial: %w", err) } conn := NewConnWith(c, connCfg) @@ -284,7 +285,7 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C if d.SASLMechanism != nil { host, port, err := splitHostPortNumber(address) if err != nil { - return nil, err + return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err) } metadata := &sasl.Metadata{ Host: host, @@ -292,7 +293,7 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C } if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil { _ = conn.Close() - return nil, err + return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err) } } @@ -307,19 +308,19 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C // responsibility of the caller. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { - return err + return fmt.Errorf("SASL handshake failed: %w", err) } sess, state, err := d.SASLMechanism.Start(ctx) if err != nil { - return err + return fmt.Errorf("SASL authentication process could not be started: %w", err) } for completed := false; !completed; { challenge, err := conn.saslAuthenticate(state) - switch err { - case nil: - case io.EOF: + switch { + case err == nil: + case errors.Is(err, io.EOF): // the broker may communicate a failed exchange by closing the // connection (esp. in the case where we're passing opaque sasl // data over the wire since there's no protocol info). @@ -330,7 +331,7 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { completed, state, err = sess.Next(ctx, challenge) if err != nil { - return err + return fmt.Errorf("SASL authentication process has failed: %w", err) } } @@ -340,7 +341,7 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { address, err := lookupHost(ctx, addr, d.Resolver) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to resolve host: %w", err) } dial := d.DialFunc @@ -355,7 +356,7 @@ func (d *Dialer) dialContext(ctx context.Context, network string, addr string) ( conn, err := dial(ctx, network, address) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open connection to %s: %w", address, err) } if d.TLS != nil { @@ -469,7 +470,7 @@ func lookupHost(ctx context.Context, address string, resolver Resolver) (string, if resolver != nil { resolved, err := resolver.LookupHost(ctx, host) if err != nil { - return "", err + return "", fmt.Errorf("failed to resolve host %s: %w", host, err) } // if the resolver doesn't return anything, we'll fall back on the provided diff --git a/error.go b/error.go index a1146bcae..23e099ffc 100644 --- a/error.go +++ b/error.go @@ -516,15 +516,15 @@ func isTransientNetworkError(err error) bool { } func silentEOF(err error) error { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } return err } func dontExpectEOF(err error) error { - if err == io.EOF { - err = io.ErrUnexpectedEOF + if errors.Is(err, io.EOF) { + return io.ErrUnexpectedEOF } return err } diff --git a/fetch_test.go b/fetch_test.go index 950e82278..c610d82e9 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -273,7 +273,7 @@ func TestClientPipeline(t *testing.T) { for { r, err := res.Records.ReadRecord() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } t.Fatal(err) diff --git a/protocol/buffer.go b/protocol/buffer.go index 901fdf3e2..23909c692 100644 --- a/protocol/buffer.go +++ b/protocol/buffer.go @@ -2,6 +2,7 @@ package protocol import ( "bytes" + "errors" "fmt" "io" "math" @@ -151,7 +152,7 @@ func (p *page) ReadAt(b []byte, off int64) (int, error) { func (p *page) ReadFrom(r io.Reader) (int64, error) { n, err := io.ReadFull(r, p.buffer[p.length:]) - if err == io.EOF || err == io.ErrUnexpectedEOF { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { err = nil } p.length += n diff --git a/read.go b/read.go index d2dcb299a..71a7b9205 100644 --- a/read.go +++ b/read.go @@ -79,7 +79,7 @@ func readVarInt(r *bufio.Reader, sz int, v *int64) (remain int, err error) { // Fill the buffer: ask for one more byte, but in practice the reader // will load way more from the underlying stream. if _, err := r.Peek(1); err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = errShortRead } return sz, err diff --git a/reader_test.go b/reader_test.go index 3862b2bfa..d4f2e5519 100644 --- a/reader_test.go +++ b/reader_test.go @@ -3,6 +3,7 @@ package kafka import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -1622,7 +1623,7 @@ func TestConsumerGroupWithGroupTopicsMultple(t *testing.T) { for { msg, err := r.ReadMessage(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { t.Log("reader closed") break } diff --git a/writer_test.go b/writer_test.go index da95423fa..4b9918766 100644 --- a/writer_test.go +++ b/writer_test.go @@ -411,7 +411,7 @@ func readPartition(topic string, partition int, offset int64) (msgs []Message, e var msg Message if msg, err = batch.ReadMessage(); err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } return