Skip to content

Commit

Permalink
feat: support websocket client reconnect (#2409)
Browse files Browse the repository at this point in the history
Signed-off-by: yisaer <[email protected]>
  • Loading branch information
Yisaer authored Nov 13, 2023
1 parent abeffe8 commit c91d1be
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 37 deletions.
9 changes: 5 additions & 4 deletions internal/topo/connection/clients/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
)

type WebSocketConnectionConfig struct {
Addr string `json:"addr"`
Path string `json:"path"`
tlsConfig *tls.Config
Addr string `json:"addr"`
Path string `json:"path"`
MaxConnRetry int `json:"maxConnRetry"`
tlsConfig *tls.Config
}

type tlsConf struct {
Expand All @@ -50,7 +51,7 @@ func (c *tlsConf) isNil() bool {
}

func NewWebSocketConnWrapper(props map[string]interface{}) (clients.ClientWrapper, error) {
config := &WebSocketConnectionConfig{}
config := &WebSocketConnectionConfig{MaxConnRetry: 3}
if err := cast.MapToStruct(props, config); err != nil {
return nil, err
}
Expand Down
94 changes: 74 additions & 20 deletions internal/topo/connection/clients/websocket/websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package websocket

import (
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/gorilla/websocket"

Expand All @@ -32,13 +34,16 @@ type websocketClientWrapper struct {
// We maintained topicChannel for each source_node(ruleID_OpID_InstanceID)
// When source_node Subscribed, each message comes from the websocket connection will be delivered into all topic Channel.
// When source_node Released, the Topic Channel will be removed by the ID so that the websocket msg won't send data to it anymore.
chs map[string][]api.TopicChannel
errCh map[string]chan error
refCount int
conSelector string
finished bool

processDone bool
chs map[string][]api.TopicChannel
errCh map[string]chan error
refCount int
conSelector string
finished bool
maxConnRetry int
config *WebSocketConnectionConfig

// only used for test
sync.WaitGroup
}

func newWebsocketClientClientWrapper(config *WebSocketConnectionConfig) (clients.ClientWrapper, error) {
Expand All @@ -48,13 +53,28 @@ func newWebsocketClientClientWrapper(config *WebSocketConnectionConfig) (clients
}
cc := &websocketClientWrapper{
c: conn, chs: make(map[string][]api.TopicChannel),
errCh: make(map[string]chan error),
refCount: 1,
errCh: make(map[string]chan error),
refCount: 1,
config: config,
maxConnRetry: config.MaxConnRetry,
}
cc.Add(1)
go cc.process()
return cc, nil
}

func (wcw *websocketClientWrapper) getConn() *websocket.Conn {
wcw.Lock()
defer wcw.Unlock()
return wcw.c
}

func (wcw *websocketClientWrapper) setConn(conn *websocket.Conn) {
wcw.Lock()
defer wcw.Unlock()
wcw.c = conn
}

func (wcw *websocketClientWrapper) isFinished() bool {
wcw.Lock()
defer wcw.Unlock()
Expand All @@ -72,27 +92,32 @@ func (wcw *websocketClientWrapper) getID(ctx api.StreamContext) string {
}

func (wcw *websocketClientWrapper) process() {
defer func() {
wcw.processDone = true
wcw.c.Close()
}()
defer wcw.Done()
for {
if wcw.isFinished() {
return
}
msgTyp, data, err := wcw.c.ReadMessage()
msgTyp, data, err := wcw.getConn().ReadMessage()
if err != nil {
if wcw.isFinished() {
return
}
errMsg := err.Error()
if strings.Contains(errMsg, "close") {
if wcw.reconn() {
continue
}
wcw.Lock()
wcw.finished = true
wcw.Unlock()
}
for key, errCh := range wcw.errCh {
select {
case errCh <- err:
default:
conf.Log.Warnf("websocket client connection discard one error for %v", key)
}
}
if strings.Contains(err.Error(), "close") {
conf.Log.Info("websocket client closed")
return
}
continue
}
if msgTyp == websocket.TextMessage {
Expand All @@ -112,6 +137,9 @@ func (wcw *websocketClientWrapper) process() {
func (wcw *websocketClientWrapper) Subscribe(ctx api.StreamContext, subChan []api.TopicChannel, messageErrors chan error, _ map[string]interface{}) error {
wcw.Lock()
defer wcw.Unlock()
if wcw.finished {
return errors.New("websocket client connection closed")
}
subId := wcw.getID(ctx)
if _, ok := wcw.chs[subId]; ok {
return fmt.Errorf("%s subsucribe websocket client connection duplidated", subId)
Expand All @@ -132,15 +160,18 @@ func (wcw *websocketClientWrapper) Release(ctx api.StreamContext) bool {
delete(wcw.errCh, subID)
wcw.refCount--
if wcw.refCount == 0 {
wcw.c.Close()
wcw.finished = true
wcw.c.Close()
return true
}
return false
}

func (wcw *websocketClientWrapper) Publish(c api.StreamContext, topic string, message []byte, params map[string]interface{}) error {
return wcw.c.WriteMessage(websocket.TextMessage, message)
if wcw.isFinished() {
return errors.New("websocket client connection closed")
}
return wcw.getConn().WriteMessage(websocket.TextMessage, message)
}

func (wcw *websocketClientWrapper) SetConnectionSelector(conSelector string) {
Expand All @@ -156,3 +187,26 @@ func (wcw *websocketClientWrapper) AddRef() {
defer wcw.Unlock()
wcw.refCount++
}

func (wcw *websocketClientWrapper) reconn() bool {
if wcw.isFinished() {
return false
}
conf.Log.Info("websocket client closed, try to reconnect")
for i := 1; i <= wcw.maxConnRetry; i++ {
conn, err := GetWebsocketClientConn(wcw.config.Addr, wcw.config.Path, wcw.config.tlsConfig)
if err != nil {
conf.Log.Infof("websocket client connection reconnect failed, retry: %v, err:%v", i, err)
if i < wcw.maxConnRetry {
time.Sleep(10 * time.Millisecond)
}
continue
}
wcw.getConn().Close()
wcw.setConn(conn)
conf.Log.Info("websocket client reconnect success")
return true
}
conf.Log.Info("websocket client reconnect failed")
return false
}
43 changes: 30 additions & 13 deletions internal/topo/connection/clients/websocket/websocket_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,55 @@ const (

func TestWebsocketPubSub(t *testing.T) {
go mockWebSocketServer()
time.Sleep(100 * time.Millisecond)
ctx := context.NewMockContext("123", "123")
cli, err := newWebsocketClientClientWrapper(&WebSocketConnectionConfig{Addr: addr, Path: path})
cli, err := newWebsocketClientClientWrapper(&WebSocketConnectionConfig{Addr: addr, Path: path, MaxConnRetry: 3})
wsCli := cli.(*websocketClientWrapper)
// ensure goroutine closed
defer wsCli.Wait()
// wait server goroutine process running
<-handleCh

require.NoError(t, err)
cli.SetConnectionSelector("456")
require.Equal(t, "456", cli.GetConnectionSelector())
data := map[string]interface{}{"a": float64(1)}
databytes, err := json.Marshal(data)
require.NoError(t, err)

dataCh := make(chan interface{})
dataCh := make(chan interface{}, 16)
subs := []api.TopicChannel{
{
Topic: "",
Messages: dataCh,
},
}
errCh := make(chan error)
errCh := make(chan error, 16)
require.NoError(t, cli.Subscribe(ctx, subs, errCh, map[string]interface{}{}))
err = cli.Publish(ctx, "", databytes, map[string]interface{}{})
require.NoError(t, err)
// assert pub
require.Equal(t, data, <-recvDataCh)
// assert sub
require.Equal(t, databytes, <-dataCh)
processDone <- struct{}{}
// ensure connection closed
<-connCloseCh
// wait cli connection reconnect
<-handleCh

err = cli.Publish(ctx, "", databytes, map[string]interface{}{})
require.NoError(t, err)
// assert pub
require.Equal(t, data, <-recvDataCh)
// assert sub
require.Equal(t, databytes, <-dataCh)
<-connCloseCh

cli.AddRef()
cli.Release(ctx)
wsCli := cli.(*websocketClientWrapper)
require.False(t, wsCli.isFinished())
require.False(t, wsCli.processDone)
cli.Release(ctx)
require.True(t, wsCli.isFinished())
time.Sleep(100 * time.Millisecond)
require.True(t, wsCli.processDone)
}

func mockWebSocketServer() {
Expand All @@ -80,12 +94,14 @@ func mockWebSocketServer() {

var (
recvDataCh chan interface{}
processDone chan struct{}
connCloseCh chan struct{}
handleCh chan struct{}
)

func init() {
recvDataCh = make(chan interface{})
processDone = make(chan struct{})
connCloseCh = make(chan struct{})
handleCh = make(chan struct{})
}

var upgrader = websocket.Upgrader{
Expand All @@ -95,7 +111,6 @@ var upgrader = websocket.Upgrader{
}

func process(c *websocket.Conn) {
defer c.Close()
_, message, err := c.ReadMessage()
if err != nil {
recvDataCh <- err
Expand All @@ -110,7 +125,8 @@ func process(c *websocket.Conn) {
recvDataCh <- a

c.WriteMessage(websocket.TextMessage, message)
<-processDone
c.Close()
connCloseCh <- struct{}{}
}

func handler(w http.ResponseWriter, r *http.Request) {
Expand All @@ -119,6 +135,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
conf.Log.Errorf("upgrade: %v", err)
return
}

go process(c)
time.Sleep(100 * time.Millisecond)
handleCh <- struct{}{}
}

0 comments on commit c91d1be

Please sign in to comment.