Skip to content

Commit

Permalink
refactor #1
Browse files Browse the repository at this point in the history
  • Loading branch information
fionera committed Feb 16, 2024
1 parent 483345c commit 0d2674e
Show file tree
Hide file tree
Showing 21 changed files with 1,098 additions and 1,069 deletions.
315 changes: 128 additions & 187 deletions peers/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@ package peers

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"syscall"
"time"

"github.com/dropmorepackets/haproxy-go/peers/sticktable"
"github.com/dropmorepackets/haproxy-go/pkg/encoding"
)

type Conn struct {
conn net.Conn
r *bufio.Reader
ctx context.Context
conn net.Conn
r *bufio.Reader

nextHeartbeat *time.Ticker
lastMessageTimer *time.Timer
lastTableDefinition *StickTableDefinition
lastEntryUpdate *EntryUpdate
lastTableDefinition *sticktable.Definition
lastEntryUpdate *sticktable.EntryUpdate

handler Handler
}
Expand All @@ -27,185 +32,18 @@ func (c *Conn) Close() error {
}

func (c *Conn) peerHandshake() error {
scanner := bufio.NewScanner(c.r)
// protocol identifier : HAProxyS
// version : 2.1
// remote peer identifier: the peer name this "hello" message is sent to.
// local peer identifier : the name of the peer which sends this "hello" message.
// process ID : the ID of the process handling this peer session.
// relative process ID : the haproxy's relative process ID (0 if nbproc == 1).

type handshake struct {
protocolIdentifier string
version string
remotePeer string
localPeerIdentifier string
processID int
relativeProcessID int
}

var h handshake
scanner.Scan()
_, _ = fmt.Sscanf(scanner.Text(), "%s %s", &h.protocolIdentifier, &h.version)

scanner.Scan()
h.remotePeer = scanner.Text()

scanner.Scan()
_, _ = fmt.Sscanf(scanner.Text(), "%s %d %d", &h.localPeerIdentifier, &h.processID, &h.relativeProcessID)

log.Printf("%+v", h)

_, err := c.conn.Write([]byte(fmt.Sprintf("%d\n", StatusHandshakeSucceeded)))
if err != nil {
_ = c.conn.Close()
return fmt.Errorf("handshake failed: %v", err)
}

return nil
}

func (c *Conn) Handshake() error {
if err := c.peerHandshake(); err != nil {
return err
}

c.resetHeartbeat()
c.resetLastMessage()
go c.heartbeat()
go c.lastMessage()

return nil
}

var unknownBuf []byte

