Skip to content

Commit

Permalink
feat: Refactor polling transport and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisgibson committed Aug 16, 2024
1 parent cc6fa1e commit 6caaeed
Show file tree
Hide file tree
Showing 2 changed files with 1,454 additions and 102 deletions.
129 changes: 70 additions & 59 deletions transport_polling.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package engineio
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -20,6 +20,8 @@ type PollingTransport struct {
header http.Header
// state is the state of the transport.
state TransportState
// polling is a wait group that completes when the in-flight poll is complete.
polling *sync.WaitGroup
// onOpenHandler is the handler for when the transport opens.
onOpenHandler TransportOpenHandler
// onCloseHandler is the handler for when the transport closes.
Expand All @@ -28,12 +30,14 @@ type PollingTransport struct {
onPacketHandler TransportPacketHandler
// onErrorHandler is the handler for when the transport encounters an error.
onErrorHandler TransportErrorHandler
// polling is a wait group that completes when the in-flight poll is complete.
polling *sync.WaitGroup
}

// NewPollingTransport creates a new PollingTransport.
func NewPollingTransport(url *url.URL, opts TransportOptions) Transport {
func NewPollingTransport(url *url.URL, opts TransportOptions) (Transport, error) {
if url == nil {
return nil, ErrURLRequired
}

var client TransportClient = http.DefaultClient
if opts.Client != nil {
client = opts.Client
Expand All @@ -45,11 +49,12 @@ func NewPollingTransport(url *url.URL, opts TransportOptions) Transport {
}

return &PollingTransport{
url: url,
client: client,
header: header,
state: TransportStateClosed,
}
url: url,
client: client,
header: header,
state: TransportStateClosed,
polling: &sync.WaitGroup{},
}, nil
}

// Type returns the type of the transport.
Expand All @@ -73,22 +78,6 @@ func (t *PollingTransport) Open(ctx context.Context) {
t.poll(ctx)
}

// Close closes the transport.
func (t *PollingTransport) Close(ctx context.Context) {
if t.state == TransportStateOpening || t.state == TransportStateOpen {
// Set the state to closing to prevent further polling and Close being called again.
t.state = TransportStateClosing

// Send a close packet
t.Send(ctx, []Packet{{Type: PacketClose}})

// Wait for polling to be complete
if t.polling != nil {
t.polling.Wait()
}
}
}

// Pause pauses the transport.
func (t *PollingTransport) Pause(ctx context.Context) {
switch t.state {
Expand All @@ -104,54 +93,78 @@ func (t *PollingTransport) Pause(ctx context.Context) {
t.state = TransportStatePausing

// Wait for polling to be complete.
if t.polling != nil {
t.polling.Wait()
}
t.polling.Wait()

// Set the state to paused.
t.state = TransportStatePaused
}

// Send sends packets through the transport.
func (t *PollingTransport) Send(ctx context.Context, packets []Packet) error {
// The state must be open, to send data; or closing to send the close packet.
if t.state != TransportStateOpen && t.state != TransportStateClosing {
switch t.state {
// These states are valid for sending packets.
case TransportStateOpen:
break

default:
return nil
}

b, err := EncodePayload(packets)
if err != nil {
return err
return fmt.Errorf("encode error: %w", err)
}

return t.write(ctx, b)
}

// Close closes the transport.
func (t *PollingTransport) Close(ctx context.Context) {
switch t.state {
// These states are valid for closing the transport.
case TransportStateOpening, TransportStateOpen:
break

default:
return
}

// Send a close packet
t.Send(ctx, []Packet{{Type: PacketClose}})

// Set the state to closing to prevent further polling and Close being called again.
t.state = TransportStateClosing

// Wait for polling to be complete
t.polling.Wait()
}

// poll requests data from the server.
func (t *PollingTransport) poll(ctx context.Context) {
// Store a wait group to wait for the in-flight poll to complete.
t.polling = &sync.WaitGroup{}
t.polling.Add(1)
// If polling is still being held, wait for it to complete.
t.polling.Wait()

// Increment the polling wait group.
t.polling.Add(1)
// Polling is complete when the function returns.
defer t.polling.Done()

res, err := t.request(ctx, nil)
switch {
case err != nil:
t.onError(ctx, err)
t.onError(ctx, fmt.Errorf("polling error: %w", err))
return

case res.StatusCode != http.StatusOK:
t.onError(ctx, err)
t.onError(ctx, fmt.Errorf("polling error: %d", res.StatusCode))
return
}

defer res.Body.Close()
b, err := io.ReadAll(res.Body)
switch {
case err != nil:
t.onError(ctx, err)
t.onError(ctx, fmt.Errorf("read error: %w", err))

case len(b) != 0:
t.onData(ctx, b)
Expand All @@ -163,10 +176,10 @@ func (t *PollingTransport) write(ctx context.Context, data []byte) error {
res, err := t.request(ctx, data)
switch {
case err != nil:
return err
return fmt.Errorf("write error: %w", err)

case res.StatusCode != http.StatusOK:
return errors.New("polling error")
return fmt.Errorf("write error: %d", res.StatusCode)

default:
return nil
Expand All @@ -188,54 +201,45 @@ func (t *PollingTransport) request(ctx context.Context, data []byte) (*http.Resp

u, err := url.Parse(t.url.String())
if err != nil {
return nil, err
return nil, fmt.Errorf("parse error: %w", err)
}

req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
if err != nil {
return nil, err
return nil, fmt.Errorf("request error: %w", err)
}
req.Header = header

return t.client.Do(req)
}

// onError calls the onError handler.
func (t *PollingTransport) onError(ctx context.Context, err error) {
if t.onErrorHandler != nil {
go t.onErrorHandler(ctx, err)
}
}

// onData processes data received from the server.
func (t *PollingTransport) onData(ctx context.Context, data []byte) error {
func (t *PollingTransport) onData(ctx context.Context, data []byte) {
packets, err := DecodePayload(data)
if err != nil {
return err
t.onError(ctx, fmt.Errorf("decode error: %w", err))
return
}

// Process each packet.
for _, packet := range packets {
switch {
// If the packet is an open packet and the transport is opening, call the onOpen method.
case packet.Type == PacketOpen && t.state == TransportStateOpening:
go t.onOpen(ctx)
t.onOpen(ctx)

// If the packet is a close packet, call the onClose method.
case packet.Type == PacketClose:
// If the packet is a close packet and the transport is not closed, call the onClose method.
case packet.Type == PacketClose && t.state != TransportStateClosed:
t.onClose(ctx)
continue
}

t.onPacket(ctx, packet)
}

// Poll again if the transport is open, pausing, or paused.
// Poll again if the transport is open or pausing.
if t.state == TransportStateOpen || t.state == TransportStatePausing {
go t.poll(ctx)
}

return nil
}

// onOpen sets the state of the transport to open.
Expand All @@ -252,14 +256,21 @@ func (t *PollingTransport) onClose(ctx context.Context) {
t.state = TransportStateClosed

if t.onCloseHandler != nil {
go t.onCloseHandler(ctx)
t.onCloseHandler(ctx)
}
}

// onPacket calls the onPacket handler.
func (t *PollingTransport) onPacket(ctx context.Context, packet Packet) {
if t.onPacketHandler != nil {
go t.onPacketHandler(ctx, packet)
t.onPacketHandler(ctx, packet)
}
}

// onError calls the onError handler.
func (t *PollingTransport) onError(ctx context.Context, err error) {
if t.onErrorHandler != nil {
t.onErrorHandler(ctx, err)
}
}

Expand Down
Loading

0 comments on commit 6caaeed

Please sign in to comment.