From 9e760caafa8e0bf4b203e1036cc8bc3ea20dbb65 Mon Sep 17 00:00:00 2001 From: Alec Sammon Date: Wed, 22 May 2024 09:04:10 +0100 Subject: [PATCH] Fix race errors with memory tables (#3) --- memory/database.go | 57 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/memory/database.go b/memory/database.go index 6464239519..c2b0fff555 100644 --- a/memory/database.go +++ b/memory/database.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/fulltext" "github.com/dolthub/go-mysql-server/sql/types" + "sync" ) // Database is an in-memory database. @@ -34,6 +35,7 @@ type Database struct { type MemoryDatabase interface { sql.Database AddTable(name string, t MemTable) + DeleteTable(name string) Database() *BaseDatabase } @@ -60,6 +62,7 @@ type BaseDatabase struct { events []sql.EventDefinition primaryKeyIndexes bool collation sql.CollationID + tablesMu *sync.RWMutex } var _ MemoryDatabase = (*Database)(nil) @@ -76,9 +79,10 @@ func NewDatabase(name string) *Database { // NewViewlessDatabase creates a new database that doesn't persist views. Used only for testing. Use NewDatabase. func NewViewlessDatabase(name string) *BaseDatabase { return &BaseDatabase{ - name: name, - tables: map[string]MemTable{}, - fkColl: newForeignKeyCollection(), + name: name, + tables: map[string]MemTable{}, + fkColl: newForeignKeyCollection(), + tablesMu: &sync.RWMutex{}, } } @@ -103,6 +107,8 @@ func (d *BaseDatabase) Name() string { // Tables returns all tables in the database. func (d *BaseDatabase) Tables() map[string]sql.Table { + d.tablesMu.RLock() + defer d.tablesMu.RUnlock() tables := make(map[string]sql.Table, len(d.tables)) for name, table := range d.tables { tables[name] = table @@ -133,17 +139,23 @@ func (d *BaseDatabase) GetTableInsensitive(ctx *sql.Context, tblName string) (sq // putTable writes the table given into database storage. A table with this name must already be present. func (d *BaseDatabase) putTable(t *Table) { lowerName := strings.ToLower(t.name) + d.tablesMu.RLock() for name, table := range d.tables { if strings.ToLower(name) == lowerName { t.name = table.Name() - d.tables[name] = t + d.tablesMu.RUnlock() + d.AddTable(name, t) return } } + d.tablesMu.RUnlock() panic(fmt.Sprintf("table %s not found", t.name)) } func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) { + d.tablesMu.RLock() + defer d.tablesMu.RUnlock() + tblNames := make([]string, 0, len(d.tables)) for k := range d.tables { tblNames = append(tblNames, k) @@ -153,6 +165,9 @@ func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) { } func (d *BaseDatabase) CreateFulltextTableNames(ctx *sql.Context, parentTableName string, parentIndexName string) (fulltext.IndexTableNames, error) { + d.tablesMu.RLock() + defer d.tablesMu.RUnlock() + var tablePrefix string OuterLoop: for i := uint64(0); true; i++ { @@ -225,17 +240,28 @@ func (db *HistoryDatabase) AddTableAsOf(name string, t sql.Table, asOf interface } db.Revisions[strings.ToLower(name)][asOf] = t - db.tables[name] = t.(MemTable) + db.AddTable(name, t.(MemTable)) } // AddTable adds a new table to the database. func (d *BaseDatabase) AddTable(name string, t MemTable) { + d.tablesMu.Lock() + defer d.tablesMu.Unlock() d.tables[name] = t } +// DeleteTable deletes a table from the database. +func (d *BaseDatabase) DeleteTable(name string) { + d.tablesMu.Lock() + defer d.tablesMu.Unlock() + delete(d.tables, name) +} + // CreateTable creates a table with the given name and schema func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { + d.tablesMu.RLock() _, ok := d.tables[name] + d.tablesMu.RUnlock() if ok { return sql.ErrTableAlreadyExists.New(name) } @@ -247,7 +273,7 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri table.EnablePrimaryKeyIndexes() } - d.tables[name] = table + d.AddTable(name, table) sess := SessionFromContext(ctx) sess.putTable(table.data) @@ -256,7 +282,9 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri // CreateIndexedTable creates a table with the given name and schema func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error { + d.tablesMu.RLock() _, ok := d.tables[name] + d.tablesMu.RUnlock() if ok { return sql.ErrTableAlreadyExists.New(name) } @@ -278,31 +306,40 @@ func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql } } - d.tables[name] = table + d.AddTable(name, table) return nil } // DropTable drops the table with the given name func (d *BaseDatabase) DropTable(ctx *sql.Context, name string) error { + d.tablesMu.RLock() t, ok := d.tables[name] + d.tablesMu.RUnlock() + if !ok { return sql.ErrTableNotFound.New(name) } SessionFromContext(ctx).dropTable(t.(*Table).data) - delete(d.tables, name) + d.DeleteTable(name) return nil } func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) error { + d.tablesMu.RLock() tbl, ok := d.tables[oldName] + d.tablesMu.RUnlock() + if !ok { // Should be impossible (engine already checks this condition) return sql.ErrTableNotFound.New(oldName) } + d.tablesMu.RLock() _, ok = d.tables[newName] + d.tablesMu.RUnlock() + if ok { return sql.ErrTableAlreadyExists.New(newName) } @@ -327,8 +364,8 @@ func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) er } memTbl.data.tableName = newName - d.tables[newName] = memTbl - delete(d.tables, oldName) + d.AddTable(newName, memTbl) + d.DeleteTable(oldName) sess.putTable(memTbl.data) return nil