diff --git a/spop/agent.go b/spop/agent.go index aade30f..42a9107 100644 --- a/spop/agent.go +++ b/spop/agent.go @@ -2,6 +2,7 @@ package spop import ( "context" + "errors" "fmt" "log" "net" @@ -46,12 +47,23 @@ func (a *Agent) Serve(l net.Listener) error { return fmt.Errorf("accepting conn: %w", err) } - p := newProtocolClient(a.BaseContext, nc, a.Handler) + if tcp, ok := nc.(*net.TCPConn); ok { + err = tcp.SetWriteBuffer(maxFrameSize * 4) + if err != nil { + return err + } + err = tcp.SetReadBuffer(maxFrameSize * 4) + if err != nil { + return err + } + } + + p := newProtocolClient(a.BaseContext, nc, as, a.Handler) go func() { defer nc.Close() defer p.Close() - if err := p.Serve(); err != nil && err != p.ctx.Err() { + if err := p.Serve(); err != nil && !errors.Is(err, p.ctx.Err()) { log.Println(err) } }() diff --git a/spop/frame.go b/spop/frame.go index e6cf1b8..c8361e3 100644 --- a/spop/frame.go +++ b/spop/frame.go @@ -28,6 +28,8 @@ func acquireFrame() *frame { func releaseFrame(f *frame) { f.buf.Reset() + f.frameType = 0 + f.meta = frameMetadata{} framePool.Put(f) } diff --git a/spop/frames.go b/spop/frames.go index 6bd9e51..eb70fe0 100644 --- a/spop/frames.go +++ b/spop/frames.go @@ -1,6 +1,7 @@ package spop import ( + "fmt" "io" "strings" @@ -30,7 +31,6 @@ const ( type frameWriter interface { io.WriterTo - Write(w io.Writer) error } var ( @@ -42,11 +42,6 @@ type AgentDisconnectFrame struct { ErrCode errorCode } -func (a *AgentDisconnectFrame) Write(w io.Writer) error { - _, err := a.WriteTo(w) - return err -} - func (a *AgentDisconnectFrame) WriteTo(w io.Writer) (int64, error) { f := acquireFrame() defer releaseFrame(f) @@ -92,11 +87,6 @@ type AgentHelloFrame struct { MaxFrameSize uint32 } -func (a *AgentHelloFrame) Write(w io.Writer) error { - _, err := a.WriteTo(w) - return err -} - func (a *AgentHelloFrame) WriteTo(w io.Writer) (int64, error) { f := acquireFrame() defer releaseFrame(f) @@ -144,12 +134,13 @@ func (a *AckFrame) WriteTo(w io.Writer) (int64, error) { f.meta.Flags = frameFlagFin if err := f.encodeHeader(); err != nil { - return 0, err + return 0, fmt.Errorf("encoding header: %w", err) } aw := encoding.AcquireActionWriter(f.buf.WriteBytes(), 0) defer encoding.ReleaseActionWriter(aw) + // TODO: errors are not correctly handled and will result in an invalid state. if err := a.ActionWriterCallback(aw); err != nil { return 0, err } @@ -158,8 +149,3 @@ func (a *AckFrame) WriteTo(w io.Writer) (int64, error) { return f.WriteTo(w) } - -func (a *AckFrame) Write(w io.ReadWriter) error { - _, err := a.WriteTo(w) - return err -} diff --git a/spop/protocol.go b/spop/protocol.go index 9c3ea76..05736be 100644 --- a/spop/protocol.go +++ b/spop/protocol.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "runtime" + "syscall" "github.com/dropmorepackets/haproxy-go/pkg/encoding" ) @@ -24,23 +25,30 @@ type protocolClient struct { handler Handler ctx context.Context - ctxCancel context.CancelFunc + ctxCancel context.CancelCauseFunc as *asyncScheduler engineID string maxFrameSize uint32 gotHello bool + lf frameType } func (c *protocolClient) Close() error { - errDisconnect := (&AgentDisconnectFrame{ + if c.ctx.Err() != nil { + return c.ctx.Err() + } + + // We ignore any error since the disconnect frame is delivered on + // best effort anyway. + _, _ = (&AgentDisconnectFrame{ ErrCode: ErrorUnknown, - }).Write(c.rw) + }).WriteTo(c.rw) - c.ctxCancel() + c.ctxCancel(fmt.Errorf("closing client")) - return errors.Join(errDisconnect, c.ctx.Err()) + return nil } func (c *protocolClient) frameHandler(f *frame) error { @@ -63,17 +71,17 @@ func (c *protocolClient) Serve() error { f := acquireFrame() if _, err := f.ReadFrom(c.rw); err != nil { if c.ctx.Err() != nil { - return c.ctx.Err() + return context.Cause(c.ctx) } - if errors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) { return nil } return err } - c.as.schedule(f) + c.as.schedule(f, c) } } @@ -112,7 +120,7 @@ func (c *protocolClient) onHAProxyHello(f *frame) error { case k.NameEquals(helloKeyHealthcheck): // as described in the protocol, close connection after hello // AGENT-HELLO + close() - defer c.ctxCancel() + defer c.ctxCancel(nil) } } @@ -120,11 +128,12 @@ func (c *protocolClient) onHAProxyHello(f *frame) error { return err } - return (&AgentHelloFrame{ + _, err := (&AgentHelloFrame{ Version: version, MaxFrameSize: c.maxFrameSize, Capabilities: []string{capabilityNamePipelining, capabilityNameAsync}, - }).Write(c.rw) + }).WriteTo(c.rw) + return err } func (c *protocolClient) runHandler(ctx context.Context, w *encoding.ActionWriter, m *encoding.Message, handler HandlerFunc) (err error) { @@ -166,14 +175,50 @@ func (c *protocolClient) onNotify(f *frame) error { return s.Error() } - return (&AckFrame{ + _, err := (&AckFrame{ FrameID: f.meta.FrameID, StreamID: f.meta.StreamID, ActionWriterCallback: fn, - }).Write(c.rw) + }).WriteTo(c.rw) + return err } func (c *protocolClient) onHAProxyDisconnect(f *frame) error { - //TODO: read disconnect reason and return error if required? - return nil + if f.buf.Len() == 0 { + return fmt.Errorf("disconnect frame without content") + } + + s := encoding.AcquireKVScanner(f.buf.ReadBytes(), -1) + defer encoding.ReleaseKVScanner(s) + + k := encoding.AcquireKVEntry() + defer encoding.ReleaseKVEntry(k) + + var ( + code errorCode + ) + + for s.Next(k) { + switch name := string(k.NameBytes()); name { + case "status-code": + code = errorCode(k.ValueInt()) + case "message": + // We don't really care about the message since they should all be + // defined in the errorCode type. + default: + panic("unexpected kv entry: " + name) + } + } + + var err error + switch code { + // HAProxy returns an IO error when it doesn't require a connection + // anymore. + case ErrorIO, ErrorTimeout, ErrorNone: + default: + err = fmt.Errorf("disconnect frame with code %d: %s", code, code) + } + + c.ctxCancel(err) + return err }