From 44e29ed5b8a4b4fff39d7ebaa19976c6f852075b Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Fri, 12 Oct 2018 11:03:15 +0100 Subject: [PATCH] conn: check we can use compression in startup (#1215) * conn: check we can use compression in startup If the user asks to use compression check that the remote server supports it by making an options call first. Rejig conn startup to be self contained. * simplify TestFrameHeaderObserver * disable compression if the server doesnt support it --- conn.go | 153 +++++++++++++++++++++++++++++++-------------------- conn_test.go | 19 +++---- 2 files changed, 102 insertions(+), 70 deletions(-) diff --git a/conn.go b/conn.go index ce7f502bb..0bec55ccf 100644 --- a/conn.go +++ b/conn.go @@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa } defer cancel() - frameTicker := make(chan struct{}, 1) - startupErr := make(chan error) - go func() { - for range frameTicker { - err := c.recv() - if err != nil { - select { - case startupErr <- err: - case <-ctx.Done(): - } - - return - } - } - }() - - go func() { - defer close(frameTicker) - err := c.startup(ctx, frameTicker) - select { - case startupErr <- err: - case <-ctx.Done(): - } - }() + startup := &startupCoordinator{ + frameTicker: make(chan struct{}), + conn: c, + } - select { - case err := <-startupErr: - if err != nil { - c.Close() - return nil, err - } - case <-ctx.Done(): - c.Close() - return nil, errors.New("gocql: no response to connection startup within timeout") + if err := startup.setupConn(ctx); err != nil { + c.close() + return nil, err } // dont coalesce startup frames @@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) { return } -func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { - m := map[string]string{ - "CQL_VERSION": c.cfg.CQLVersion, - } +type startupCoordinator struct { + conn *Conn + frameTicker chan struct{} +} - if c.compressor != nil { - m["COMPRESSION"] = c.compressor.Name() +func (s *startupCoordinator) setupConn(ctx context.Context) error { + startupErr := make(chan error) + go func() { + for range s.frameTicker { + err := s.conn.recv() + if err != nil { + select { + case startupErr <- err: + case <-ctx.Done(): + } + + return + } + } + }() + + go func() { + defer close(s.frameTicker) + err := s.options(ctx) + select { + case startupErr <- err: + case <-ctx.Done(): + } + }() + + select { + case err := <-startupErr: + if err != nil { + return err + } + case <-ctx.Done(): + return errors.New("gocql: no response to connection startup within timeout") } + return nil +} + +func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { select { - case frameTicker <- struct{}{}: + case s.frameTicker <- struct{}{}: case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } - framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil) + framer, err := s.conn.exec(ctx, frame, nil) + if err != nil { + return nil, err + } + + return framer.parseFrame() +} + +func (s *startupCoordinator) options(ctx context.Context) error { + frame, err := s.write(ctx, &writeOptionsFrame{}) if err != nil { return err } - frame, err := framer.parseFrame() + supported, ok := frame.(*supportedFrame) + if !ok { + return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + } + + return s.startup(ctx, supported.supported) +} + +func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { + m := map[string]string{ + "CQL_VERSION": s.conn.cfg.CQLVersion, + } + + if s.conn.compressor != nil { + comp := supported["COMPRESSION"] + name := s.conn.compressor.Name() + for _, compressor := range comp { + if compressor == name { + m["COMPRESSION"] = compressor + break + } + } + + if _, ok := m["COMPRESSION"]; !ok { + s.conn.compressor = nil + } + } + + frame, err := s.write(ctx, &writeStartupFrame{opts: m}) if err != nil { return err } @@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { case *readyFrame: return nil case *authenticateFrame: - return c.authenticateHandshake(ctx, v, frameTicker) + return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error { - if c.auth == nil { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { + if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } - resp, challenger, err := c.auth.Challenge([]byte(authFrame.class)) + resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} - for { - select { - case frameTicker <- struct{}{}: - case <-ctx.Done(): - return ctx.Err() - } - - framer, err := c.exec(ctx, req, nil) - if err != nil { - return err - } - - frame, err := framer.parseFrame() + frame, err := s.write(ctx, req) if err != nil { return err } diff --git a/conn_test.go b/conn_test.go index efa6b8e81..1a2adcbf2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) { } frames := observer.getFrames() - - if len(frames) != 2 { - t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames)) - } - readyFrame := frames[0] - if readyFrame.Opcode != frameOp(opReady) { - t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode) + expFrames := []frameOp{opSupported, opReady, opResult} + if len(frames) != len(expFrames) { + t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) } - voidResultFrame := frames[1] - if voidResultFrame.Opcode != frameOp(opResult) { - t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode) + + for i, op := range expFrames { + if op != frames[i].Opcode { + t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i]) + } } + voidResultFrame := frames[2] if voidResultFrame.Length != int32(4) { t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length) }