From d910d573b30e6235137f97532d94e4bcd49923b3 Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Thu, 11 Oct 2018 10:36:05 +0100 Subject: [PATCH 1/3] conn: check we can use compression in startup If the user asks to use compression check that the remote server supports it by making an options call first. Rejig conn startup to be self contained. --- conn.go | 149 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 89 insertions(+), 60 deletions(-) diff --git a/conn.go b/conn.go index ce7f502bb..eca672541 100644 --- a/conn.go +++ b/conn.go @@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa } defer cancel() - frameTicker := make(chan struct{}, 1) - startupErr := make(chan error) - go func() { - for range frameTicker { - err := c.recv() - if err != nil { - select { - case startupErr <- err: - case <-ctx.Done(): - } - - return - } - } - }() - - go func() { - defer close(frameTicker) - err := c.startup(ctx, frameTicker) - select { - case startupErr <- err: - case <-ctx.Done(): - } - }() + startup := &startupCoordinator{ + frameTicker: make(chan struct{}), + conn: c, + } - select { - case err := <-startupErr: - if err != nil { - c.Close() - return nil, err - } - case <-ctx.Done(): - c.Close() - return nil, errors.New("gocql: no response to connection startup within timeout") + if err := startup.setupConn(ctx); err != nil { + c.close() + return nil, err } // dont coalesce startup frames @@ -300,27 +274,94 @@ func (c *Conn) Read(p []byte) (n int, err error) { return } -func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { - m := map[string]string{ - "CQL_VERSION": c.cfg.CQLVersion, - } +type startupCoordinator struct { + conn *Conn + frameTicker chan struct{} +} + +func (s *startupCoordinator) setupConn(ctx context.Context) error { + startupErr := make(chan error) + go func() { + for range s.frameTicker { + err := s.conn.recv() + if err != nil { + select { + case startupErr <- err: + case <-ctx.Done(): + } - if c.compressor != nil { - m["COMPRESSION"] = c.compressor.Name() + return + } + } + }() + + go func() { + defer close(s.frameTicker) + err := s.options(ctx) + select { + case startupErr <- err: + case <-ctx.Done(): + } + }() + + select { + case err := <-startupErr: + if err != nil { + return err + } + case <-ctx.Done(): + return errors.New("gocql: no response to connection startup within timeout") } + return nil +} + +func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { select { - case frameTicker <- struct{}{}: + case s.frameTicker <- struct{}{}: case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } - framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil) + framer, err := s.conn.exec(ctx, frame, nil) + if err != nil { + return nil, err + } + + return framer.parseFrame() +} + +func (s *startupCoordinator) options(ctx context.Context) error { + frame, err := s.write(ctx, &writeOptionsFrame{}) if err != nil { return err } - frame, err := framer.parseFrame() + supported, ok := frame.(*supportedFrame) + if !ok { + return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + } + + return s.startup(ctx, supported.supported) +} + +func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { + m := map[string]string{ + "CQL_VERSION": s.conn.cfg.CQLVersion, + } + + if s.conn.compressor != nil { + comp := supported["COMPRESSION"] + name := s.conn.compressor.Name() + for _, compressor := range comp { + if compressor == name { + m["COMPRESSION"] = compressor + break + } + } + } + + frame, err := s.write(ctx, &writeStartupFrame{opts: m}) if err != nil { return err } @@ -331,37 +372,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { case *readyFrame: return nil case *authenticateFrame: - return c.authenticateHandshake(ctx, v, frameTicker) + return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error { - if c.auth == nil { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { + if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } - resp, challenger, err := c.auth.Challenge([]byte(authFrame.class)) + resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} - for { - select { - case frameTicker <- struct{}{}: - case <-ctx.Done(): - return ctx.Err() - } - - framer, err := c.exec(ctx, req, nil) - if err != nil { - return err - } - - frame, err := framer.parseFrame() + frame, err := s.write(ctx, req) if err != nil { return err } From d1c522526fd216bc845401c0dad8c118ffed92bd Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Thu, 11 Oct 2018 16:55:38 +0100 Subject: [PATCH 2/3] simplify TestFrameHeaderObserver --- conn_test.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/conn_test.go b/conn_test.go index efa6b8e81..1a2adcbf2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) { } frames := observer.getFrames() - - if len(frames) != 2 { - t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames)) - } - readyFrame := frames[0] - if readyFrame.Opcode != frameOp(opReady) { - t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode) + expFrames := []frameOp{opSupported, opReady, opResult} + if len(frames) != len(expFrames) { + t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) } - voidResultFrame := frames[1] - if voidResultFrame.Opcode != frameOp(opResult) { - t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode) + + for i, op := range expFrames { + if op != frames[i].Opcode { + t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i]) + } } + voidResultFrame := frames[2] if voidResultFrame.Length != int32(4) { t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length) } From 4f7793b4428cfd75559eecc3c44049a822c0bbae Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Fri, 12 Oct 2018 10:48:22 +0100 Subject: [PATCH 3/3] disable compression if the server doesnt support it --- conn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conn.go b/conn.go index eca672541..0bec55ccf 100644 --- a/conn.go +++ b/conn.go @@ -359,6 +359,10 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ break } } + + if _, ok := m["COMPRESSION"]; !ok { + s.conn.compressor = nil + } } frame, err := s.write(ctx, &writeStartupFrame{opts: m})