Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conn: check we can use compression in startup #1215

Merged
merged 3 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 93 additions & 60 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
}
defer cancel()

frameTicker := make(chan struct{}, 1)
startupErr := make(chan error)
go func() {
for range frameTicker {
err := c.recv()
if err != nil {
select {
case startupErr <- err:
case <-ctx.Done():
}

return
}
}
}()

go func() {
defer close(frameTicker)
err := c.startup(ctx, frameTicker)
select {
case startupErr <- err:
case <-ctx.Done():
}
}()
startup := &startupCoordinator{
frameTicker: make(chan struct{}),
conn: c,
}

select {
case err := <-startupErr:
if err != nil {
c.Close()
return nil, err
}
case <-ctx.Done():
c.Close()
return nil, errors.New("gocql: no response to connection startup within timeout")
if err := startup.setupConn(ctx); err != nil {
c.close()
return nil, err
}

// dont coalesce startup frames
Expand Down Expand Up @@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) {
return
}

func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
m := map[string]string{
"CQL_VERSION": c.cfg.CQLVersion,
}
type startupCoordinator struct {
conn *Conn
frameTicker chan struct{}
}

if c.compressor != nil {
m["COMPRESSION"] = c.compressor.Name()
func (s *startupCoordinator) setupConn(ctx context.Context) error {
startupErr := make(chan error)
go func() {
for range s.frameTicker {
err := s.conn.recv()
if err != nil {
select {
case startupErr <- err:
case <-ctx.Done():
}

return
}
}
}()

go func() {
defer close(s.frameTicker)
err := s.options(ctx)
select {
case startupErr <- err:
case <-ctx.Done():
}
}()

select {
case err := <-startupErr:
if err != nil {
return err
}
case <-ctx.Done():
return errors.New("gocql: no response to connection startup within timeout")
}

return nil
}

func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
select {
case frameTicker <- struct{}{}:
case s.frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
return nil, ctx.Err()
}

framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
framer, err := s.conn.exec(ctx, frame, nil)
if err != nil {
return nil, err
}

return framer.parseFrame()
}

func (s *startupCoordinator) options(ctx context.Context) error {
frame, err := s.write(ctx, &writeOptionsFrame{})
if err != nil {
return err
}

frame, err := framer.parseFrame()
supported, ok := frame.(*supportedFrame)
if !ok {
return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
}

return s.startup(ctx, supported.supported)
}

func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error {
m := map[string]string{
"CQL_VERSION": s.conn.cfg.CQLVersion,
}

if s.conn.compressor != nil {
comp := supported["COMPRESSION"]
name := s.conn.compressor.Name()
for _, compressor := range comp {
if compressor == name {
m["COMPRESSION"] = compressor
break
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're out of the for loop, but m["COMPRESSION"] is not set, then probably not supported by the backend. So, I reckon, reset s.conn.compressor here to nil, and print something appropriate in the log.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, no, ignore. Not overly important to reset it, as probably the coordinator does startup only once.


if _, ok := m["COMPRESSION"]; !ok {
s.conn.compressor = nil
}
}

frame, err := s.write(ctx, &writeStartupFrame{opts: m})
if err != nil {
return err
}
Expand All @@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
case *readyFrame:
return nil
case *authenticateFrame:
return c.authenticateHandshake(ctx, v, frameTicker)
return s.authenticateHandshake(ctx, v)
default:
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
}
}

func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
if c.auth == nil {
func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error {
if s.conn.auth == nil {
return fmt.Errorf("authentication required (using %q)", authFrame.class)
}

resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class))
if err != nil {
return err
}

req := &writeAuthResponseFrame{data: resp}

for {
select {
case frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}

framer, err := c.exec(ctx, req, nil)
if err != nil {
return err
}

frame, err := framer.parseFrame()
frame, err := s.write(ctx, req)
if err != nil {
return err
}
Expand Down
19 changes: 9 additions & 10 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) {
}

frames := observer.getFrames()

if len(frames) != 2 {
t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames))
}
readyFrame := frames[0]
if readyFrame.Opcode != frameOp(opReady) {
t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode)
expFrames := []frameOp{opSupported, opReady, opResult}
if len(frames) != len(expFrames) {
t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames))
}
voidResultFrame := frames[1]
if voidResultFrame.Opcode != frameOp(opResult) {
t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode)

for i, op := range expFrames {
if op != frames[i].Opcode {
t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i])
}
}
voidResultFrame := frames[2]
if voidResultFrame.Length != int32(4) {
t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length)
}
Expand Down