diff --git a/libpod/sqlite_state.go b/libpod/sqlite_state.go index 8289a31d32..d435f10727 100644 --- a/libpod/sqlite_state.go +++ b/libpod/sqlite_state.go @@ -78,18 +78,11 @@ func NewSqliteState(runtime *Runtime) (_ State, defErr error) { } }() - state.conn = conn - - // Migrate schema (if necessary) - if err := state.migrateSchemaIfNecessary(); err != nil { + if err := initSQLiteDB(conn); err != nil { return nil, err } - // Set up tables - if err := sqliteInitTables(state.conn); err != nil { - return nil, fmt.Errorf("creating tables: %w", err) - } - + state.conn = conn state.valid = true state.runtime = runtime diff --git a/libpod/sqlite_state_internal.go b/libpod/sqlite_state_internal.go index 184b96971e..7143978a3a 100644 --- a/libpod/sqlite_state_internal.go +++ b/libpod/sqlite_state_internal.go @@ -15,52 +15,85 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func (s *SQLiteState) migrateSchemaIfNecessary() (defErr error) { +func initSQLiteDB(conn *sql.DB) (defErr error) { + // Start with a transaction to avoid "database locked" errors. + // See https://github.com/mattn/go-sqlite3/issues/274#issuecomment-1429054597 + tx, err := conn.Begin() + if err != nil { + return fmt.Errorf("beginning transaction: %w", err) + } + defer func() { + if defErr != nil { + if err := tx.Rollback(); err != nil { + logrus.Errorf("Rolling back transaction to create tables: %v", err) + } + } + }() + + sameSchema, err := migrateSchemaIfNecessary(tx) + if err != nil { + return err + } + if !sameSchema { + if err := createSQLiteTables(tx); err != nil { + return err + } + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + return nil +} + +func migrateSchemaIfNecessary(tx *sql.Tx) (bool, error) { // First, check if the DBConfig table exists - checkRow := s.conn.QueryRow("SELECT 1 FROM sqlite_master WHERE type='table' AND name='DBConfig';") + checkRow := tx.QueryRow("SELECT 1 FROM sqlite_master WHERE type='table' AND name='DBConfig';") var check int if err := checkRow.Scan(&check); err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil + return false, nil } - return fmt.Errorf("checking if DB config table exists: %w", err) + return false, fmt.Errorf("checking if DB config table exists: %w", err) } if check != 1 { // Table does not exist, fresh database, no need to migrate. - return nil + return false, nil } - row := s.conn.QueryRow("SELECT SchemaVersion FROM DBConfig;") + row := tx.QueryRow("SELECT SchemaVersion FROM DBConfig;") var schemaVer int if err := row.Scan(&schemaVer); err != nil { if errors.Is(err, sql.ErrNoRows) { // Brand-new, unpopulated DB. // Schema was just created, so it has to be the latest. - return nil + return false, nil } - return fmt.Errorf("scanning schema version from DB config: %w", err) + return false, fmt.Errorf("scanning schema version from DB config: %w", err) } // If the schema version 0 or less, it's invalid if schemaVer <= 0 { - return fmt.Errorf("database schema version %d is invalid: %w", schemaVer, define.ErrInternal) + return false, fmt.Errorf("database schema version %d is invalid: %w", schemaVer, define.ErrInternal) } - if schemaVer != schemaVersion { - // If the DB is a later schema than we support, we have to error - if schemaVer > schemaVersion { - return fmt.Errorf("database has schema version %d while this libpod version only supports version %d: %w", - schemaVer, schemaVersion, define.ErrInternal) - } + // Same schema -> nothing do to. + if schemaVer == schemaVersion { + return true, nil + } - // Perform schema migration here, one version at a time. + // If the DB is a later schema than we support, we have to error + if schemaVer > schemaVersion { + return false, fmt.Errorf("database has schema version %d while this libpod version only supports version %d: %w", + schemaVer, schemaVersion, define.ErrInternal) } - return nil + // Perform schema migration here, one version at a time. + + return false, nil } // Initialize all required tables for the SQLite state -func sqliteInitTables(conn *sql.DB) (defErr error) { +func createSQLiteTables(tx *sql.Tx) error { // Technically we could split the "CREATE TABLE IF NOT EXISTS" and ");" // bits off each command and add them in the for loop where we actually // run the SQL, but that seems unnecessary. @@ -186,28 +219,11 @@ func sqliteInitTables(conn *sql.DB) (defErr error) { "VolumeState": volumeState, } - tx, err := conn.Begin() - if err != nil { - return fmt.Errorf("beginning transaction: %w", err) - } - defer func() { - if defErr != nil { - if err := tx.Rollback(); err != nil { - logrus.Errorf("Rolling back transaction to create tables: %v", err) - } - } - }() - for tblName, cmd := range tables { if _, err := tx.Exec(cmd); err != nil { return fmt.Errorf("creating table %s: %w", tblName, err) } } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - return nil }