Skip to content

Commit

Permalink
Use stream to copy websocket data for improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
alinz committed Aug 29, 2024
1 parent 4537796 commit e816308
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 8 deletions.
13 changes: 13 additions & 0 deletions examples/websocket/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### build stage ###
FROM golang:1.23-alpine AS builder

WORKDIR /websocket

COPY . .

RUN go build -o websocket-server /websocket/server/main.go

### run stage ###
FROM alpine:latest
COPY --from=builder /websocket/websocket-server ./websocket-server
CMD ["./websocket-server"]
22 changes: 22 additions & 0 deletions examples/websocket/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
services:
example-static-service:
build: .

labels:
- "baker.enable=true"
- "baker.network=baker"
- "baker.service.port=8000"
- "baker.service.static.domain=example.com"
- "baker.service.static.path=/*"
- "baker.service.static.headers.host=xyz.example.com"

networks:
- baker

ports:
- "3000:3000"

networks:
baker:
name: baker
external: true
8 changes: 8 additions & 0 deletions examples/websocket/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module helloworld

go 1.23

require (
github.com/coder/websocket v1.8.12
golang.org/x/time v0.6.0
)
4 changes: 4 additions & 0 deletions examples/websocket/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
130 changes: 130 additions & 0 deletions examples/websocket/server/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package main

import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"time"

"github.com/coder/websocket"
"golang.org/x/time/rate"
)

// echoServer is the WebSocket echo server implementation.
// It ensures the client speaks the echo subprotocol and
// only allows one message every 100ms with a 10 message burst.
type echoServer struct {
// logf controls where logs are sent.
logf func(f string, v ...interface{})
}

func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
//Subprotocols: []string{"echo"},
})
if err != nil {
s.logf("%v", err)
return
}
defer c.CloseNow()

// if c.Subprotocol() != "echo" {
// c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol, not "+c.Subprotocol())
// return
// }

l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
for {
err = echo(r.Context(), c, l)
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
return
}
if err != nil {
s.logf("failed to echo with %v: %v", r.RemoteAddr, err)
return
}
}
}

// echo reads from the WebSocket connection and then writes
// the received message back to it.
// The entire function has 10s to complete.
func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()

err := l.Wait(ctx)
if err != nil {
return err
}

typ, r, err := c.Reader(ctx)
if err != nil {
return err
}

w, err := c.Writer(ctx, typ)
if err != nil {
return err
}

_, err = io.Copy(w, r)
if err != nil {
return fmt.Errorf("failed to io.Copy: %w", err)
}

err = w.Close()
return err
}

func main() {
log.SetFlags(0)

err := run()
if err != nil {
log.Fatal(err)
}
}

// run starts a http.Server for the passed in address
// with all requests handled by echoServer.
func run() error {
addr := "0.0.0.0:3000"

l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
log.Printf("listening on ws://%v", l.Addr())

s := &http.Server{
Handler: echoServer{
logf: log.Printf,
},
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}
errc := make(chan error, 1)
go func() {
errc <- s.Serve(l)
}()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case err := <-errc:
log.Printf("failed to serve: %v", err)
case sig := <-sigs:
log.Printf("terminating: %v", sig)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

return s.Shutdown(ctx)
}
48 changes: 40 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand Down Expand Up @@ -138,30 +139,61 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request, contain
}
defer serverConn.Close(websocket.StatusNormalClosure, "")

ctx, cancel := context.WithCancel(r.Context())
defer cancel()

// Proxy data between client and server
go func() {
defer cancel()

for {
messageType, p, err := serverConn.Read(r.Context())
err = copyWebsocketStream(ctx, clientConn, serverConn)
if err != nil {
return
}

if err := clientConn.Write(r.Context(), messageType, p); err != nil {
slog.Error("failed to copy data between server and client", "error", err)
return
}
}
}()

for {
messageType, p, err := clientConn.Read(r.Context())
err = copyWebsocketStream(ctx, serverConn, clientConn)
if err != nil {
slog.Error("failed to copy data between client and server", "error", err)
return
}
}
}

if err := serverConn.Write(r.Context(), messageType, p); err != nil {
return
func copyWebsocketStream(ctx context.Context, dst, src *websocket.Conn) error {
var msgType websocket.MessageType
var r io.Reader
var w io.WriteCloser
var err error

for {
msgType, r, err = src.Reader(ctx)
if err != nil {
break
}

w, err = dst.Writer(ctx, msgType)
if err != nil {
break
}

_, err = io.Copy(w, r)
if err != nil {
break
}
}

if errors.Is(err, context.Canceled) {
return nil
} else if errors.Is(err, io.EOF) {
return nil
}

return err
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit e816308

Please sign in to comment.