Skip to content

Commit

Permalink
sql: fix deadlock in querycache (#5646)
Browse files Browse the repository at this point in the history
## Motivation

Query cache may cause deadlock in some cases, such as a new ATX arriving while an epoch info request is being processed.
  • Loading branch information
ivan4th committed Mar 6, 2024
1 parent 0ce03f2 commit efea4b2
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 97 deletions.
2 changes: 1 addition & 1 deletion activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error {
} else if err != nil {
return err
}
blob, err := atxs.GetBlob(b.cdb, atx.Bytes())
blob, err := atxs.GetBlob(ctx, b.cdb, atx.Bytes())
if err != nil {
return fmt.Errorf("get blob %s: %w", atx.ShortString(), err)
}
Expand Down
2 changes: 1 addition & 1 deletion activation/activation_multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func TestRegossip(t *testing.T) {
}
}

blob, err := atxs.GetBlob(tab.cdb, refAtx.ID().Bytes())
blob, err := atxs.GetBlob(context.Background(), tab.cdb, refAtx.ID().Bytes())
require.NoError(t, err)

// atx will be regossiped once (by the smesher)
Expand Down
4 changes: 2 additions & 2 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx)
}

// GetEpochAtxs returns all valid ATXs received in the epoch epochID.
func (h *Handler) GetEpochAtxs(epochID types.EpochID) (ids []types.ATXID, err error) {
ids, err = atxs.GetIDsByEpoch(h.cdb, epochID)
func (h *Handler) GetEpochAtxs(ctx context.Context, epochID types.EpochID) (ids []types.ATXID, err error) {
ids, err = atxs.GetIDsByEpoch(ctx, h.cdb, epochID)
h.log.With().Debug("returned epoch atxs", epochID,
log.Int("count", len(ids)),
log.String("atxs", fmt.Sprint(ids)))
Expand Down
3 changes: 2 additions & 1 deletion cmd/activeset/activeset.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"errors"
"flag"
"fmt"
Expand Down Expand Up @@ -32,7 +33,7 @@ Example:
db, err := sql.Open("file:" + dbpath)
must(err, "can't open db at dbpath=%v. err=%s\n", dbpath, err)

ids, err := atxs.GetIDsByEpoch(db, types.EpochID(publish))
ids, err := atxs.GetIDsByEpoch(context.Background(), db, types.EpochID(publish))
must(err, "get ids by epoch %d. dbpath=%v. err=%s\n", publish, dbpath, err)
var weight uint64
for _, id := range ids {
Expand Down
8 changes: 4 additions & 4 deletions datastore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (db *CachedDB) IterateEpochATXHeaders(
epoch types.EpochID,
iter func(*types.ActivationTxHeader) error,
) error {
ids, err := atxs.GetIDsByEpoch(db, epoch-1)
ids, err := atxs.GetIDsByEpoch(context.Background(), db, epoch-1)
if err != nil {
return err
}
Expand Down Expand Up @@ -370,10 +370,10 @@ type BlobStore struct {
}

// Get gets an ATX as bytes by an ATX ID as bytes.
func (bs *BlobStore) Get(hint Hint, key []byte) ([]byte, error) {
func (bs *BlobStore) Get(ctx context.Context, hint Hint, key []byte) ([]byte, error) {
switch hint {
case ATXDB:
return atxs.GetBlob(bs.DB, key)
return atxs.GetBlob(ctx, bs.DB, key)
case ProposalDB:
return bs.proposals.GetBlob(types.ProposalID(types.BytesToHash(key).ToHash20()))
case BallotDB:
Expand Down Expand Up @@ -407,7 +407,7 @@ func (bs *BlobStore) Get(hint Hint, key []byte) ([]byte, error) {
case Malfeasance:
return identities.GetMalfeasanceBlob(bs.DB, key)
case ActiveSet:
return activesets.GetBlob(bs.DB, key)
return activesets.GetBlob(ctx, bs.DB, key)
}
return nil, fmt.Errorf("blob store not found %s", hint)
}
Expand Down
52 changes: 31 additions & 21 deletions datastore/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datastore_test

import (
"bytes"
"context"
"os"
"testing"
"time"
Expand Down Expand Up @@ -196,6 +197,7 @@ func TestStore_GetAtxByNodeID(t *testing.T) {
func TestBlobStore_GetATXBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

atx := &types.ActivationTx{
InnerActivationTx: types.InnerActivationTx{
Expand All @@ -218,14 +220,15 @@ func TestBlobStore_GetATXBlob(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.ATXDB, atx.ID().Bytes())
_, err = bs.Get(ctx, datastore.ATXDB, atx.ID().Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, atxs.Add(db, vAtx))

has, err = bs.Has(datastore.ATXDB, atx.ID().Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.ATXDB, atx.ID().Bytes())
got, err := bs.Get(ctx, datastore.ATXDB, atx.ID().Bytes())
require.NoError(t, err)

var gotA types.ActivationTx
Expand All @@ -235,13 +238,14 @@ func TestBlobStore_GetATXBlob(t *testing.T) {
gotA.SetReceived(atx.Received())
require.Equal(t, *atx, gotA)

_, err = bs.Get(datastore.BallotDB, atx.ID().Bytes())
_, err = bs.Get(ctx, datastore.BallotDB, atx.ID().Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)
}

func TestBlobStore_GetBallotBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

sig, err := signing.NewEdSigner()
require.NoError(t, err)
Expand All @@ -254,28 +258,29 @@ func TestBlobStore_GetBallotBlob(t *testing.T) {
has, err := bs.Has(datastore.BallotDB, blt.ID().Bytes())
require.NoError(t, err)
require.False(t, has)
_, err = bs.Get(datastore.BallotDB, blt.ID().Bytes())
_, err = bs.Get(ctx, datastore.BallotDB, blt.ID().Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, ballots.Add(db, blt))
has, err = bs.Has(datastore.BallotDB, blt.ID().Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.BallotDB, blt.ID().Bytes())
got, err := bs.Get(ctx, datastore.BallotDB, blt.ID().Bytes())
require.NoError(t, err)
var gotB types.Ballot
require.NoError(t, codec.Decode(got, &gotB))

require.NoError(t, gotB.Initialize())
require.Equal(t, *blt, gotB)

_, err = bs.Get(datastore.BlockDB, blt.ID().Bytes())
_, err = bs.Get(ctx, datastore.BlockDB, blt.ID().Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)
}

func TestBlobStore_GetBlockBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

blk := types.Block{
InnerBlock: types.InnerBlock{
Expand All @@ -289,27 +294,28 @@ func TestBlobStore_GetBlockBlob(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.BlockDB, blk.ID().Bytes())
_, err = bs.Get(ctx, datastore.BlockDB, blk.ID().Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, blocks.Add(db, &blk))
has, err = bs.Has(datastore.BlockDB, blk.ID().Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.BlockDB, blk.ID().Bytes())
got, err := bs.Get(ctx, datastore.BlockDB, blk.ID().Bytes())
require.NoError(t, err)
var gotB types.Block
require.NoError(t, codec.Decode(got, &gotB))
gotB.Initialize()
require.Equal(t, blk, gotB)

_, err = bs.Get(datastore.ProposalDB, blk.ID().Bytes())
_, err = bs.Get(ctx, datastore.ProposalDB, blk.ID().Bytes())
require.ErrorIs(t, err, store.ErrNotFound)
}

func TestBlobStore_GetPoetBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

ref := []byte("ref0")
poet := []byte("proof0")
Expand All @@ -320,26 +326,27 @@ func TestBlobStore_GetPoetBlob(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.POETDB, ref)
_, err = bs.Get(ctx, datastore.POETDB, ref)
require.ErrorIs(t, err, sql.ErrNotFound)
var poetRef types.PoetProofRef
copy(poetRef[:], ref)
require.NoError(t, poets.Add(db, poetRef, poet, sid, rid))
has, err = bs.Has(datastore.POETDB, ref)
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.POETDB, ref)
got, err := bs.Get(ctx, datastore.POETDB, ref)
require.NoError(t, err)
require.True(t, bytes.Equal(poet, got))

_, err = bs.Get(datastore.BlockDB, ref)
_, err = bs.Get(ctx, datastore.BlockDB, ref)
require.ErrorIs(t, err, sql.ErrNotFound)
}

func TestBlobStore_GetProposalBlob(t *testing.T) {
db := sql.InMemory()
proposals := store.New()
bs := datastore.NewBlobStore(db, proposals)
ctx := context.Background()

signer, err := signing.NewEdSigner()
require.NoError(t, err)
Expand All @@ -358,14 +365,14 @@ func TestBlobStore_GetProposalBlob(t *testing.T) {
has, err := bs.Has(datastore.ProposalDB, p.ID().Bytes())
require.NoError(t, err)
require.False(t, has)
_, err = bs.Get(datastore.ProposalDB, p.ID().Bytes())
_, err = bs.Get(ctx, datastore.ProposalDB, p.ID().Bytes())
require.ErrorIs(t, err, store.ErrNotFound)

require.NoError(t, proposals.Add(&p))
has, err = bs.Has(datastore.ProposalDB, p.ID().Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.ProposalDB, p.ID().Bytes())
got, err := bs.Get(ctx, datastore.ProposalDB, p.ID().Bytes())
require.NoError(t, err)
var gotP types.Proposal
require.NoError(t, codec.Decode(got, &gotP))
Expand All @@ -376,6 +383,7 @@ func TestBlobStore_GetProposalBlob(t *testing.T) {
func TestBlobStore_GetTXBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

tx := &types.Transaction{}
tx.Raw = []byte{1, 1, 1}
Expand All @@ -385,24 +393,25 @@ func TestBlobStore_GetTXBlob(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.TXDB, tx.ID.Bytes())
_, err = bs.Get(ctx, datastore.TXDB, tx.ID.Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, transactions.Add(db, tx, time.Now()))
has, err = bs.Has(datastore.TXDB, tx.ID.Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.TXDB, tx.ID.Bytes())
got, err := bs.Get(ctx, datastore.TXDB, tx.ID.Bytes())
require.NoError(t, err)
require.Equal(t, tx.Raw, got)

_, err = bs.Get(datastore.BlockDB, tx.ID.Bytes())
_, err = bs.Get(ctx, datastore.BlockDB, tx.ID.Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)
}

func TestBlobStore_GetMalfeasanceBlob(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

proof := &types.MalfeasanceProof{
Layer: types.LayerID(11),
Expand All @@ -421,21 +430,22 @@ func TestBlobStore_GetMalfeasanceBlob(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.Malfeasance, nodeID.Bytes())
_, err = bs.Get(ctx, datastore.Malfeasance, nodeID.Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, identities.SetMalicious(db, nodeID, encoded, time.Now()))
has, err = bs.Has(datastore.Malfeasance, nodeID.Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.Malfeasance, nodeID.Bytes())
got, err := bs.Get(ctx, datastore.Malfeasance, nodeID.Bytes())
require.NoError(t, err)
require.Equal(t, encoded, got)
}

func TestBlobStore_GetActiveSet(t *testing.T) {
db := sql.InMemory()
bs := datastore.NewBlobStore(db, store.New())
ctx := context.Background()

as := &types.EpochActiveSet{Epoch: 7}
hash := types.ATXIDList(as.Set).Hash()
Expand All @@ -444,14 +454,14 @@ func TestBlobStore_GetActiveSet(t *testing.T) {
require.NoError(t, err)
require.False(t, has)

_, err = bs.Get(datastore.ActiveSet, hash.Bytes())
_, err = bs.Get(ctx, datastore.ActiveSet, hash.Bytes())
require.ErrorIs(t, err, sql.ErrNotFound)

require.NoError(t, activesets.Add(db, hash, as))
has, err = bs.Has(datastore.ActiveSet, hash.Bytes())
require.NoError(t, err)
require.True(t, has)
got, err := bs.Get(datastore.ActiveSet, hash.Bytes())
got, err := bs.Get(ctx, datastore.ActiveSet, hash.Bytes())
require.NoError(t, err)
require.Equal(t, codec.MustEncode(as), got)
}
6 changes: 3 additions & 3 deletions fetch/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func (h *handler) handleEpochInfoReq(ctx context.Context, msg []byte) ([]byte, e
}

cacheKey := sql.QueryCacheKey(atxs.CacheKindEpochATXs, epoch.String())
return sql.WithCachedSubKey(h.cdb, cacheKey, fetchSubKey, func() ([]byte, error) {
atxids, err := atxs.GetIDsByEpoch(h.cdb, epoch)
return sql.WithCachedSubKey(ctx, h.cdb, cacheKey, fetchSubKey, func(ctx context.Context) ([]byte, error) {
atxids, err := atxs.GetIDsByEpoch(ctx, h.cdb, epoch)
if err != nil {
h.logger.With().Warning("serve: failed to get epoch atx IDs",
epoch, log.Err(err), log.Context(ctx))
Expand Down Expand Up @@ -189,7 +189,7 @@ func (h *handler) handleHashReq(ctx context.Context, data []byte) ([]byte, error
// be included in the response at all
for _, r := range requestBatch.Requests {
totalHashReqs.WithLabelValues(string(r.Hint)).Add(1)
res, err := h.bs.Get(r.Hint, r.Hash.Bytes())
res, err := h.bs.Get(ctx, r.Hint, r.Hash.Bytes())
if err != nil {
h.logger.With().Debug("serve: remote peer requested nonexistent hash",
log.Context(ctx),
Expand Down
5 changes: 3 additions & 2 deletions sql/activesets/activesets.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package activesets

import (
"context"
"fmt"

"github.com/spacemeshos/go-spacemesh/codec"
Expand Down Expand Up @@ -53,9 +54,9 @@ func Get(db sql.Executor, id types.Hash32) (*types.EpochActiveSet, error) {
return &rst, nil
}

func GetBlob(db sql.Executor, id []byte) ([]byte, error) {
func GetBlob(ctx context.Context, db sql.Executor, id []byte) ([]byte, error) {
cacheKey := sql.QueryCacheKey(CacheKindActiveSetBlob, string(id))
return sql.WithCachedValue(db, cacheKey, func() ([]byte, error) {
return sql.WithCachedValue(ctx, db, cacheKey, func(context.Context) ([]byte, error) {
var rst []byte
rows, err := db.Exec("select active_set from activesets where id = ?1;",
func(stmt *sql.Statement) {
Expand Down
Loading

0 comments on commit efea4b2

Please sign in to comment.