Skip to content

Commit

Permalink
Add panic recovery to mercury wsrpc client (#15846)
Browse files Browse the repository at this point in the history
* Add panic recovery to mercury wsrpc client

Wraps calls to Transmit and LatestReport with automatic recovers,
triggering a redial of the underlying connection.

* Changeset
  • Loading branch information
samsondav authored Jan 7, 2025
1 parent 62b4f1e commit 6aa365d
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 65 deletions.
7 changes: 7 additions & 0 deletions .changeset/eleven-cheetahs-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"chainlink": patch
---

Add panic recovery to wsrpc mercury client

- Should help to make nodes running wsrpc v0.8.2 more stable #bugfix
173 changes: 124 additions & 49 deletions core/services/relay/evm/mercury/wsrpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,27 @@ type Client interface {
}

type Conn interface {
wsrpc.ClientInterface
WaitForReady(ctx context.Context) bool
GetState() grpc_connectivity.State
Close() error
}

type DialWithContextFunc func(ctxCaller context.Context, target string, opts ...wsrpc.DialOption) (Conn, error)

type client struct {
services.StateMachine

csaKey csakey.KeyV2
serverPubKey []byte
serverURL string

dialWithContext DialWithContextFunc

logger logger.Logger
conn Conn
rawClient pb.MercuryClient
mu sync.RWMutex

consecutiveTimeoutCnt atomic.Int32
wg sync.WaitGroup
Expand All @@ -101,25 +107,47 @@ type client struct {
connectionResetCountMetric prometheus.Counter
}

type ClientOpts struct {
Logger logger.Logger
ClientPrivKey csakey.KeyV2
ServerPubKey []byte
ServerURL string
CacheSet cache.CacheSet

// DialWithContext allows optional dependency injection for testing
DialWithContext DialWithContextFunc
}

// Consumers of wsrpc package should not usually call NewClient directly, but instead use the Pool
func NewClient(lggr logger.Logger, clientPrivKey csakey.KeyV2, serverPubKey []byte, serverURL string, cacheSet cache.CacheSet) Client {
return newClient(lggr, clientPrivKey, serverPubKey, serverURL, cacheSet)
func NewClient(opts ClientOpts) Client {
return newClient(opts)
}

func newClient(lggr logger.Logger, clientPrivKey csakey.KeyV2, serverPubKey []byte, serverURL string, cacheSet cache.CacheSet) *client {
func newClient(opts ClientOpts) *client {
var dialWithContext DialWithContextFunc
if opts.DialWithContext != nil {
dialWithContext = opts.DialWithContext
} else {
// NOTE: Wrap here since wsrpc.DialWithContext returns a concrete *wsrpc.Conn, not an interface
dialWithContext = func(ctxCaller context.Context, target string, opts ...wsrpc.DialOption) (Conn, error) {
conn, err := wsrpc.DialWithContext(ctxCaller, target, opts...)
return conn, err
}
}
return &client{
csaKey: clientPrivKey,
serverPubKey: serverPubKey,
serverURL: serverURL,
logger: lggr.Named("WSRPC").Named(serverURL).With("serverURL", serverURL),
dialWithContext: dialWithContext,
csaKey: opts.ClientPrivKey,
serverPubKey: opts.ServerPubKey,
serverURL: opts.ServerURL,
logger: opts.Logger.Named("WSRPC").Named(opts.ServerURL).With("serverURL", opts.ServerURL),
chResetTransport: make(chan struct{}, 1),
cacheSet: cacheSet,
cacheSet: opts.CacheSet,
chStop: make(services.StopChan),
timeoutCountMetric: timeoutCount.WithLabelValues(serverURL),
dialCountMetric: dialCount.WithLabelValues(serverURL),
dialSuccessCountMetric: dialSuccessCount.WithLabelValues(serverURL),
dialErrorCountMetric: dialErrorCount.WithLabelValues(serverURL),
connectionResetCountMetric: connectionResetCount.WithLabelValues(serverURL),
timeoutCountMetric: timeoutCount.WithLabelValues(opts.ServerURL),
dialCountMetric: dialCount.WithLabelValues(opts.ServerURL),
dialSuccessCountMetric: dialSuccessCount.WithLabelValues(opts.ServerURL),
dialErrorCountMetric: dialErrorCount.WithLabelValues(opts.ServerURL),
connectionResetCountMetric: connectionResetCount.WithLabelValues(opts.ServerURL),
}
}

Expand Down Expand Up @@ -148,7 +176,7 @@ func (w *client) Start(ctx context.Context) error {
// with error.
func (w *client) dial(ctx context.Context, opts ...wsrpc.DialOption) error {
w.dialCountMetric.Inc()
conn, err := wsrpc.DialWithContext(ctx, w.serverURL,
conn, err := w.dialWithContext(ctx, w.serverURL,
append(opts,
wsrpc.WithTransportCreds(w.csaKey.Raw().Bytes(), w.serverPubKey),
wsrpc.WithLogger(w.logger),
Expand All @@ -161,8 +189,10 @@ func (w *client) dial(ctx context.Context, opts ...wsrpc.DialOption) error {
}
w.dialSuccessCountMetric.Inc()
setLivenessMetric(true)
w.mu.Lock()
w.conn = conn
w.rawClient = pb.NewMercuryClient(conn)
w.mu.Unlock()
return nil
}

Expand All @@ -184,6 +214,8 @@ func (w *client) runloop() {
func (w *client) resetTransport() {
w.connectionResetCountMetric.Inc()
ok := w.IfStarted(func() {
w.mu.RLock()
defer w.mu.RUnlock()
w.conn.Close() // Close is safe to call multiple times
})
if !ok {
Expand Down Expand Up @@ -211,7 +243,9 @@ func (w *client) resetTransport() {
func (w *client) Close() error {
return w.StopOnce("WSRPC Client", func() error {
close(w.chStop)
w.mu.RLock()
w.conn.Close()
w.mu.RUnlock()
w.wg.Wait()
return nil
})
Expand Down Expand Up @@ -251,24 +285,46 @@ func (w *client) waitForReady(ctx context.Context) (err error) {
}

func (w *client) Transmit(ctx context.Context, req *pb.TransmitRequest) (resp *pb.TransmitResponse, err error) {
w.logger.Trace("Transmit")
start := time.Now()
if err = w.waitForReady(ctx); err != nil {
return nil, errors.Wrap(err, "Transmit call failed")
}
resp, err = w.rawClient.Transmit(ctx, req)
w.handleTimeout(err)
if err != nil {
w.logger.Warnw("Transmit call failed due to networking error", "err", err, "resp", resp)
incRequestStatusMetric(statusFailed)
} else {
w.logger.Tracew("Transmit call succeeded", "resp", resp)
incRequestStatusMetric(statusSuccess)
setRequestLatencyMetric(float64(time.Since(start).Milliseconds()))
ok := w.IfStarted(func() {
defer func() {
if r := recover(); r != nil {
w.handlePanic(r)
resp = nil
err = fmt.Errorf("Transmit: caught panic: %v", r)
}
}()
w.logger.Trace("Transmit")
start := time.Now()
if err = w.waitForReady(ctx); err != nil {
err = errors.Wrap(err, "Transmit call failed")
return
}
w.mu.RLock()
rc := w.rawClient
w.mu.RUnlock()
resp, err = rc.Transmit(ctx, req)
w.handleTimeout(err)
if err != nil {
w.logger.Warnw("Transmit call failed due to networking error", "err", err, "resp", resp)
incRequestStatusMetric(statusFailed)
} else {
w.logger.Tracew("Transmit call succeeded", "resp", resp)
incRequestStatusMetric(statusSuccess)
setRequestLatencyMetric(float64(time.Since(start).Milliseconds()))
}
})
if !ok {
err = errors.New("client is not started")
}
return
}

// hacky workaround to trap panics from buggy underlying wsrpc lib and restart
// the connection from a known good state
func (w *client) handlePanic(r interface{}) {
w.chResetTransport <- struct{}{}
}

func (w *client) handleTimeout(err error) {
if errors.Is(err, context.DeadlineExceeded) {
w.timeoutCountMetric.Inc()
Expand Down Expand Up @@ -303,27 +359,44 @@ func (w *client) handleTimeout(err error) {
}

func (w *client) LatestReport(ctx context.Context, req *pb.LatestReportRequest) (resp *pb.LatestReportResponse, err error) {
lggr := w.logger.With("req.FeedId", hexutil.Encode(req.FeedId))
lggr.Trace("LatestReport")
if err = w.waitForReady(ctx); err != nil {
return nil, errors.Wrap(err, "LatestReport failed")
}
var cached bool
if w.cache == nil {
resp, err = w.rawClient.LatestReport(ctx, req)
w.handleTimeout(err)
} else {
cached = true
resp, err = w.cache.LatestReport(ctx, req)
}
if err != nil {
lggr.Errorw("LatestReport failed", "err", err, "resp", resp, "cached", cached)
} else if resp.Error != "" {
lggr.Errorw("LatestReport failed; mercury server returned error", "err", resp.Error, "resp", resp, "cached", cached)
} else if !cached {
lggr.Debugw("LatestReport succeeded", "resp", resp, "cached", cached)
} else {
lggr.Tracew("LatestReport succeeded", "resp", resp, "cached", cached)
ok := w.IfStarted(func() {
defer func() {
if r := recover(); r != nil {
w.handlePanic(r)
resp = nil
err = fmt.Errorf("LatestReport: caught panic: %v", r)
}
}()
lggr := w.logger.With("req.FeedId", hexutil.Encode(req.FeedId))
lggr.Trace("LatestReport")
if err = w.waitForReady(ctx); err != nil {
err = errors.Wrap(err, "LatestReport failed")
return
}
var cached bool
if w.cache == nil {
w.mu.RLock()
rc := w.rawClient
w.mu.RUnlock()
resp, err = rc.LatestReport(ctx, req)
w.handleTimeout(err)
} else {
cached = true
resp, err = w.cache.LatestReport(ctx, req)
}
switch {
case err != nil:
lggr.Errorw("LatestReport failed", "err", err, "resp", resp, "cached", cached)
case resp.Error != "":
lggr.Errorw("LatestReport failed; mercury server returned error", "err", resp.Error, "resp", resp, "cached", cached)
case !cached:
lggr.Debugw("LatestReport succeeded", "resp", resp, "cached", cached)
default:
lggr.Tracew("LatestReport succeeded", "resp", resp, "cached", cached)
}
})
if !ok {
err = errors.New("client is not started")
}
return
}
Expand All @@ -333,5 +406,7 @@ func (w *client) ServerURL() string {
}

func (w *client) RawClient() pb.MercuryClient {
w.mu.RLock()
defer w.mu.RUnlock()
return w.rawClient
}
77 changes: 75 additions & 2 deletions core/services/relay/evm/mercury/wsrpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@ package wsrpc

import (
"context"
"math/big"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
grpc_connectivity "google.golang.org/grpc/connectivity"

"github.com/smartcontractkit/wsrpc"

"github.com/smartcontractkit/chainlink-common/pkg/services/servicetest"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/csakey"
Expand Down Expand Up @@ -74,7 +80,15 @@ func Test_Client_Transmit(t *testing.T) {
conn := &mocks.MockConn{
Ready: true,
}
c := newClient(lggr, csakey.KeyV2{}, nil, "", noopCacheSet)
opts := ClientOpts{
lggr,
csakey.KeyV2{},
nil,
"",
noopCacheSet,
nil,
}
c := newClient(opts)
c.conn = conn
c.rawClient = wsrpcClient
require.NoError(t, c.StartOnce("Mock WSRPC Client", func() error { return nil }))
Expand Down Expand Up @@ -115,6 +129,65 @@ func Test_Client_Transmit(t *testing.T) {
}
})
})

t.Run("recovers panics in underlying client and attempts redial", func(t *testing.T) {
conn := &mocks.MockConn{
Ready: true,
State: grpc_connectivity.Ready,
InvokeF: func(ctx context.Context, method string, args interface{}, reply interface{}) error {
panic("TESTING CONN INVOKE PANIC")
},
}

ch := make(chan struct{}, 100)
cnt := 0

f := func(ctxCaller context.Context, target string, opts ...wsrpc.DialOption) (Conn, error) {
cnt++
switch cnt {
case 1:
ch <- struct{}{}
return conn, nil
case 2:
ch <- struct{}{}
return nil, nil
default:
t.Fatalf("too many dials, got: %d", cnt)
return nil, nil
}
}

clientKey := csakey.MustNewV2XXXTestingOnly(big.NewInt(rand.Int63()))
serverKey := csakey.MustNewV2XXXTestingOnly(big.NewInt(rand.Int63()))
opts := ClientOpts{
lggr,
clientKey,
serverKey.PublicKey,
"",
noopCacheSet,
f,
}
c := newClient(opts)

require.NoError(t, c.Start(tests.Context(t)))

// drain the channel
select {
case <-ch:
assert.Equal(t, 1, cnt)
default:
t.Fatalf("expected dial to be called")
}

_, err := c.Transmit(ctx, req)
require.EqualError(t, err, "Transmit: caught panic: TESTING CONN INVOKE PANIC")

// expect conn to be closed and re-dialed
<-ch
assert.Equal(t, 2, cnt)

assert.True(t, conn.Closed)
})
}

func Test_Client_LatestReport(t *testing.T) {
Expand Down Expand Up @@ -159,7 +232,7 @@ func Test_Client_LatestReport(t *testing.T) {
conn := &mocks.MockConn{
Ready: true,
}
c := newClient(lggr, csakey.KeyV2{}, nil, "", cacheSet)
c := newClient(ClientOpts{lggr, csakey.KeyV2{}, nil, "", cacheSet, nil})
c.conn = conn
c.rawClient = wsrpcClient

Expand Down
Loading

0 comments on commit 6aa365d

Please sign in to comment.