// Read should be called in a loop. It handles all Messages and returns errors,
// which can be safely ignored. They are mostly for Informational purposes.
func (c *Conn) Read() error {
defer func() {
if len(unknownBuf) != 0 {
log.Println(unknownBuf)
}
}()

// All the messages are made at least of a two bytes length header.
header := make([]byte, 2)
_, err := c.r.Read(header)
if err != nil {
var h Handshake
if _, err := h.ReadFrom(c.r); err != nil {
return err
}

c.resetLastMessage()

switch m := MessageClass(header[0]); m {
case MessageClassControl:
unknownBuf = unknownBuf[:0]
return c.controlMessage(ControlMessageType(header[1]))
case MessageClassError:
unknownBuf = unknownBuf[:0]
return c.errorMessage(ErrorMessageType(header[1]))
case MessageClassStickTableUpdates:
unknownBuf = unknownBuf[:0]
return c.stickTableUpdate(StickTableMessageType(header[1]))
default:
unknownBuf = append(unknownBuf, header...)
return fmt.Errorf("unknown message class: %s", m)
}
}

func (c *Conn) controlMessage(t ControlMessageType) error {
switch t {
case ControlMessageSyncRequest:
_, _ = c.conn.Write([]byte{byte(MessageClassControl), byte(ControlMessageSyncPartial)})
return nil
case ControlMessageSyncFinished:
return nil
case ControlMessageSyncPartial:
return nil
case ControlMessageSyncConfirmed:
return nil
case ControlMessageHeartbeat:
return nil
}

return fmt.Errorf("unknown control message type: %s", t)
}

func (c *Conn) stickTableUpdate(t StickTableMessageType) error {
switch t {
case StickTableMessageStickTableDefinition:
var std StickTableDefinition
if err := std.Unmarshal(c.r); err != nil {
return err
}

c.lastTableDefinition = &std

//log.Printf("%+v", std)

return nil
case StickTableMessageStickTableSwitch:
panic(t)
return nil
case StickTableMessageUpdateAcknowledge:
panic(t)
return nil
case StickTableMessageEntryUpdate,
StickTableMessageUpdateTimed,
StickTableMessageIncrementalEntryUpdate,
StickTableMessageIncrementalEntryUpdateTimed:
return c.stickTableEntryUpdate(t)
// Just continue to the next switch statement
default:
return fmt.Errorf("unknown stick-table message type: %s", t)
}

return nil
}

func (c *Conn) stickTableEntryUpdate(t StickTableMessageType) error {
e := EntryUpdate{
StickTable: c.lastTableDefinition,
}

if c.lastEntryUpdate != nil {
e.LocalUpdateID = c.lastEntryUpdate.LocalUpdateID + 1
}

switch t {
case StickTableMessageEntryUpdate:
e.withLocalUpdateID = true
case StickTableMessageUpdateTimed:
e.withLocalUpdateID = true
e.withExpiry = true
case StickTableMessageIncrementalEntryUpdate:
case StickTableMessageIncrementalEntryUpdateTimed:
e.withExpiry = true
}

if err := e.Unmarshal(c.r); err != nil {
return err
if _, err := c.conn.Write([]byte(fmt.Sprintf("%d\n", HandshakeStatusHandshakeSucceeded))); err != nil {
return fmt.Errorf("handshake failed: %v", err)
}

c.lastEntryUpdate = &e

c.handler.Update(&e)

return nil
}

func (c *Conn) errorMessage(t ErrorMessageType) error {
switch t {
case ErrorMessageProtocol:
return fmt.Errorf("protocol error")
case ErrorMessageSizeLimit:
return fmt.Errorf("message size limit")
}

return fmt.Errorf("unknown error message type: %s", t)
}

func (c *Conn) resetHeartbeat() {
// a peer sends heartbeat messages to peers it is
// connected to after periods of 3s of inactivity (i.e. when there is no
Expand Down Expand Up @@ -237,7 +75,7 @@ func (c *Conn) heartbeat() {
for range c.nextHeartbeat.C {
_, err := c.conn.Write([]byte{byte(MessageClassControl), byte(ControlMessageHeartbeat)})
if err != nil {
_ = c.conn.Close()
_ = c.Close()
return
}
}
Expand All @@ -246,24 +84,127 @@ func (c *Conn) heartbeat() {
func (c *Conn) lastMessage() {
<-c.lastMessageTimer.C
log.Println("last message timer expired: closing connection")
_ = c.conn.Close()
_ = c.Close()
}

func (c *Conn) serve() {
defer c.Close()

if err := c.Handshake(); err != nil {
panic(err)
func (c *Conn) Serve() error {
if err := c.peerHandshake(); err != nil {
return fmt.Errorf("handshake: %v", err)
}

c.resetHeartbeat()
c.resetLastMessage()
go c.heartbeat()
go c.lastMessage()

for {
err := c.Read()
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) {
return
var m rawMessage

if _, err := m.ReadFrom(c.r); err != nil {
if c.ctx.Err() != nil {
return c.ctx.Err()
}

if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}

return fmt.Errorf("reading message: %v", err)
}

c.resetLastMessage()
if err := c.messageHandler(&m); err != nil {
return fmt.Errorf("message handler: %v", err)
}
}
}

func (c *Conn) messageHandler(m *rawMessage) error {
switch m.MessageClass {
case MessageClassControl:
return ControlMessageType(m.MessageType).OnMessage(m, c)
case MessageClassError:
return ErrorMessageType(m.MessageType).OnMessage(m, c)
case MessageClassStickTableUpdates:
return StickTableUpdateMessageType(m.MessageType).OnMessage(m, c)
default:
return fmt.Errorf("unknown message class: %s", m.MessageClass)
}
}

type byteReader interface {
io.ByteReader
io.Reader
}

type rawMessage struct {
MessageClass MessageClass
MessageType byte

Data []byte
}

func (m *rawMessage) ReadFrom(r byteReader) (int64, error) {
// All the messages are made at least of a two bytes length header.
header := make([]byte, 2)
n, err := r.Read(header)
if err != nil {
return int64(n), err
}

m.MessageClass = MessageClass(header[0])
m.MessageType = header[1]

var readData int
// All messages with type >= 128 have a payload
if m.MessageType >= 128 {
dataLength, err := encoding.ReadVarint(r)
if err != nil {
return int64(n), fmt.Errorf("failed decoding data length: %v", err)
}

m.Data = make([]byte, dataLength)
readData, err = r.Read(m.Data)
if err != nil {
panic(err)
return int64(n + readData), fmt.Errorf("failed reading message data: %v", err)
}
if int64(readData) != dataLength {
return int64(n + readData), fmt.Errorf("invalid amount read: %d != %d", dataLength, readData)
}
}

return int64(n + readData), nil
}

// Handshake is composed by these fields:
//
// protocol identifier : HAProxyS
// version : 2.1
// remote peer identifier: the peer name this "hello" message is sent to.
// local peer identifier : the name of the peer which sends this "hello" message.
// process ID : the ID of the process handling this peer session.
// relative process ID : the haproxy's relative process ID (0 if nbproc == 1).
type Handshake struct {
ProtocolIdentifier string
Version string
RemotePeer string
LocalPeerIdentifier string
ProcessID int
RelativeProcessID int
}

func (h *Handshake) ReadFrom(r io.Reader) (n int64, err error) {
scanner := bufio.NewScanner(r)

scanner.Scan()
_, err = fmt.Sscanf(scanner.Text(), "%s %s", &h.ProtocolIdentifier, &h.Version)

Check failure on line 200 in peers/conn.go

View workflow job for this annotation

GitHub Actions / lint

this value of err is never used (SA4006)

scanner.Scan()
h.RemotePeer = scanner.Text()

scanner.Scan()
_, err = fmt.Sscanf(scanner.Text(), "%s %d %d", &h.LocalPeerIdentifier, &h.ProcessID, &h.RelativeProcessID)

Check failure on line 206 in peers/conn.go

View workflow job for this annotation

GitHub Actions / lint

this value of err is never used (SA4006)

//TODO: find out how many bytes where read.
return -1, scanner.Err()
}
Loading

0 comments on commit 0d2674e

Please sign in to comment.