Skip to content

Commit

Permalink
Merge pull request #348 from nats-io/fix_read_after_connect
Browse files Browse the repository at this point in the history
[FIXED] Protocols received right after first PONG may be processed
  • Loading branch information
kozlovic authored Mar 6, 2018
2 parents bfca8af + 15952d7 commit f90dee2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 11 deletions.
48 changes: 37 additions & 11 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
Expand Down Expand Up @@ -1243,38 +1244,40 @@ func (nc *Conn) sendConnect() error {
return err
}

// Now read the response from the server.
br := bufio.NewReaderSize(nc.conn, defaultBufSize)
line, err := br.ReadString('\n')
// We don't want to read more than we need here, otherwise
// we would need to transfer the excess read data to the readLoop.
// Since in normal situations we just are looking for a PONG\r\n,
// reading byte-by-byte here is ok.
proto, err := nc.readProto()
if err != nil {
return err
}

// If opts.Verbose is set, handle +OK
if nc.Opts.Verbose && line == okProto {
if nc.Opts.Verbose && proto == okProto {
// Read the rest now...
line, err = br.ReadString('\n')
proto, err = nc.readProto()
if err != nil {
return err
}
}

// We expect a PONG
if line != pongProto {
if proto != pongProto {
// But it could be something else, like -ERR

// Since we no longer use ReadLine(), trim the trailing "\r\n"
line = strings.TrimRight(line, "\r\n")
proto = strings.TrimRight(proto, "\r\n")

// If it's a server error...
if strings.HasPrefix(line, _ERR_OP_) {
if strings.HasPrefix(proto, _ERR_OP_) {
// Remove -ERR, trim spaces and quotes, and convert to lower case.
line = normalizeErr(line)
return errors.New("nats: " + line)
proto = normalizeErr(proto)
return errors.New("nats: " + proto)
}

// Notify that we got an unexpected protocol.
return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, line)
return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, proto)
}

// This is where we are truly connected.
Expand All @@ -1283,6 +1286,29 @@ func (nc *Conn) sendConnect() error {
return nil
}

// reads a protocol one byte at a time.
func (nc *Conn) readProto() (string, error) {
var (
_buf = [10]byte{}
buf = _buf[:0]
b = [1]byte{}
protoEnd = byte('\n')
)
for {
if _, err := nc.conn.Read(b[:1]); err != nil {
// Do not report EOF error
if err == io.EOF {
return string(buf), nil
}
return "", err
}
buf = append(buf, b[0])
if b[0] == protoEnd {
return string(buf), nil
}
}
}

// A control protocol line.
type control struct {
op, args string
Expand Down
77 changes: 77 additions & 0 deletions test/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -1477,12 +1478,16 @@ func TestCustomFlusherTimeout(t *testing.T) {

func TestNewServers(t *testing.T) {
s1Opts := test.DefaultTestOptions
s1Opts.Host = "127.0.0.1"
s1Opts.Port = 4222
s1Opts.Cluster.Host = "localhost"
s1Opts.Cluster.Port = 6222
s1 := test.RunServer(&s1Opts)
defer s1.Shutdown()

s2Opts := test.DefaultTestOptions
s2Opts.Host = "127.0.0.1"
s2Opts.Port = 4223
s2Opts.Port = s1Opts.Port + 1
s2Opts.Cluster.Host = "localhost"
s2Opts.Cluster.Port = 6223
Expand Down Expand Up @@ -1526,6 +1531,8 @@ func TestNewServers(t *testing.T) {

// Start a new server.
s3Opts := test.DefaultTestOptions
s1Opts.Host = "127.0.0.1"
s1Opts.Port = 4224
s3Opts.Port = s2Opts.Port + 1
s3Opts.Cluster.Host = "localhost"
s3Opts.Cluster.Port = 6224
Expand Down Expand Up @@ -1758,3 +1765,73 @@ func TestBarrier(t *testing.T) {
t.Fatal("Barrier was blocked")
}
}

func TestReceiveInfoRightAfterFirstPong(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Error on listen: %v", err)
}
tl := l.(*net.TCPListener)
defer tl.Close()
addr := tl.Addr().(*net.TCPAddr)

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()

c, err := tl.Accept()
if err != nil {
return
}
defer c.Close()
// Send the initial INFO
c.Write([]byte("INFO {}\r\n"))
buf := make([]byte, 0, 100)
b := make([]byte, 100)
for {
n, err := c.Read(b)
if err != nil {
return
}
buf = append(buf, b[:n]...)
if bytes.Contains(buf, []byte("PING\r\n")) {
break
}
}
// Send PONG and following INFO in one go (or at least try).
// The processing of PONG in sendConnect() should leave the
// rest for the readLoop to process.
c.Write([]byte(fmt.Sprintf("PONG\r\nINFO {\"connect_urls\":[\"127.0.0.1:%d\", \"me:1\"]}\r\n", addr.Port)))
// Wait for client to disconnect
for {
if _, err := c.Read(buf); err != nil {
return
}
}
}()

nc, err := nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", addr.Port))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer nc.Close()
var (
ds []string
timeout = time.Now().Add(2 * time.Second)
ok = false
)
for time.Now().Before(timeout) {
ds = nc.DiscoveredServers()
if len(ds) == 1 && ds[0] == "nats://me:1" {
ok = true
break
}
time.Sleep(50 * time.Millisecond)
}
nc.Close()
wg.Wait()
if !ok {
t.Fatalf("Unexpected discovered servers: %v", ds)
}
}

0 comments on commit f90dee2

Please sign in to comment.