Skip to content

Commit

Permalink
conn: check we can use compression in startup (#1215)
Browse files Browse the repository at this point in the history
* conn: check we can use compression in startup

If the user asks to use compression check that the remote server
supports it by making an options call first.

Rejig conn startup to be self contained.

* simplify TestFrameHeaderObserver

* disable compression if the server doesnt support it
  • Loading branch information
Zariel authored Oct 12, 2018
1 parent 26b68fd commit 44e29ed
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 70 deletions.
153 changes: 93 additions & 60 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
}
defer cancel()

frameTicker := make(chan struct{}, 1)
startupErr := make(chan error)
go func() {
for range frameTicker {
err := c.recv()
if err != nil {
select {
case startupErr <- err:
case <-ctx.Done():
}

return
}
}
}()

go func() {
defer close(frameTicker)
err := c.startup(ctx, frameTicker)
select {
case startupErr <- err:
case <-ctx.Done():
}
}()
startup := &startupCoordinator{
frameTicker: make(chan struct{}),
conn: c,
}

select {
case err := <-startupErr:
if err != nil {
c.Close()
return nil, err
}
case <-ctx.Done():
c.Close()
return nil, errors.New("gocql: no response to connection startup within timeout")
if err := startup.setupConn(ctx); err != nil {
c.close()
return nil, err
}

// dont coalesce startup frames
Expand Down Expand Up @@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) {
return
}

func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
m := map[string]string{
"CQL_VERSION": c.cfg.CQLVersion,
}
type startupCoordinator struct {
conn *Conn
frameTicker chan struct{}
}

if c.compressor != nil {
m["COMPRESSION"] = c.compressor.Name()
func (s *startupCoordinator) setupConn(ctx context.Context) error {
startupErr := make(chan error)
go func() {
for range s.frameTicker {
err := s.conn.recv()
if err != nil {
select {
case startupErr <- err:
case <-ctx.Done():
}

return
}
}
}()

go func() {
defer close(s.frameTicker)
err := s.options(ctx)
select {
case startupErr <- err:
case <-ctx.Done():
}
}()

select {
case err := <-startupErr:
if err != nil {
return err
}
case <-ctx.Done():
return errors.New("gocql: no response to connection startup within timeout")
}

return nil
}

func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
select {
case frameTicker <- struct{}{}:
case s.frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
return nil, ctx.Err()
}

framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
framer, err := s.conn.exec(ctx, frame, nil)
if err != nil {
return nil, err
}

return framer.parseFrame()
}

func (s *startupCoordinator) options(ctx context.Context) error {
frame, err := s.write(ctx, &writeOptionsFrame{})
if err != nil {
return err
}

frame, err := framer.parseFrame()
supported, ok := frame.(*supportedFrame)
if !ok {
return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
}

return s.startup(ctx, supported.supported)
}

func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error {
m := map[string]string{
"CQL_VERSION": s.conn.cfg.CQLVersion,
}

if s.conn.compressor != nil {
comp := supported["COMPRESSION"]
name := s.conn.compressor.Name()
for _, compressor := range comp {
if compressor == name {
m["COMPRESSION"] = compressor
break
}
}

if _, ok := m["COMPRESSION"]; !ok {
s.conn.compressor = nil
}
}

frame, err := s.write(ctx, &writeStartupFrame{opts: m})
if err != nil {
return err
}
Expand All @@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
case *readyFrame:
return nil
case *authenticateFrame:
return c.authenticateHandshake(ctx, v, frameTicker)
return s.authenticateHandshake(ctx, v)
default:
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
}
}

func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
if c.auth == nil {
func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error {
if s.conn.auth == nil {
return fmt.Errorf("authentication required (using %q)", authFrame.class)
}

resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class))
if err != nil {
return err
}

req := &writeAuthResponseFrame{data: resp}

for {
select {
case frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}

framer, err := c.exec(ctx, req, nil)
if err != nil {
return err
}

frame, err := framer.parseFrame()
frame, err := s.write(ctx, req)
if err != nil {
return err
}
Expand Down
19 changes: 9 additions & 10 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) {
}

frames := observer.getFrames()

if len(frames) != 2 {
t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames))
}
readyFrame := frames[0]
if readyFrame.Opcode != frameOp(opReady) {
t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode)
expFrames := []frameOp{opSupported, opReady, opResult}
if len(frames) != len(expFrames) {
t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames))
}
voidResultFrame := frames[1]
if voidResultFrame.Opcode != frameOp(opResult) {
t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode)

for i, op := range expFrames {
if op != frames[i].Opcode {
t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i])
}
}
voidResultFrame := frames[2]
if voidResultFrame.Length != int32(4) {
t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length)
}
Expand Down

0 comments on commit 44e29ed

Please sign in to comment.