Skip to content

Commit

Permalink
feat: support custom websocket dialer (#197)
Browse files Browse the repository at this point in the history
* feat: support custom websocket dialer

* tests
  • Loading branch information
yunyu950908 authored Nov 27, 2024
1 parent 035189a commit 578cde3
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 9 deletions.
10 changes: 10 additions & 0 deletions client_web_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"os"
"os/signal"
"time"

"github.com/gorilla/websocket"
)

const (
Expand All @@ -28,6 +30,7 @@ type WebSocketClient struct {
baseURL string
key string
secret string
dialer *websocket.Dialer
}

func (c *WebSocketClient) debugf(format string, v ...interface{}) {
Expand Down Expand Up @@ -75,6 +78,13 @@ func (c *WebSocketClient) WithBaseURL(url string) *WebSocketClient {
return c
}

// WithDialer :
func (c *WebSocketClient) WithDialer(dialer *websocket.Dialer) *WebSocketClient {
c.dialer = dialer

return c
}

// hasAuth : check has auth key and secret
func (c *WebSocketClient) hasAuth() bool {
return c.key != "" && c.secret != ""
Expand Down
43 changes: 43 additions & 0 deletions testhelper/server_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package testhelper

import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -40,3 +43,43 @@ func TestWebsocketServer(t *testing.T) {
assert.ErrorIs(t, err, websocket.ErrBadHandshake)
})
}

func TestWebsocketServerWithCustomDialer(t *testing.T) {
customDialer := &websocket.Dialer{
Proxy: func(req *http.Request) (*url.URL, error) {
return nil, nil
},
HandshakeTimeout: 5 * time.Second,
}

t.Run("custom dialer success", func(t *testing.T) {
path := "/custom"
respBody := struct {
Message string `json:"message"`
}{
Message: "custom dialer ok",
}
bytesBody, err := json.Marshal(respBody)
require.NoError(t, err)
server, teardown := NewWebsocketServer(WithWebsocketHandlerOption(path, bytesBody))
defer teardown()

c, _, err := customDialer.Dial(server.URL+path, nil)
require.NoError(t, err)

assert.NoError(t, c.WriteMessage(websocket.TextMessage, nil))

_, message, err := c.ReadMessage()
require.NoError(t, err)
assert.Equal(t, bytesBody, message)
})

t.Run("custom dialer failure", func(t *testing.T) {
path := "/custom"
server, teardown := NewWebsocketServer()
defer teardown()

_, _, err := customDialer.Dial(server.URL+path, nil)
assert.ErrorIs(t, err, websocket.ErrBadHandshake)
})
}
24 changes: 21 additions & 3 deletions v5_client_web_socket_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ type V5WebsocketService struct {
// Public :
func (s *V5WebsocketService) Public(category CategoryV5) (V5WebsocketPublicServiceI, error) {
url := s.client.baseURL + V5WebsocketPublicPathFor(category)
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand All @@ -38,7 +44,13 @@ func (s *V5WebsocketService) Public(category CategoryV5) (V5WebsocketPublicServi
// Private :
func (s *V5WebsocketService) Private() (V5WebsocketPrivateServiceI, error) {
url := s.client.baseURL + V5WebsocketPrivatePath
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand All @@ -55,7 +67,13 @@ func (s *V5WebsocketService) Private() (V5WebsocketPrivateServiceI, error) {
// Trade :
func (s *V5WebsocketService) Trade() (V5WebsocketTradeServiceI, error) {
url := s.client.baseURL + V5WebsocketTradePath
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand Down
12 changes: 11 additions & 1 deletion v5_ws_private_order_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package bybit

import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/hirokisan/bybit/v2/testhelper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -64,7 +68,13 @@ func TestV5WebsocketPrivate_Order(t *testing.T) {

wsClient := NewTestWebsocketClient().
WithBaseURL(server.URL).
WithAuth("test", "test")
WithAuth("test", "test").
WithDialer(&websocket.Dialer{
Proxy: func(req *http.Request) (*url.URL, error) {
return nil, nil
},
HandshakeTimeout: 5 * time.Second,
})

svc, err := wsClient.V5().Private()
require.NoError(t, err)
Expand Down
12 changes: 11 additions & 1 deletion v5_ws_private_position_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package bybit

import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/hirokisan/bybit/v2/testhelper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -56,7 +60,13 @@ func TestV5WebsocketPrivate_Position(t *testing.T) {

wsClient := NewTestWebsocketClient().
WithBaseURL(server.URL).
WithAuth("test", "test")
WithAuth("test", "test").
WithDialer(&websocket.Dialer{
Proxy: func(req *http.Request) (*url.URL, error) {
return nil, nil
},
HandshakeTimeout: 5 * time.Second,
})

svc, err := wsClient.V5().Private()
require.NoError(t, err)
Expand Down
24 changes: 21 additions & 3 deletions ws_spot_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ type SpotWebsocketV1Service struct {
// PublicV1 :
func (s *SpotWebsocketV1Service) PublicV1() (*SpotWebsocketV1PublicV1Service, error) {
url := s.client.baseURL + SpotWebsocketV1PublicV1Path
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand All @@ -25,7 +31,13 @@ func (s *SpotWebsocketV1Service) PublicV1() (*SpotWebsocketV1PublicV1Service, er
// PublicV2 :
func (s *SpotWebsocketV1Service) PublicV2() (*SpotWebsocketV1PublicV2Service, error) {
url := s.client.baseURL + SpotWebsocketV1PublicV2Path
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand All @@ -38,7 +50,13 @@ func (s *SpotWebsocketV1Service) PublicV2() (*SpotWebsocketV1PublicV2Service, er
// Private :
func (s *SpotWebsocketV1Service) Private() (*SpotWebsocketV1PrivateService, error) {
url := s.client.baseURL + SpotWebsocketV1PrivatePath
c, _, err := websocket.DefaultDialer.Dial(url, nil)
var c *websocket.Conn
var err error
if s.client.dialer != nil {
c, _, err = s.client.dialer.Dial(url, nil)
} else {
c, _, err = websocket.DefaultDialer.Dial(url, nil)
}
if err != nil {
return nil, err
}
Expand Down
12 changes: 11 additions & 1 deletion ws_spot_v1_private_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package bybit

import (
"encoding/json"
"net/http"
"net/url"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/hirokisan/bybit/v2/testhelper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -38,7 +42,13 @@ func TestSpotWebsocketV1PrivateOutboundAccountInfo(t *testing.T) {

wsClient := NewTestWebsocketClient().
WithBaseURL(server.URL).
WithAuth("test", "test")
WithAuth("test", "test").
WithDialer(&websocket.Dialer{
Proxy: func(req *http.Request) (*url.URL, error) {
return nil, nil
},
HandshakeTimeout: 5 * time.Second,
})

svc, err := wsClient.Spot().V1().Private()
require.NoError(t, err)
Expand Down

0 comments on commit 578cde3

Please sign in to comment.