Skip to content

Commit

Permalink
spop: implement correct disconnect handling
Browse files Browse the repository at this point in the history
  • Loading branch information
fionera committed Sep 30, 2024
1 parent 1d49c24 commit 347b227
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 34 deletions.
16 changes: 14 additions & 2 deletions spop/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package spop

import (
"context"
"errors"
"fmt"
"log"
"net"
Expand Down Expand Up @@ -46,12 +47,23 @@ func (a *Agent) Serve(l net.Listener) error {
return fmt.Errorf("accepting conn: %w", err)
}

p := newProtocolClient(a.BaseContext, nc, a.Handler)
if tcp, ok := nc.(*net.TCPConn); ok {
err = tcp.SetWriteBuffer(maxFrameSize * 4)
if err != nil {
return err
}
err = tcp.SetReadBuffer(maxFrameSize * 4)
if err != nil {
return err
}
}

p := newProtocolClient(a.BaseContext, nc, as, a.Handler)
go func() {
defer nc.Close()
defer p.Close()

if err := p.Serve(); err != nil && err != p.ctx.Err() {
if err := p.Serve(); err != nil && !errors.Is(err, p.ctx.Err()) {
log.Println(err)
}
}()
Expand Down
2 changes: 2 additions & 0 deletions spop/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func acquireFrame() *frame {

func releaseFrame(f *frame) {
f.buf.Reset()
f.frameType = 0
f.meta = frameMetadata{}

framePool.Put(f)
}
Expand Down
20 changes: 3 additions & 17 deletions spop/frames.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package spop

import (
"fmt"
"io"
"strings"

Expand Down Expand Up @@ -30,7 +31,6 @@ const (

type frameWriter interface {
io.WriterTo
Write(w io.Writer) error
}

var (
Expand All @@ -42,11 +42,6 @@ type AgentDisconnectFrame struct {
ErrCode errorCode
}

func (a *AgentDisconnectFrame) Write(w io.Writer) error {
_, err := a.WriteTo(w)
return err
}

func (a *AgentDisconnectFrame) WriteTo(w io.Writer) (int64, error) {
f := acquireFrame()
defer releaseFrame(f)
Expand Down Expand Up @@ -92,11 +87,6 @@ type AgentHelloFrame struct {
MaxFrameSize uint32
}

func (a *AgentHelloFrame) Write(w io.Writer) error {
_, err := a.WriteTo(w)
return err
}

func (a *AgentHelloFrame) WriteTo(w io.Writer) (int64, error) {
f := acquireFrame()
defer releaseFrame(f)
Expand Down Expand Up @@ -144,12 +134,13 @@ func (a *AckFrame) WriteTo(w io.Writer) (int64, error) {
f.meta.Flags = frameFlagFin

if err := f.encodeHeader(); err != nil {
return 0, err
return 0, fmt.Errorf("encoding header: %w", err)
}

aw := encoding.AcquireActionWriter(f.buf.WriteBytes(), 0)
defer encoding.ReleaseActionWriter(aw)

// TODO: errors are not correctly handled and will result in an invalid state.
if err := a.ActionWriterCallback(aw); err != nil {
return 0, err
}
Expand All @@ -158,8 +149,3 @@ func (a *AckFrame) WriteTo(w io.Writer) (int64, error) {

return f.WriteTo(w)
}

func (a *AckFrame) Write(w io.ReadWriter) error {
_, err := a.WriteTo(w)
return err
}
75 changes: 60 additions & 15 deletions spop/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"runtime"
"syscall"

"github.com/dropmorepackets/haproxy-go/pkg/encoding"
)
Expand All @@ -24,23 +25,30 @@ type protocolClient struct {
handler Handler
ctx context.Context

ctxCancel context.CancelFunc
ctxCancel context.CancelCauseFunc
as *asyncScheduler

engineID string
maxFrameSize uint32

gotHello bool
lf frameType
}

func (c *protocolClient) Close() error {
errDisconnect := (&AgentDisconnectFrame{
if c.ctx.Err() != nil {
return c.ctx.Err()
}

// We ignore any error since the disconnect frame is delivered on
// best effort anyway.
_, _ = (&AgentDisconnectFrame{
ErrCode: ErrorUnknown,
}).Write(c.rw)
}).WriteTo(c.rw)

c.ctxCancel()
c.ctxCancel(fmt.Errorf("closing client"))

return errors.Join(errDisconnect, c.ctx.Err())
return nil
}

func (c *protocolClient) frameHandler(f *frame) error {
Expand All @@ -63,17 +71,17 @@ func (c *protocolClient) Serve() error {
f := acquireFrame()
if _, err := f.ReadFrom(c.rw); err != nil {
if c.ctx.Err() != nil {
return c.ctx.Err()
return context.Cause(c.ctx)
}

if errors.Is(err, io.EOF) {
if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) {
return nil
}

return err
}

c.as.schedule(f)
c.as.schedule(f, c)
}
}

Expand Down Expand Up @@ -112,19 +120,20 @@ func (c *protocolClient) onHAProxyHello(f *frame) error {
case k.NameEquals(helloKeyHealthcheck):
// as described in the protocol, close connection after hello
// AGENT-HELLO + close()
defer c.ctxCancel()
defer c.ctxCancel(nil)
}
}

if err := s.Error(); err != nil {
return err
}

return (&AgentHelloFrame{
_, err := (&AgentHelloFrame{
Version: version,
MaxFrameSize: c.maxFrameSize,
Capabilities: []string{capabilityNamePipelining, capabilityNameAsync},
}).Write(c.rw)
}).WriteTo(c.rw)
return err
}

func (c *protocolClient) runHandler(ctx context.Context, w *encoding.ActionWriter, m *encoding.Message, handler HandlerFunc) (err error) {
Expand Down Expand Up @@ -166,14 +175,50 @@ func (c *protocolClient) onNotify(f *frame) error {
return s.Error()
}

return (&AckFrame{
_, err := (&AckFrame{
FrameID: f.meta.FrameID,
StreamID: f.meta.StreamID,
ActionWriterCallback: fn,
}).Write(c.rw)
}).WriteTo(c.rw)
return err
}

func (c *protocolClient) onHAProxyDisconnect(f *frame) error {
//TODO: read disconnect reason and return error if required?
return nil
if f.buf.Len() == 0 {
return fmt.Errorf("disconnect frame without content")
}

s := encoding.AcquireKVScanner(f.buf.ReadBytes(), -1)
defer encoding.ReleaseKVScanner(s)

k := encoding.AcquireKVEntry()
defer encoding.ReleaseKVEntry(k)

var (
code errorCode
)

for s.Next(k) {
switch name := string(k.NameBytes()); name {
case "status-code":
code = errorCode(k.ValueInt())
case "message":
// We don't really care about the message since they should all be
// defined in the errorCode type.
default:
panic("unexpected kv entry: " + name)
}
}

var err error
switch code {
// HAProxy returns an IO error when it doesn't require a connection
// anymore.
case ErrorIO, ErrorTimeout, ErrorNone:
default:
err = fmt.Errorf("disconnect frame with code %d: %s", code, code)
}

c.ctxCancel(err)
return err
}

0 comments on commit 347b227

Please sign in to comment.