From 74cbe95f647fa89c921219aa662be73c47dcb603 Mon Sep 17 00:00:00 2001 From: Sam Liokumovich <65994425+samliok@users.noreply.github.com> Date: Mon, 3 Apr 2023 13:43:33 -0700 Subject: [PATCH] [Pubsub] Adding a pubsub package for server/client connections (#120) * from avalanchego * server test * pubsub test lock * check connection ends on client close * lint * pubsub stripped down, and server test added * server callback function * callback function in server and test * use connections in server * server comments * github lint * lint + comments + cleanup * read callback to connection * change callback function type * callback function change * server tests multiple connects + callback * fixed race condition in tests * partial PR comments * consts, and server start + shutdown * review * naming and pr comments --- go.mod | 3 + go.sum | 6 + pubsub/connection.go | 163 +++++++++++++++++++++++++ pubsub/connections.go | 60 +++++++++ pubsub/consts.go | 21 ++++ pubsub/server.go | 165 +++++++++++++++++++++++++ pubsub/server_test.go | 278 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 696 insertions(+) create mode 100644 pubsub/connection.go create mode 100644 pubsub/connections.go create mode 100644 pubsub/consts.go create mode 100644 pubsub/server.go create mode 100644 pubsub/server_test.go diff --git a/go.mod b/go.mod index acb790a754..e99e8b27b3 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,9 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.1.2 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 // indirect + github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/klauspost/compress v1.15.15 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/kr/text v0.2.0 // indirect @@ -50,6 +52,7 @@ require ( github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/supranational/blst v0.3.11-0.20220920110316-f72618070295 // indirect github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a // indirect go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.11.2 // indirect diff --git a/go.sum b/go.sum index 252ea31ec4..89e29d1ca6 100644 --- a/go.sum +++ b/go.sum @@ -251,6 +251,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gorilla/rpc v1.2.0 h1:WvvdC2lNeT1SP32zrIce5l0ECBfbAlmrmSBsuc57wfk= github.com/gorilla/rpc v1.2.0/go.mod h1:V4h9r+4sF5HnzqbwIez0fKSpANP0zlYd3qR7p36jkTQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 h1:kr3j8iIMR4ywO/O0rvksXaJvauGGCMg2zAZIiNZ9uIQ= @@ -259,6 +261,8 @@ github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hydrogen18/memlistener v0.0.0-20141126152155-54553eb933fb/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -413,6 +417,8 @@ github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= diff --git a/pubsub/connection.go b/pubsub/connection.go new file mode 100644 index 0000000000..373f0bec4f --- /dev/null +++ b/pubsub/connection.go @@ -0,0 +1,163 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pubsub + +import ( + "errors" + "io" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +var ( + ErrFilterNotInitialized = errors.New("filter not initialized") + ErrAddressLimit = errors.New("address limit exceeded") + ErrInvalidFilterParam = errors.New("invalid bloom filter params") + ErrInvalidCommand = errors.New("invalid command") +) + +// Callback type is used as a callback function for the +// WebSocket server to process incoming messages. +// Accepts a byte message, the connection and any additional information. +type Callback func([]byte, *Connection) []byte + +// connection is a representation of the websocket connection. +type Connection struct { + s *Server + + // The websocket connection. + conn *websocket.Conn + + // Buffered channel of outbound messages. + send chan []byte + + // Represents if the connection can receive new messages. + active atomic.Bool +} + +// isActive returns whether the connection is active +func (c *Connection) isActive() bool { + return c.active.Load() +} + +// deactivate deactivates the connection. +func (c *Connection) deactivate() { + c.active.Store(false) +} + +// Send sends [msg] to c's send channel and returns whether the message was sent. +func (c *Connection) Send(msg []byte) bool { + if !c.isActive() { + return false + } + select { + case c.send <- msg: + return true + default: + c.s.log.Debug("msg was dropped") + } + return false +} + +// readPump pumps messages from the websocket connection to the hub. +// +// The application runs readPump in a per-connection goroutine. The application +// ensures that there is at most one reader on a connection by executing all +// reads from this goroutine. +func (c *Connection) readPump() { + defer func() { + c.deactivate() + c.s.removeConnection(c) + + // close is called by both the writePump and the readPump so one of them + // will always error + _ = c.conn.Close() + }() + + c.conn.SetReadLimit(c.s.config.MaxMessageSize) + // SetReadDeadline returns an error if the connection is corrupted + if err := c.conn.SetReadDeadline(time.Now().Add(c.s.config.PongWait)); err != nil { + return + } + c.conn.SetPongHandler(func(string) error { + return c.conn.SetReadDeadline(time.Now().Add(c.s.config.PongWait)) + }) + for { + _, reader, err := c.conn.NextReader() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + c.s.log.Debug("unexpected close in websockets", + zap.Error(err), + ) + } + break + } + if c.s.callback != nil { + responseBytes, err := io.ReadAll(reader) + if err == nil { + c.s.log.Debug("unexpected error reading bytes from websockets", + zap.Error(err), + ) + } + c.Send(c.s.callback(responseBytes, c)) + } + } +} + +// writePump pumps messages from the hub to the websocket connection. +// +// A goroutine running writePump is started for each connection. The +// application ensures that there is at most one writer to a connection by +// executing all writes from this goroutine. +func (c *Connection) writePump() { + ticker := time.NewTicker(c.s.config.PingPeriod) + defer func() { + c.deactivate() + ticker.Stop() + c.s.removeConnection(c) + + // close is called by both the writePump and the readPump so one of them + // will always error + _ = c.conn.Close() + }() + for { + select { + case message, ok := <-c.send: + if err := c.conn.SetWriteDeadline(time.Now().Add(c.s.config.WriteWait)); err != nil { + c.s.log.Debug("closing the connection", + zap.String("reason", "failed to set the write deadline"), + zap.Error(err), + ) + return + } + if !ok { + // The hub closed the channel. Attempt to close the connection + // gracefully. + _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + if err := c.conn.WriteMessage(websocket.BinaryMessage, message); err != nil { + return + } + case <-ticker.C: + if err := c.conn.SetWriteDeadline(time.Now().Add(c.s.config.WriteWait)); err != nil { + c.s.log.Debug("closing the connection", + zap.String("reason", "failed to set the write deadline"), + zap.Error(err), + ) + return + } + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/pubsub/connections.go b/pubsub/connections.go new file mode 100644 index 0000000000..72c7e5c361 --- /dev/null +++ b/pubsub/connections.go @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pubsub + +import ( + "sync" + + "github.com/ava-labs/avalanchego/utils/set" +) + +// connections represents a collection of connections to clients. +type Connections struct { + lock sync.RWMutex + conns set.Set[*Connection] +} + +// NewConnections returns a new Connections instance. +func NewConnections() *Connections { + return &Connections{} +} + +// Conns returns a list of all connections in [c]. +func (c *Connections) Conns() []*Connection { + c.lock.RLock() + defer c.lock.RUnlock() + return c.conns.List() +} + +// Has returns if the connection [conn] is in [c]. +func (c *Connections) Has(conn *Connection) bool { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.conns.Contains(conn) +} + +// Remove removes [conn] from [c]. +func (c *Connections) Remove(conn *Connection) { + c.lock.Lock() + defer c.lock.Unlock() + + c.conns.Remove(conn) +} + +// Add adds [conn] to the [c]. +func (c *Connections) Add(conn *Connection) { + c.lock.Lock() + defer c.lock.Unlock() + + c.conns.Add(conn) +} + +// Len returns the number of connections in [c]. +func (c *Connections) Len() int { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.conns.Len() +} diff --git a/pubsub/consts.go b/pubsub/consts.go new file mode 100644 index 0000000000..e736b7c3a9 --- /dev/null +++ b/pubsub/consts.go @@ -0,0 +1,21 @@ +// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pubsub + +import ( + "time" + + "github.com/ava-labs/avalanchego/utils/units" +) + +const ( + ReadBufferSize = units.KiB + WriteBufferSize = units.KiB + WriteWait = 10 * time.Second + PongWait = 60 * time.Second + PingPeriod = (PongWait * 9) / 10 + MaxMessageSize = 10 * units.KiB // bytes + MaxPendingMessages = 1024 + ReadHeaderTimeout = 5 * time.Second +) diff --git a/pubsub/server.go b/pubsub/server.go new file mode 100644 index 0000000000..f69e69f946 --- /dev/null +++ b/pubsub/server.go @@ -0,0 +1,165 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pubsub + +import ( + "context" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/utils/logging" +) + +type ServerConfig struct { + // Size of the ws read buffer + ReadBufferSize int + // Size of the ws write buffer + WriteBufferSize int + // Maximum number of pending messages to send to a peer. + MaxPendingMessages int + // Maximum message size in bytes allowed from peer. + MaxMessageSize int64 + // Time allowed to write a message to the peer. + WriteWait time.Duration + // Time allowed to read the next pong message from the peer. + PongWait time.Duration + // Send pings to peer with this period. Must be less than pongWait. + PingPeriod time.Duration + // ReadHeaderTimeout is the maximum duration for reading a request. + ReadHeaderTimeout time.Duration +} + +func NewDefaultServerConfig() *ServerConfig { + return &ServerConfig{ + ReadBufferSize: ReadBufferSize, + WriteBufferSize: WriteBufferSize, + MaxPendingMessages: MaxPendingMessages, + MaxMessageSize: MaxMessageSize, + WriteWait: WriteWait, + PongWait: PongWait, + PingPeriod: (9 * PongWait) / 10, + ReadHeaderTimeout: ReadHeaderTimeout, + } +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { + return true + }, +} + +// Server maintains the set of active clients and sends messages to the clients. +// +// Connect to the server after starting using websocket.DefaultDialer.Dial(). +type Server struct { + // The http server + s *http.Server + // The address to listen on + addr string + log logging.Logger + lock sync.RWMutex + // conns a set of all our connections + conns *Connections + // Callback function when server receives a message + callback Callback + // Config variables + config *ServerConfig +} + +// New returns a new Server instance. The callback function [f] is called +// by the server in response to messages if not nil. +func New( + addr string, + r Callback, + log logging.Logger, + config *ServerConfig, +) *Server { + upgrader.ReadBufferSize = config.ReadBufferSize + upgrader.WriteBufferSize = config.WriteBufferSize + return &Server{ + log: log, + addr: addr, + callback: r, + conns: NewConnections(), + config: config, + } +} + +// ServeHTTP adds a connection to the server, and starts go routines for +// reading and writing. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Upgrader.upgrade() is called to upgrade the HTTP connection. + // No nead to set any headers so we pass nil as the last argument. + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + s.log.Debug("failed to upgrade", + zap.Error(err), + ) + return + } + s.addConnection(&Connection{ + s: s, + conn: wsConn, + send: make(chan []byte, s.config.MaxPendingMessages), + active: atomic.Bool{}, + }) +} + +// Publish sends msg from [s] to [toConns]. +func (s *Server) Publish(msg []byte, toConns *Connections) { + for _, conn := range toConns.Conns() { + // check server has connection O(1) + if !s.conns.Has(conn) { + continue + } + if !conn.Send(msg) { + s.log.Verbo( + "dropping message to subscribed connection due to too many pending messages", + ) + } + } +} + +// addConnection adds [conn] to the servers connection set and starts go +// routines for reading and writing messages for the connection. +func (s *Server) addConnection(conn *Connection) { + s.lock.Lock() + defer s.lock.Unlock() + + conn.active.Store(true) + s.conns.Add(conn) + + go conn.writePump() + go conn.readPump() +} + +// removeConnection removes [conn] from the servers connection set. +func (s *Server) removeConnection(conn *Connection) { + s.conns.Remove(conn) +} + +// Start starts the server. Returns an error if the server fails to start or +// when the server is stopped. +func (s *Server) Start() error { + s.lock.Lock() + s.s = &http.Server{ + Addr: s.addr, + Handler: s, + ReadHeaderTimeout: s.config.ReadHeaderTimeout, + } + s.lock.Unlock() + err := s.s.ListenAndServe() + return err +} + +// Shutdown shuts down the server and returns the associated error. +func (s *Server) Shutdown(c context.Context) error { + return s.s.Shutdown(c) +} diff --git a/pubsub/server_test.go b/pubsub/server_test.go new file mode 100644 index 0000000000..cf215e0174 --- /dev/null +++ b/pubsub/server_test.go @@ -0,0 +1,278 @@ +// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pubsub + +import ( + "context" + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +const dummyAddr = "localhost:8080" + +var ( + callbackEmptyResponse = "EMPTY_ID" + callbackResponse = "ID_RECEIVED" +) + +// This is a dummy struct to test the callback function +type counter struct { + val int +} + +func (x *counter) dummyProcessTXCallback(b []byte, _ *Connection) []byte { + x.val++ + id, err := ids.ToID(b) + if err != nil { + return []byte("ERROR") + } + if ids.Empty == id { + return []byte(callbackEmptyResponse) + } else { + return []byte(callbackResponse) + } +} + +// This also makes sure the callback function executed properly. +// TestServerPublish adds a connection to a server then publishes +// a msg to be sent to all connections. Checks the message was delivered properly +// and the connection is properly handled when closed. +func TestServerPublish(t *testing.T) { + require := require.New(t) + // Create a new logger for the test + logger := logging.NoLog{} + // Create a new pubsub server + server := New(dummyAddr, nil, logger, NewDefaultServerConfig()) + // Channels for ensuring if connections/server are closed + closeConnection := make(chan bool) + serverDone := make(chan struct{}) + dummyMsg := "dummy_msg" + // Go routine that listens on dummyAddress for connections + go func() { + defer close(serverDone) + err := server.Start() + require.ErrorIs(err, http.ErrServerClosed, "Incorrect error closing server.") + }() + // Connect to pubsub server + u := url.URL{Scheme: "ws", Host: dummyAddr} + // Wait for server to start accepting requests + <-time.After(10 * time.Millisecond) + webCon, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) + require.NoError(err, "Error connecting to the server.") + defer resp.Body.Close() + // Publish to subscribed connections + server.lock.Lock() + server.Publish([]byte(dummyMsg), server.conns) + server.lock.Unlock() + // Receive the message from the publish + _, msg, err := webCon.ReadMessage() + require.NoError(err, "Error receiveing message.") + // Verify that the received message is the expected dummy message + require.Equal([]byte(dummyMsg), msg, "Response from server not correct.") + // Close the connection and wait for it to be closed on the server side + go func() { + webCon.Close() + for { + server.lock.Lock() + len := server.conns.Len() + if len == 0 { + server.lock.Unlock() + closeConnection <- true + return + } + server.lock.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + // Wait for the connection to be closed or for a timeout to occur + select { + case <-closeConnection: + // Connection was closed on the server side, test passed + case <-time.After(time.Second): + // Timeout occurred, connection was not closed on the server side, test failed + require.Fail("connection was not closed on the server side") + } + // Gracefully shutdown the server + err = server.Shutdown(context.TODO()) + require.NoError(err, "Error shuting down server") + // Wait for the server to finish shutting down + <-serverDone +} + +// TestServerPublish pumps messages into a dummy server and waits for +// the servers response. Requires the server handled the messages correctly. +func TestServerRead(t *testing.T) { + require := require.New(t) + // Create a new logger for the test + logger := logging.NoLog{} + counter := &counter{ + val: 10, + } + // Create a new pubsub server + server := New(dummyAddr, counter.dummyProcessTXCallback, + logger, NewDefaultServerConfig()) + // Channels for ensuring if connections/server are closed + closeConnection := make(chan bool) + serverDone := make(chan struct{}) + // Go routine that listens on dummyAddress for connections + go func() { + defer close(serverDone) + err := server.Start() + require.ErrorIs(err, http.ErrServerClosed, "Incorrect error closing server.") + }() + // Connect to pubsub server + u := url.URL{Scheme: "ws", Host: dummyAddr} + // Wait for server to start accepting requests + <-time.After(10 * time.Millisecond) + webCon, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) + require.NoError(err, "Error connecting to the server.") + defer resp.Body.Close() + id := ids.GenerateTestID() + err = webCon.WriteMessage(websocket.TextMessage, id[:]) + require.NoError(err, "Error writing message to server.") + // Receive the message from the publish + _, msg, err := webCon.ReadMessage() + require.NoError(err, "Error reading from connection.") + // Callback was correctly called + require.Equal(11, counter.val, "Callback not called correctly.") + // Verify that the received message is the expected dummy message + require.Equal(callbackResponse, string(msg), "Response is unexpected.") + // Close the connection and wait for it to be closed on the server side + go func() { + webCon.Close() + for { + server.lock.Lock() + len := server.conns.Len() + if len == 0 { + server.lock.Unlock() + closeConnection <- true + return + } + server.lock.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + // Wait for the connection to be closed or for a timeout to occur + select { + case <-closeConnection: + // Connection was closed on the server side, test passed + case <-time.After(time.Second): + // Timeout occurred, connection was not closed on the server side, test failed + require.Fail("connection was not closed on the server side") + } + // Gracefully shutdown the server + err = server.Shutdown(context.TODO()) + require.NoError(err, "Error shutting down server.") + // Wait for the server to finish shutting down + <-serverDone +} + +// TestServerPublishSpecific adds two connections to a pubsub server then publishes +// a msg to be sent to only one of the connections. Checks the message was +// delivered properly and the connection is properly handled when closed. +func TestServerPublishSpecific(t *testing.T) { + require := require.New(t) + // Create a new logger for the test + logger := logging.NoLog{} + counter := &counter{ + val: 10, + } + // Create a new pubsub server + server := New(dummyAddr, counter.dummyProcessTXCallback, + logger, NewDefaultServerConfig()) + + // Channels for ensuring if connections/server are closed + closeConnection := make(chan bool) + serverDone := make(chan struct{}) + dummyMsg := "dummy_msg" + // Go routine that listens on dummyAddress for connections + go func() { + defer close(serverDone) + err := server.Start() + require.ErrorIs(err, http.ErrServerClosed, "Incorrect error closing server.") + }() + // Connect to pubsub server + u := url.URL{Scheme: "ws", Host: dummyAddr} + // Wait for server to start accepting requests + <-time.After(10 * time.Millisecond) + webCon1, resp1, err := websocket.DefaultDialer.Dial(u.String(), nil) + require.NoError(err, "Error connecting to the server.") + defer resp1.Body.Close() + sendConns := NewConnections() + server.lock.Lock() + peekCon, _ := server.conns.conns.Peek() + server.lock.Unlock() + sendConns.Add(peekCon) + webCon2, resp2, err := websocket.DefaultDialer.Dial(u.String(), nil) + require.NoError(err, "Error connecting to the server.") + defer resp2.Body.Close() + require.Equal(2, server.conns.Len(), "Server didn't add connection correctly.") + // Publish to subscribed connections + server.lock.Lock() + server.Publish([]byte(dummyMsg), sendConns) + server.lock.Unlock() + go func() { + // Receive the message from the publish + _, msg, err := webCon1.ReadMessage() + require.NoError(err, "Error reading to connection.") + // Verify that the received message is the expected dummy message + require.Equal([]byte(dummyMsg), msg, "Message not as expected.") + webCon1.Close() + for { + server.lock.Lock() + len := server.conns.Len() + if len == 0 { + server.lock.Unlock() + closeConnection <- true + return + } + server.lock.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + // not receive from the other + go func() { + err := webCon2.SetReadDeadline(time.Now().Add(time.Second)) + require.NoError(err, "Error setting connection deadline.") + // Make sure connection wasn't written too + _, _, err = webCon2.ReadMessage() + require.Error(err, "Error not thrown.") + netErr, ok := err.(net.Error) + require.True(ok, "Error is not a net.Error") + require.True(netErr.Timeout(), "Error is not a timeout error") + webCon2.Close() + for { + server.lock.Lock() + len := server.conns.Len() + if len == 0 { + server.lock.Unlock() + closeConnection <- true + return + } + server.lock.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + // Wait for the connection to be closed or for a timeout to occur + select { + case <-closeConnection: + // Connection was closed on the server side, test passed + case <-time.After(2 * time.Second): + // Timeout occurred, connection was not closed on the server side, test failed + require.Fail("connection was not closed on the server side") + } + // Gracefully shutdown the server + err = server.Shutdown(context.TODO()) + require.NoError(err, "Error shuting down server.") + // Wait for the server to finish shutting down + <-serverDone +}