diff --git a/api/api.go b/api/api.go index 88e25ee6..a28987e9 100644 --- a/api/api.go +++ b/api/api.go @@ -1,11 +1,10 @@ package api import ( - "database/sql" "encoding/json" "github.com/vocdoni/census3/census" - queries "github.com/vocdoni/census3/db/sqlc" + "github.com/vocdoni/census3/db" "go.vocdoni.io/dvote/httprouter" api "go.vocdoni.io/dvote/httprouter/apirest" "go.vocdoni.io/dvote/log" @@ -21,18 +20,16 @@ type Census3APIConf struct { type census3API struct { conf Census3APIConf - db *sql.DB - sqlc *queries.Queries + db *db.DB endpoint *api.API censusDB *census.CensusDB w3p map[int64]string } -func Init(db *sql.DB, q *queries.Queries, conf Census3APIConf) error { +func Init(db *db.DB, conf Census3APIConf) error { newAPI := &census3API{ conf: conf, db: db, - sqlc: q, w3p: conf.Web3Providers, } // get the current chainID diff --git a/api/censuses.go b/api/censuses.go index c7c49ec3..9bcdb037 100644 --- a/api/censuses.go +++ b/api/censuses.go @@ -39,18 +39,7 @@ func (capi *census3API) getCensus(msg *api.APIdata, ctx *httprouter.HTTPContext) } internalCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // begin a transaction for group sql queries - tx, err := capi.db.BeginTx(internalCtx, nil) - if err != nil { - return ErrCantGetCensus - } - defer func() { - if err := tx.Rollback(); err != nil { - log.Errorw(err, "holders transaction rollback failed") - } - }() - qtx := capi.sqlc.WithTx(tx) - currentCensus, err := qtx.CensusByID(internalCtx, int64(censusID)) + currentCensus, err := capi.db.QueriesRO.CensusByID(internalCtx, int64(censusID)) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNotFoundCensus @@ -90,7 +79,7 @@ func (capi *census3API) createAndPublishCensus(msg *api.APIdata, ctx *httprouter defer cancel() // begin a transaction for group sql queries - tx, err := capi.db.BeginTx(internalCtx, nil) + tx, err := capi.db.RW.BeginTx(internalCtx, nil) if err != nil { return ErrCantCreateCensus } @@ -99,7 +88,7 @@ func (capi *census3API) createAndPublishCensus(msg *api.APIdata, ctx *httprouter log.Errorw(err, "holders transaction rollback failed") } }() - qtx := capi.sqlc.WithTx(tx) + qtx := capi.db.QueriesRW.WithTx(tx) strategyTokens, err := qtx.TokensByStrategyID(internalCtx, int64(req.StrategyID)) if err != nil { @@ -210,7 +199,7 @@ func (capi *census3API) getStrategyCensuses(msg *api.APIdata, ctx *httprouter.HT // get censuses by this strategy ID internalCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - rows, err := capi.sqlc.CensusByStrategyID(internalCtx, int64(strategyID)) + rows, err := capi.db.QueriesRO.CensusByStrategyID(internalCtx, int64(strategyID)) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNotFoundCensus diff --git a/api/debug.go b/api/debug.go index bd89b082..71c6a337 100644 --- a/api/debug.go +++ b/api/debug.go @@ -37,7 +37,7 @@ func (capi *census3API) getTokenHolders(msg *api.APIdata, ctx *httprouter.HTTPCo // get token holders from the database addr := common.HexToAddress(ctx.URLParam("address")) - dbHolders, err := capi.sqlc.TokenHoldersByTokenID(ctx2, addr.Bytes()) + dbHolders, err := capi.db.QueriesRO.TokenHoldersByTokenID(ctx2, addr.Bytes()) if err != nil { // if database does not contain any token holder for this token, return // no content, else return generic error. @@ -77,7 +77,7 @@ func (capi *census3API) countHolders(msg *api.APIdata, ctx *httprouter.HTTPConte defer cancel() addr := common.HexToAddress(ctx.URLParam("address")) - numberOfHolders, err := capi.sqlc.CountTokenHoldersByTokenID(ctx2, addr.Bytes()) + numberOfHolders, err := capi.db.QueriesRO.CountTokenHoldersByTokenID(ctx2, addr.Bytes()) if err != nil { if errors.Is(sql.ErrNoRows, err) { log.Errorf("no holders found for address %s: %s", addr, err.Error()) diff --git a/api/strategy.go b/api/strategy.go index 30b63216..df69bb2a 100644 --- a/api/strategy.go +++ b/api/strategy.go @@ -37,7 +37,7 @@ func (capi *census3API) initStrategiesHandlers() error { func (capi *census3API) createDummyStrategy(tokenID []byte) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - res, err := capi.sqlc.CreateStategy(ctx, "test") + res, err := capi.db.QueriesRW.CreateStategy(ctx, "test") if err != nil { return err } @@ -45,7 +45,7 @@ func (capi *census3API) createDummyStrategy(tokenID []byte) error { if err != nil { return err } - _, err = capi.sqlc.CreateStrategyToken(ctx, queries.CreateStrategyTokenParams{ + _, err = capi.db.QueriesRW.CreateStrategyToken(ctx, queries.CreateStrategyTokenParams{ StrategyID: strategyID, TokenID: tokenID, MinBalance: big.NewInt(0).Bytes(), @@ -62,7 +62,7 @@ func (capi *census3API) getStrategies(msg *api.APIdata, ctx *httprouter.HTTPCont defer cancel() // TODO: Support for pagination // get strategies from the database - rows, err := capi.sqlc.ListStrategies(internalCtx) + rows, err := capi.db.QueriesRO.ListStrategies(internalCtx) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNoStrategies @@ -100,7 +100,7 @@ func (capi *census3API) getStrategy(msg *api.APIdata, ctx *httprouter.HTTPContex // get strategy from the database internalCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - strategyData, err := capi.sqlc.StrategyByID(internalCtx, int64(strategyID)) + strategyData, err := capi.db.QueriesRO.StrategyByID(internalCtx, int64(strategyID)) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNotFoundStrategy @@ -115,7 +115,7 @@ func (capi *census3API) getStrategy(msg *api.APIdata, ctx *httprouter.HTTPContex Tokens: []GetStrategyToken{}, } // get information of the strategy related tokens - tokensData, err := capi.sqlc.TokensByStrategyID(internalCtx, strategyData.ID) + tokensData, err := capi.db.QueriesRO.TokensByStrategyID(internalCtx, strategyData.ID) if err != nil && !errors.Is(err, sql.ErrNoRows) { log.Errorw(ErrCantGetTokens, err.Error()) return ErrCantGetTokens @@ -147,7 +147,7 @@ func (capi *census3API) getTokenStrategies(msg *api.APIdata, ctx *httprouter.HTT internalCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() // get strategies associated to the token provided - rows, err := capi.sqlc.StrategiesByTokenID(internalCtx, common.HexToAddress(tokenID).Bytes()) + rows, err := capi.db.QueriesRO.StrategiesByTokenID(internalCtx, common.HexToAddress(tokenID).Bytes()) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNoStrategies diff --git a/api/tokens.go b/api/tokens.go index 78110c3b..2e1078d5 100644 --- a/api/tokens.go +++ b/api/tokens.go @@ -42,7 +42,7 @@ func (capi *census3API) getTokens(msg *api.APIdata, ctx *httprouter.HTTPContext) defer cancel() // TODO: Support for pagination // get tokens from the database - rows, err := capi.sqlc.ListTokens(internalCtx) + rows, err := capi.db.QueriesRO.ListTokens(internalCtx) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNoTokens @@ -134,7 +134,7 @@ func (capi *census3API) createToken(msg *api.APIdata, ctx *httprouter.HTTPContex return ErrCantGetToken } } - _, err = capi.sqlc.CreateToken(internalCtx, queries.CreateTokenParams{ + _, err = capi.db.QueriesRW.CreateToken(internalCtx, queries.CreateTokenParams{ ID: info.Address.Bytes(), Name: *name, Symbol: *symbol, @@ -168,7 +168,7 @@ func (capi *census3API) getToken(msg *api.APIdata, ctx *httprouter.HTTPContext) address := common.HexToAddress(ctx.URLParam("tokenID")) internalCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - tokenData, err := capi.sqlc.TokenByID(internalCtx, address.Bytes()) + tokenData, err := capi.db.QueriesRO.TokenByID(internalCtx, address.Bytes()) if err != nil { if errors.Is(err, sql.ErrNoRows) { log.Errorw(ErrNotFoundToken, err.Error()) @@ -178,7 +178,7 @@ func (capi *census3API) getToken(msg *api.APIdata, ctx *httprouter.HTTPContext) return ErrCantGetToken } // TODO: Only for the MVP, consider to remove it - tokenStrategies, err := capi.sqlc.StrategiesByTokenID(internalCtx, tokenData.ID) + tokenStrategies, err := capi.db.QueriesRO.StrategiesByTokenID(internalCtx, tokenData.ID) if err != nil && !errors.Is(err, sql.ErrNoRows) { log.Errorw(ErrCantGetToken, err.Error()) return ErrCantGetToken @@ -188,7 +188,7 @@ func (capi *census3API) getToken(msg *api.APIdata, ctx *httprouter.HTTPContext) defaultStrategyID = uint64(tokenStrategies[0].ID) } // get last block with token information - atBlock, err := capi.sqlc.LastBlockByTokenID(internalCtx, address.Bytes()) + atBlock, err := capi.db.QueriesRO.LastBlockByTokenID(internalCtx, address.Bytes()) if err != nil { if !errors.Is(err, sql.ErrNoRows) { log.Errorw(ErrCantGetToken, err.Error()) @@ -221,7 +221,7 @@ func (capi *census3API) getToken(msg *api.APIdata, ctx *httprouter.HTTPContext) // get token holders count countHoldersCtx, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel2() - holders, err := capi.sqlc.CountTokenHoldersByTokenID(countHoldersCtx, address.Bytes()) + holders, err := capi.db.QueriesRO.CountTokenHoldersByTokenID(countHoldersCtx, address.Bytes()) if err != nil { return ErrCantGetTokenCount } diff --git a/cmd/census3/main.go b/cmd/census3/main.go index 432bab31..51bfd347 100644 --- a/cmd/census3/main.go +++ b/cmd/census3/main.go @@ -30,7 +30,7 @@ func main() { flag.Parse() log.Init(*logLevel, "stdout", nil) - db, q, err := db.Init(*dataDir) + database, err := db.Init(*dataDir) if err != nil { log.Fatal(err) } @@ -43,13 +43,13 @@ func main() { log.Info(w3p) // Start the holder scanner - hc, err := service.NewHoldersScanner(db, q, w3p) + hc, err := service.NewHoldersScanner(database, w3p) if err != nil { log.Fatal(err) } // Start the API - err = api.Init(db, q, api.Census3APIConf{ + err = api.Init(database, api.Census3APIConf{ Hostname: "0.0.0.0", Port: *port, DataDir: *dataDir, @@ -69,6 +69,12 @@ func main() { log.Warnf("received SIGTERM, exiting at %s", time.Now().Format(time.RFC850)) cancel() log.Infof("waiting for routines to end gracefully...") + // closing database + go func() { + if err := database.Close(); err != nil { + log.Fatal(err) + } + }() time.Sleep(5 * time.Second) os.Exit(0) } diff --git a/db/db.go b/db/db.go index a8cacc79..5b583fba 100644 --- a/db/db.go +++ b/db/db.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "time" _ "github.com/mattn/go-sqlite3" "github.com/pressly/goose/v3" @@ -15,27 +16,73 @@ import ( //go:embed migrations/*.sql var migrationsFS embed.FS -func Init(dataDir string) (*sql.DB, *queries.Queries, error) { +// DB struct abstact a safe connection with the database using sqlc queries, +// sqlite as a database engine and go-sqlite3 as a driver. +type DB struct { + RW *sql.DB + RO *sql.DB + + QueriesRW *queries.Queries + QueriesRO *queries.Queries +} + +// Close function stops all internal connections to the database +func (db *DB) Close() error { + if err := db.RW.Close(); err != nil { + return err + } + return db.RO.Close() +} + +// Init function starts a database using the data path provided as argument. It +// opens two different connections, one for read only, and another for read and +// write, with different configurations, optimized for each use case. +func Init(dataDir string) (*DB, error) { dbFile := filepath.Join(dataDir, "census3.sql") if _, err := os.Stat(dbFile); os.IsNotExist(err) { if err := os.MkdirAll(dataDir, os.ModePerm); err != nil { - return nil, nil, fmt.Errorf("error creating a new database file: %w", err) + return nil, fmt.Errorf("error creating a new database file: %w", err) } } - // open database file - database, err := sql.Open("sqlite3", dbFile) + // sqlite doesn't support multiple concurrent writers. + // For that reason, rwDB is limited to one open connection. + // Per https://github.com/mattn/go-sqlite3/issues/1022#issuecomment-1067353980, + // we use WAL to allow multiple concurrent readers at the same time. + rwDB, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=rwc&_journal_mode=wal&_txlock=immediate&_synchronous=normal", dbFile)) if err != nil { - return nil, nil, fmt.Errorf("error opening database: %w", err) + return nil, fmt.Errorf("error opening database: %w", err) } + rwDB.SetMaxOpenConns(1) + rwDB.SetMaxIdleConns(2) + rwDB.SetConnMaxIdleTime(10 * time.Minute) + rwDB.SetConnMaxLifetime(time.Hour) + + roDB, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=ro&_journal_mode=wal", dbFile)) + if err != nil { + return nil, fmt.Errorf("error opening database: %w", err) + } + // Increasing these numbers can allow for more queries to run concurrently, + // but it also increases the memory used by sqlite and our connection pool. + // Most read-only queries we run are quick enough, so a small number seems OK. + roDB.SetMaxOpenConns(10) + roDB.SetMaxIdleConns(20) + roDB.SetConnMaxIdleTime(5 * time.Minute) + roDB.SetConnMaxLifetime(time.Hour) + // get census3 goose migrations and setup for sqlite3 if err := goose.SetDialect("sqlite3"); err != nil { - return nil, nil, fmt.Errorf("error setting up driver for sqlite: %w", err) + return nil, fmt.Errorf("error setting up driver for sqlite: %w", err) } goose.SetBaseFS(migrationsFS) // perform goose up - if err := goose.Up(database, "migrations"); err != nil { - return nil, nil, fmt.Errorf("error during goose up: %w", err) + if err := goose.Up(rwDB, "migrations"); err != nil { + return nil, fmt.Errorf("error during goose up: %w", err) } // init sqlc - return database, queries.New(database), nil + return &DB{ + RW: rwDB, + RO: roDB, + QueriesRW: queries.New(rwDB), + QueriesRO: queries.New(roDB), + }, nil } diff --git a/service/helper_test.go b/service/helper_test.go index 614bdc40..90bdf95f 100644 --- a/service/helper_test.go +++ b/service/helper_test.go @@ -43,23 +43,23 @@ var ( ) type TestDB struct { - dir string - db *sql.DB - queries *queries.Queries + dir string + db *db.DB } func StartTestDB(t *testing.T) *TestDB { c := qt.New(t) dir := t.TempDir() - db, q, err := db.Init(dir) + db, err := db.Init(dir) c.Assert(err, qt.IsNil) - return &TestDB{dir, db, q} + return &TestDB{dir, db} } func (testdb *TestDB) Close(t *testing.T) { c := qt.New(t) - c.Assert(testdb.db.Close(), qt.IsNil) + c.Assert(testdb.db.RW.Close(), qt.IsNil) + c.Assert(testdb.db.RO.Close(), qt.IsNil) c.Assert(os.RemoveAll(testdb.dir), qt.IsNil) } diff --git a/service/holder_scanner_test.go b/service/holder_scanner_test.go index 1db1d242..ae63e1da 100644 --- a/service/holder_scanner_test.go +++ b/service/holder_scanner_test.go @@ -25,24 +25,24 @@ func TestNewHolderScanner(t *testing.T) { w3p, err := state.CheckWeb3Providers([]string{web3uri}) c.Assert(err, qt.IsNil) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) c.Assert(hs.lastBlock, qt.Equals, uint64(0)) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - _, err = testdb.queries.CreateBlock(ctx, queries.CreateBlockParams{ + _, err = testdb.db.QueriesRW.CreateBlock(ctx, queries.CreateBlockParams{ ID: 1000, Timestamp: "test", RootHash: []byte("test"), }) c.Assert(err, qt.IsNil) - hs, err = NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err = NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) c.Assert(hs.lastBlock, qt.Equals, uint64(1000)) - _, err = NewHoldersScanner(nil, nil, w3p) + _, err = NewHoldersScanner(nil, w3p) c.Assert(err, qt.IsNotNil) } @@ -58,7 +58,7 @@ func TestHolderScannerStart(t *testing.T) { defer testdb.Close(t) twg.Add(1) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) go func() { hs.Start(ctx) @@ -78,7 +78,7 @@ func Test_tokenAddresses(t *testing.T) { w3p, err := state.CheckWeb3Providers([]string{web3uri}) c.Assert(err, qt.IsNil) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) res, err := hs.tokenAddresses() @@ -87,7 +87,7 @@ func Test_tokenAddresses(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - _, err = testdb.queries.CreateToken(ctx, testTokenParams("0x1", "test0", + _, err = testdb.db.QueriesRW.CreateToken(ctx, testTokenParams("0x1", "test0", "test0", MonkeysDecimals, 0, MonkeysTotalSupply.Uint64(), uint64(state.CONTRACT_TYPE_ERC20), false, 5)) c.Assert(err, qt.IsNil) @@ -96,7 +96,7 @@ func Test_tokenAddresses(t *testing.T) { c.Assert(err, qt.IsNil) c.Assert(res[common.HexToAddress("0x1")], qt.IsFalse) - _, err = testdb.queries.CreateToken(ctx, testTokenParams("0x2", "test2", + _, err = testdb.db.QueriesRW.CreateToken(ctx, testTokenParams("0x2", "test2", "test3", MonkeysDecimals, 10, MonkeysTotalSupply.Uint64(), uint64(state.CONTRACT_TYPE_ERC20), false, 5)) c.Assert(err, qt.IsNil) @@ -115,13 +115,13 @@ func Test_saveHolders(t *testing.T) { w3p, err := state.CheckWeb3Providers([]string{web3uri}) c.Assert(err, qt.IsNil) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) th := new(state.TokenHolders).Init(MonkeysAddress, state.CONTRACT_TYPE_ERC20, MonkeysCreationBlock, 5) // no registered token c.Assert(hs.saveHolders(th), qt.ErrorIs, ErrTokenNotExists) - _, err = testdb.queries.CreateToken(context.Background(), testTokenParams( + _, err = testdb.db.QueriesRW.CreateToken(context.Background(), testTokenParams( MonkeysAddress.String(), MonkeysName, MonkeysSymbol, MonkeysDecimals, MonkeysCreationBlock, MonkeysTotalSupply.Uint64(), uint64(state.CONTRACT_TYPE_ERC20), false, 5)) @@ -136,7 +136,7 @@ func Test_saveHolders(t *testing.T) { // check web3 c.Assert(hs.saveHolders(th), qt.IsNil) // check new holders - res, err := testdb.queries.TokenHolderByTokenIDAndHolderID(context.Background(), + res, err := testdb.db.QueriesRO.TokenHolderByTokenIDAndHolderID(context.Background(), queries.TokenHolderByTokenIDAndHolderIDParams{ TokenID: MonkeysAddress.Bytes(), HolderID: holderAddr.Bytes(), @@ -146,7 +146,7 @@ func Test_saveHolders(t *testing.T) { // check update holders th.Append(holderAddr, holderBalance) c.Assert(hs.saveHolders(th), qt.IsNil) - res, err = testdb.queries.TokenHolderByTokenIDAndHolderID(context.Background(), + res, err = testdb.db.QueriesRO.TokenHolderByTokenIDAndHolderID(context.Background(), queries.TokenHolderByTokenIDAndHolderIDParams{ TokenID: MonkeysAddress.Bytes(), HolderID: holderAddr.Bytes(), @@ -157,7 +157,7 @@ func Test_saveHolders(t *testing.T) { // check delete holders th.Append(holderAddr, big.NewInt(-24)) c.Assert(hs.saveHolders(th), qt.IsNil) - _, err = testdb.queries.TokenHolderByTokenIDAndHolderID(context.Background(), + _, err = testdb.db.QueriesRO.TokenHolderByTokenIDAndHolderID(context.Background(), queries.TokenHolderByTokenIDAndHolderIDParams{ TokenID: MonkeysAddress.Bytes(), HolderID: holderAddr.Bytes(), @@ -174,7 +174,7 @@ func Test_scanHolders(t *testing.T) { w3p, err := state.CheckWeb3Providers([]string{web3uri}) c.Assert(err, qt.IsNil) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) // token does not exists @@ -183,7 +183,7 @@ func Test_scanHolders(t *testing.T) { _, err = hs.scanHolders(ctx1, MonkeysAddress) c.Assert(err, qt.IsNotNil) - _, err = testdb.queries.CreateToken(context.Background(), testTokenParams( + _, err = testdb.db.QueriesRW.CreateToken(context.Background(), testTokenParams( MonkeysAddress.String(), MonkeysName, MonkeysSymbol, MonkeysDecimals, MonkeysCreationBlock, 10, uint64(state.CONTRACT_TYPE_ERC20), false, 5)) c.Assert(err, qt.IsNil) @@ -193,7 +193,7 @@ func Test_scanHolders(t *testing.T) { _, err = hs.scanHolders(ctx2, MonkeysAddress) c.Assert(err, qt.IsNil) - res, err := testdb.queries.TokenHoldersByTokenID(context.Background(), MonkeysAddress.Bytes()) + res, err := testdb.db.QueriesRW.TokenHoldersByTokenID(context.Background(), MonkeysAddress.Bytes()) c.Assert(err, qt.IsNil) for _, holder := range res { balance, ok := MonkeysHolders[common.BytesToAddress(holder.ID)] @@ -211,18 +211,18 @@ func Test_calcTokenCreationBlock(t *testing.T) { w3p, err := state.CheckWeb3Providers([]string{web3uri}) c.Assert(err, qt.IsNil) - hs, err := NewHoldersScanner(testdb.db, testdb.queries, w3p) + hs, err := NewHoldersScanner(testdb.db, w3p) c.Assert(err, qt.IsNil) c.Assert(hs.calcTokenCreationBlock(context.Background(), MonkeysAddress), qt.IsNotNil) - _, err = testdb.queries.CreateToken(context.Background(), testTokenParams( + _, err = testdb.db.QueriesRW.CreateToken(context.Background(), testTokenParams( MonkeysAddress.String(), MonkeysName, MonkeysSymbol, MonkeysDecimals, MonkeysCreationBlock, MonkeysTotalSupply.Uint64(), uint64(state.CONTRACT_TYPE_ERC20), false, 5)) c.Assert(err, qt.IsNil) c.Assert(hs.calcTokenCreationBlock(context.Background(), MonkeysAddress), qt.IsNil) - token, err := testdb.queries.TokenByID(context.Background(), MonkeysAddress.Bytes()) + token, err := testdb.db.QueriesRW.TokenByID(context.Background(), MonkeysAddress.Bytes()) c.Assert(err, qt.IsNil) c.Assert(uint64(token.CreationBlock.Int32), qt.Equals, MonkeysCreationBlock) } diff --git a/service/holders_scanner.go b/service/holders_scanner.go index 06a7619d..fdd42ab7 100644 --- a/service/holders_scanner.go +++ b/service/holders_scanner.go @@ -13,6 +13,7 @@ import ( "github.com/ethereum/go-ethereum/common" _ "github.com/mattn/go-sqlite3" + "github.com/vocdoni/census3/db" queries "github.com/vocdoni/census3/db/sqlc" "github.com/vocdoni/census3/state" "go.vocdoni.io/dvote/log" @@ -31,16 +32,15 @@ type HoldersScanner struct { w3p map[int64]string tokens map[common.Address]*state.TokenHolders mutex sync.RWMutex - db *sql.DB - sqlc *queries.Queries + db *db.DB lastBlock uint64 } // NewHoldersScanner function creates a new HolderScanner using the dataDir path // and the web3 endpoint URI provided. It sets up a sqlite3 database instance // and gets the number of last block scanned from it. -func NewHoldersScanner(db *sql.DB, q *queries.Queries, w3p map[int64]string) (*HoldersScanner, error) { - if db == nil || q == nil { +func NewHoldersScanner(db *db.DB, w3p map[int64]string) (*HoldersScanner, error) { + if db == nil { return nil, ErrNoDB } // create an empty scanner @@ -48,12 +48,11 @@ func NewHoldersScanner(db *sql.DB, q *queries.Queries, w3p map[int64]string) (*H w3p: w3p, tokens: make(map[common.Address]*state.TokenHolders), db: db, - sqlc: q, } // get latest analyzed block ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - lastBlock, err := s.sqlc.LastBlock(ctx) + lastBlock, err := s.db.QueriesRO.LastBlock(ctx) if err == nil { s.lastBlock = uint64(lastBlock) } @@ -120,7 +119,7 @@ func (s *HoldersScanner) tokenAddresses() (map[common.Address]bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // get tokens from the database - tokens, err := s.sqlc.ListTokens(ctx) + tokens, err := s.db.QueriesRO.ListTokens(ctx) // if error raises and is no rows error return nil results, if it is not // return the error. if err != nil { @@ -149,14 +148,14 @@ func (s *HoldersScanner) saveHolders(th *state.TokenHolders) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // begin a transaction for group sql queries - tx, err := s.db.BeginTx(ctx, nil) + tx, err := s.db.RW.BeginTx(ctx, nil) if err != nil { return err } defer func() { _ = tx.Rollback() }() - qtx := s.sqlc.WithTx(tx) + qtx := s.db.QueriesRW.WithTx(tx) if exists, err := qtx.ExistsToken(ctx, th.Address().Bytes()); err != nil { return fmt.Errorf("error checking if token exists: %w", err) } else if !exists { @@ -308,13 +307,13 @@ func (s *HoldersScanner) scanHolders(ctx context.Context, addr common.Address) ( if !ok { log.Infof("initializing contract %s", addr.Hex()) // get token information from the database - tokenInfo, err := s.sqlc.TokenByID(ctx, addr.Bytes()) + tokenInfo, err := s.db.QueriesRO.TokenByID(ctx, addr.Bytes()) if err != nil { return false, err } ttype := state.TokenType(tokenInfo.TypeID) tokenLastBlock := uint64(tokenInfo.CreationBlock.Int32) - if blockNumber, err := s.sqlc.LastBlockByTokenID(ctx, addr.Bytes()); err == nil { + if blockNumber, err := s.db.QueriesRO.LastBlockByTokenID(ctx, addr.Bytes()); err == nil { tokenLastBlock = uint64(blockNumber) } th = new(state.TokenHolders).Init(addr, ttype, tokenLastBlock, tokenInfo.ChainID) @@ -373,7 +372,7 @@ func (s *HoldersScanner) calcTokenCreationBlock(ctx context.Context, addr common ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() // get the token type - tokenInfo, err := s.sqlc.TokenByID(ctx, addr.Bytes()) + tokenInfo, err := s.db.QueriesRO.TokenByID(ctx, addr.Bytes()) if err != nil { return fmt.Errorf("error getting token from database: %w", err) } @@ -398,7 +397,7 @@ func (s *HoldersScanner) calcTokenCreationBlock(ctx context.Context, addr common return fmt.Errorf("error getting token creation block value: %w", err) } // save the creation block into the database - _, err = s.sqlc.UpdateTokenCreationBlock(ctx, + _, err = s.db.QueriesRW.UpdateTokenCreationBlock(ctx, queries.UpdateTokenCreationBlockParams{ ID: addr.Bytes(), CreationBlock: *dbCreationBlock,