diff --git a/conn.go b/conn.go index ce7f502bb..0bec55ccf 100644 --- a/conn.go +++ b/conn.go @@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa } defer cancel() - frameTicker := make(chan struct{}, 1) - startupErr := make(chan error) - go func() { - for range frameTicker { - err := c.recv() - if err != nil { - select { - case startupErr <- err: - case <-ctx.Done(): - } - - return - } - } - }() - - go func() { - defer close(frameTicker) - err := c.startup(ctx, frameTicker) - select { - case startupErr <- err: - case <-ctx.Done(): - } - }() + startup := &startupCoordinator{ + frameTicker: make(chan struct{}), + conn: c, + } - select { - case err := <-startupErr: - if err != nil { - c.Close() - return nil, err - } - case <-ctx.Done(): - c.Close() - return nil, errors.New("gocql: no response to connection startup within timeout") + if err := startup.setupConn(ctx); err != nil { + c.close() + return nil, err } // dont coalesce startup frames @@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) { return } -func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { - m := map[string]string{ - "CQL_VERSION": c.cfg.CQLVersion, - } +type startupCoordinator struct { + conn *Conn + frameTicker chan struct{} +} - if c.compressor != nil { - m["COMPRESSION"] = c.compressor.Name() +func (s *startupCoordinator) setupConn(ctx context.Context) error { + startupErr := make(chan error) + go func() { + for range s.frameTicker { + err := s.conn.recv() + if err != nil { + select { + case startupErr <- err: + case <-ctx.Done(): + } + + return + } + } + }() + + go func() { + defer close(s.frameTicker) + err := s.options(ctx) + select { + case startupErr <- err: + case <-ctx.Done(): + } + }() + + select { + case err := <-startupErr: + if err != nil { + return err + } + case <-ctx.Done(): + return errors.New("gocql: no response to connection startup within timeout") } + return nil +} + +func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { select { - case frameTicker <- struct{}{}: + case s.frameTicker <- struct{}{}: case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } - framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil) + framer, err := s.conn.exec(ctx, frame, nil) + if err != nil { + return nil, err + } + + return framer.parseFrame() +} + +func (s *startupCoordinator) options(ctx context.Context) error { + frame, err := s.write(ctx, &writeOptionsFrame{}) if err != nil { return err } - frame, err := framer.parseFrame() + supported, ok := frame.(*supportedFrame) + if !ok { + return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + } + + return s.startup(ctx, supported.supported) +} + +func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { + m := map[string]string{ + "CQL_VERSION": s.conn.cfg.CQLVersion, + } + + if s.conn.compressor != nil { + comp := supported["COMPRESSION"] + name := s.conn.compressor.Name() + for _, compressor := range comp { + if compressor == name { + m["COMPRESSION"] = compressor + break + } + } + + if _, ok := m["COMPRESSION"]; !ok { + s.conn.compressor = nil + } + } + + frame, err := s.write(ctx, &writeStartupFrame{opts: m}) if err != nil { return err } @@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { case *readyFrame: return nil case *authenticateFrame: - return c.authenticateHandshake(ctx, v, frameTicker) + return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error { - if c.auth == nil { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { + if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } - resp, challenger, err := c.auth.Challenge([]byte(authFrame.class)) + resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} - for { - select { - case frameTicker <- struct{}{}: - case <-ctx.Done(): - return ctx.Err() - } - - framer, err := c.exec(ctx, req, nil) - if err != nil { - return err - } - - frame, err := framer.parseFrame() + frame, err := s.write(ctx, req) if err != nil { return err } diff --git a/conn_test.go b/conn_test.go index efa6b8e81..1a2adcbf2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) { } frames := observer.getFrames() - - if len(frames) != 2 { - t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames)) - } - readyFrame := frames[0] - if readyFrame.Opcode != frameOp(opReady) { - t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode) + expFrames := []frameOp{opSupported, opReady, opResult} + if len(frames) != len(expFrames) { + t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) } - voidResultFrame := frames[1] - if voidResultFrame.Opcode != frameOp(opResult) { - t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode) + + for i, op := range expFrames { + if op != frames[i].Opcode { + t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i]) + } } + voidResultFrame := frames[2] if voidResultFrame.Length != int32(4) { t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length) }