Skip to content

Commit

Permalink
[CT-1321] subscribe to market prices streaming services (#2592)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayy04 authored Nov 22, 2024
1 parent 395f448 commit fefd10e
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 12 deletions.
103 changes: 102 additions & 1 deletion protocol/streaming/full_node_streaming_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/dydxprotocol/v4-chain/protocol/lib"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"

"cosmossdk.io/log"
Expand Down Expand Up @@ -52,6 +53,8 @@ type FullNodeStreamingManagerImpl struct {
clobPairIdToSubscriptionIdMapping map[uint32][]uint32
// map from subaccount id to subscription ids.
subaccountIdToSubscriptionIdMapping map[satypes.SubaccountId][]uint32
// map from market id to subscription ids.
marketIdToSubscriptionIdMapping map[uint32][]uint32

maxUpdatesInCache uint32
maxSubscriptionChannelSize uint32
Expand Down Expand Up @@ -79,6 +82,9 @@ type OrderbookSubscription struct {
// Subaccount ids to subscribe to.
subaccountIds []satypes.SubaccountId

// market ids to subscribe to.
marketIds []uint32

// Stream
messageSender types.OutgoingMessageSender

Expand Down Expand Up @@ -114,6 +120,7 @@ func NewFullNodeStreamingManager(
streamUpdateSubscriptionCache: make([][]uint32, 0),
clobPairIdToSubscriptionIdMapping: make(map[uint32][]uint32),
subaccountIdToSubscriptionIdMapping: make(map[satypes.SubaccountId][]uint32),
marketIdToSubscriptionIdMapping: make(map[uint32][]uint32),

maxUpdatesInCache: maxUpdatesInCache,
maxSubscriptionChannelSize: maxSubscriptionChannelSize,
Expand Down Expand Up @@ -184,6 +191,7 @@ func (sm *FullNodeStreamingManagerImpl) getNextAvailableSubscriptionId() uint32
func (sm *FullNodeStreamingManagerImpl) Subscribe(
clobPairIds []uint32,
subaccountIds []*satypes.SubaccountId,
marketIds []uint32,
messageSender types.OutgoingMessageSender,
) (
err error,
Expand All @@ -206,6 +214,7 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe(
initialized: &atomic.Bool{}, // False by default.
clobPairIds: clobPairIds,
subaccountIds: sIds,
marketIds: marketIds,
messageSender: messageSender,
updatesChannel: make(chan []clobtypes.StreamUpdate, sm.maxSubscriptionChannelSize),
}
Expand All @@ -231,6 +240,17 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe(
subscription.subscriptionId,
)
}
for _, marketId := range marketIds {
// if subaccountId exists in the map, append the subscription id to the slice
// otherwise, create a new slice with the subscription id
if _, ok := sm.marketIdToSubscriptionIdMapping[marketId]; !ok {
sm.marketIdToSubscriptionIdMapping[marketId] = []uint32{}
}
sm.marketIdToSubscriptionIdMapping[marketId] = append(
sm.marketIdToSubscriptionIdMapping[marketId],
subscription.subscriptionId,
)
}

sm.logger.Info(
fmt.Sprintf(
Expand Down Expand Up @@ -325,6 +345,21 @@ func (sm *FullNodeStreamingManagerImpl) removeSubscription(
}
}

// Iterate over the marketIdToSubscriptionIdMapping to remove the subscriptionIdToRemove
for marketId, subscriptionIds := range sm.marketIdToSubscriptionIdMapping {
for i, id := range subscriptionIds {
if id == subscriptionIdToRemove {
// Remove the subscription ID from the slice
sm.marketIdToSubscriptionIdMapping[marketId] = append(subscriptionIds[:i], subscriptionIds[i+1:]...)
break
}
}
// If the list is empty after removal, delete the key from the map
if len(sm.marketIdToSubscriptionIdMapping[marketId]) == 0 {
delete(sm.marketIdToSubscriptionIdMapping, marketId)
}
}

sm.logger.Info(
fmt.Sprintf("Removed streaming subscription id %+v", subscriptionIdToRemove),
)
Expand Down Expand Up @@ -372,6 +407,24 @@ func toSubaccountStreamUpdates(
return streamUpdates
}

func toPriceStreamUpdates(
priceUpdates []*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) []clobtypes.StreamUpdate {
streamUpdates := make([]clobtypes.StreamUpdate, 0)
for _, update := range priceUpdates {
streamUpdates = append(streamUpdates, clobtypes.StreamUpdate{
UpdateMessage: &clobtypes.StreamUpdate_PriceUpdate{
PriceUpdate: update,
},
BlockHeight: blockHeight,
ExecMode: uint32(execMode),
})
}
return streamUpdates
}

func (sm *FullNodeStreamingManagerImpl) sendStreamUpdates(
subscriptionId uint32,
streamUpdates []clobtypes.StreamUpdate,
Expand Down Expand Up @@ -466,6 +519,7 @@ func (sm *FullNodeStreamingManagerImpl) GetStagedFinalizeBlockEvents(
func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot(
offchainUpdates *clobtypes.OffchainUpdates,
saUpdates []*satypes.StreamSubaccountUpdate,
priceUpdates []*pricestypes.StreamPriceUpdate,
subscriptionId uint32,
blockHeight uint32,
execMode sdk.ExecMode,
Expand All @@ -479,6 +533,7 @@ func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot(
var streamUpdates []clobtypes.StreamUpdate
streamUpdates = append(streamUpdates, toOrderbookStreamUpdate(offchainUpdates, blockHeight, execMode)...)
streamUpdates = append(streamUpdates, toSubaccountStreamUpdates(saUpdates, blockHeight, execMode)...)
streamUpdates = append(streamUpdates, toPriceStreamUpdates(priceUpdates, blockHeight, execMode)...)
sm.sendStreamUpdates(subscriptionId, streamUpdates)
}

Expand Down Expand Up @@ -863,6 +918,30 @@ func (sm *FullNodeStreamingManagerImpl) GetSubaccountSnapshotsForInitStreams(
return ret
}

func (sm *FullNodeStreamingManagerImpl) GetPriceSnapshotsForInitStreams(
getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate {
sm.Lock()
defer sm.Unlock()

ret := make(map[uint32]*pricestypes.StreamPriceUpdate)
for _, subscription := range sm.orderbookSubscriptions {
// If the subscription has been initialized, no need to grab the price snapshot.
if alreadyInitialized := subscription.initialized.Load(); alreadyInitialized {
continue
}

for _, marketId := range subscription.marketIds {
if _, exists := ret[marketId]; exists {
continue
}

ret[marketId] = getPriceSnapshot(marketId)
}
}
return ret
}

// cacheStreamUpdatesByClobPairWithLock adds stream updates to cache,
// and store corresponding clob pair Ids.
// This method requires the lock and assumes that the lock has already been
Expand Down Expand Up @@ -1003,6 +1082,7 @@ func (sm *FullNodeStreamingManagerImpl) getStagedEventsFromFinalizeBlock(
func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
pricesSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) {
Expand Down Expand Up @@ -1038,7 +1118,28 @@ func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams(
}
}

sm.SendCombinedSnapshot(allUpdates, saUpdates, subscriptionId, blockHeight, execMode)
priceUpdates := []*pricestypes.StreamPriceUpdate{}
for _, marketId := range subscription.marketIds {
if priceUpdate, ok := pricesSnapshots[marketId]; ok {
priceUpdates = append(priceUpdates, priceUpdate)
} else {
sm.logger.Error(
fmt.Sprintf(
"Price update not found for market id %v. This should not happen.",
marketId,
),
)
}
}

sm.SendCombinedSnapshot(
allUpdates,
saUpdates,
priceUpdates,
subscriptionId,
blockHeight,
execMode,
)

if sm.snapshotBlockInterval != 0 {
subscription.nextSnapshotBlock = blockHeight + sm.snapshotBlockInterval
Expand Down
9 changes: 9 additions & 0 deletions protocol/streaming/noop_streaming_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/streaming/types"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

Expand All @@ -22,6 +23,7 @@ func (sm *NoopGrpcStreamingManager) Enabled() bool {
func (sm *NoopGrpcStreamingManager) Subscribe(
_ []uint32,
_ []*satypes.SubaccountId,
_ []uint32,
_ types.OutgoingMessageSender,
) (
err error,
Expand Down Expand Up @@ -58,9 +60,16 @@ func (sm *NoopGrpcStreamingManager) GetSubaccountSnapshotsForInitStreams(
return nil
}

func (sm *NoopGrpcStreamingManager) GetPriceSnapshotsForInitStreams(
_ func(_ uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate {
return nil
}

func (sm *NoopGrpcStreamingManager) InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) {
Expand Down
6 changes: 6 additions & 0 deletions protocol/streaming/types/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package types
import (
sdk "github.com/cosmos/cosmos-sdk/types"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

Expand All @@ -14,6 +15,7 @@ type FullNodeStreamingManager interface {
Subscribe(
clobPairIds []uint32,
subaccountIds []*satypes.SubaccountId,
marketIds []uint32,
srv OutgoingMessageSender,
) (
err error,
Expand All @@ -23,12 +25,16 @@ type FullNodeStreamingManager interface {
InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
)
GetSubaccountSnapshotsForInitStreams(
getSubaccountSnapshot func(subaccountId satypes.SubaccountId) *satypes.StreamSubaccountUpdate,
) map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate
GetPriceSnapshotsForInitStreams(
getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate
SendOrderbookUpdates(
offchainUpdates *clobtypes.OffchainUpdates,
ctx sdk.Context,
Expand Down
40 changes: 29 additions & 11 deletions protocol/streaming/ws/websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
"github.com/gorilla/websocket"
)

const (
CLOB_PAIR_IDS_QUERY_PARAM = "clobPairIds"
MARKET_IDS_QUERY_PARAM = "marketIds"
)

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Expand Down Expand Up @@ -61,7 +66,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
conn.SetReadLimit(10 * 1024 * 1024)

// Parse clobPairIds from query parameters
clobPairIds, err := parseClobPairIds(r)
clobPairIds, err := parseUint32(r, CLOB_PAIR_IDS_QUERY_PARAM)
if err != nil {
ws.logger.Error(
"Error parsing clobPairIds",
Expand All @@ -70,6 +75,18 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Parse marketIds from query parameters
marketIds, err := parseUint32(r, MARKET_IDS_QUERY_PARAM)
if err != nil {
ws.logger.Error(
"Error parsing marketIds",
"err", err,
)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Parse subaccountIds from query parameters
subaccountIds, err := parseSubaccountIds(r)
if err != nil {
Expand All @@ -93,6 +110,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
err = ws.streamingManager.Subscribe(
clobPairIds,
subaccountIds,
marketIds,
websocketMessageSender,
)
if err != nil {
Expand Down Expand Up @@ -136,26 +154,26 @@ func parseSubaccountIds(r *http.Request) ([]*satypes.SubaccountId, error) {
return subaccountIds, nil
}

// parseClobPairIds is a helper function to parse the clobPairIds from the query parameters.
func parseClobPairIds(r *http.Request) ([]uint32, error) {
clobPairIdsParam := r.URL.Query().Get("clobPairIds")
if clobPairIdsParam == "" {
// parseUint32 is a helper function to parse the uint32 from the query parameters.
func parseUint32(r *http.Request, queryParam string) ([]uint32, error) {
param := r.URL.Query().Get(queryParam)
if param == "" {
return []uint32{}, nil
}
idStrs := strings.Split(clobPairIdsParam, ",")
clobPairIds := make([]uint32, 0)
idStrs := strings.Split(param, ",")
ids := make([]uint32, 0)
for _, idStr := range idStrs {
id, err := strconv.Atoi(idStr)
if err != nil {
return nil, fmt.Errorf("invalid clobPairId: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
if id < 0 || id > math.MaxInt32 {
return nil, fmt.Errorf("invalid clob pair id: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
clobPairIds = append(clobPairIds, uint32(id))
ids = append(ids, uint32(id))
}

return clobPairIds, nil
return ids, nil
}

// Start the websocket server in a separate goroutine.
Expand Down
1 change: 1 addition & 0 deletions protocol/x/clob/keeper/grpc_stream_orderbook.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func (k Keeper) StreamOrderbookUpdates(
err := k.GetFullNodeStreamingManager().Subscribe(
req.GetClobPairId(),
req.GetSubaccountIds(),
req.GetMarketIds(),
stream,
)
if err != nil {
Expand Down
Loading

0 comments on commit fefd10e

Please sign in to comment.