Skip to content

Commit

Permalink
Fix race errors with memory tables (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
alecsammon authored May 22, 2024
1 parent 05792f2 commit 9e760ca
Showing 1 changed file with 47 additions and 10 deletions.
57 changes: 47 additions & 10 deletions memory/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/fulltext"
"github.com/dolthub/go-mysql-server/sql/types"
"sync"
)

// Database is an in-memory database.
Expand All @@ -34,6 +35,7 @@ type Database struct {
type MemoryDatabase interface {
sql.Database
AddTable(name string, t MemTable)
DeleteTable(name string)
Database() *BaseDatabase
}

Expand All @@ -60,6 +62,7 @@ type BaseDatabase struct {
events []sql.EventDefinition
primaryKeyIndexes bool
collation sql.CollationID
tablesMu *sync.RWMutex
}

var _ MemoryDatabase = (*Database)(nil)
Expand All @@ -76,9 +79,10 @@ func NewDatabase(name string) *Database {
// NewViewlessDatabase creates a new database that doesn't persist views. Used only for testing. Use NewDatabase.
func NewViewlessDatabase(name string) *BaseDatabase {
return &BaseDatabase{
name: name,
tables: map[string]MemTable{},
fkColl: newForeignKeyCollection(),
name: name,
tables: map[string]MemTable{},
fkColl: newForeignKeyCollection(),
tablesMu: &sync.RWMutex{},
}
}

Expand All @@ -103,6 +107,8 @@ func (d *BaseDatabase) Name() string {

// Tables returns all tables in the database.
func (d *BaseDatabase) Tables() map[string]sql.Table {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()
tables := make(map[string]sql.Table, len(d.tables))
for name, table := range d.tables {
tables[name] = table
Expand Down Expand Up @@ -133,17 +139,23 @@ func (d *BaseDatabase) GetTableInsensitive(ctx *sql.Context, tblName string) (sq
// putTable writes the table given into database storage. A table with this name must already be present.
func (d *BaseDatabase) putTable(t *Table) {
lowerName := strings.ToLower(t.name)
d.tablesMu.RLock()
for name, table := range d.tables {
if strings.ToLower(name) == lowerName {
t.name = table.Name()
d.tables[name] = t
d.tablesMu.RUnlock()
d.AddTable(name, t)
return
}
}
d.tablesMu.RUnlock()
panic(fmt.Sprintf("table %s not found", t.name))
}

func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()

tblNames := make([]string, 0, len(d.tables))
for k := range d.tables {
tblNames = append(tblNames, k)
Expand All @@ -153,6 +165,9 @@ func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
}

func (d *BaseDatabase) CreateFulltextTableNames(ctx *sql.Context, parentTableName string, parentIndexName string) (fulltext.IndexTableNames, error) {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()

var tablePrefix string
OuterLoop:
for i := uint64(0); true; i++ {
Expand Down Expand Up @@ -225,17 +240,28 @@ func (db *HistoryDatabase) AddTableAsOf(name string, t sql.Table, asOf interface
}

db.Revisions[strings.ToLower(name)][asOf] = t
db.tables[name] = t.(MemTable)
db.AddTable(name, t.(MemTable))
}

// AddTable adds a new table to the database.
func (d *BaseDatabase) AddTable(name string, t MemTable) {
d.tablesMu.Lock()
defer d.tablesMu.Unlock()
d.tables[name] = t
}

// DeleteTable deletes a table from the database.
func (d *BaseDatabase) DeleteTable(name string) {
d.tablesMu.Lock()
defer d.tablesMu.Unlock()
delete(d.tables, name)
}

// CreateTable creates a table with the given name and schema
func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {
d.tablesMu.RLock()
_, ok := d.tables[name]
d.tablesMu.RUnlock()
if ok {
return sql.ErrTableAlreadyExists.New(name)
}
Expand All @@ -247,7 +273,7 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri
table.EnablePrimaryKeyIndexes()
}

d.tables[name] = table
d.AddTable(name, table)
sess := SessionFromContext(ctx)
sess.putTable(table.data)

Expand All @@ -256,7 +282,9 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri

// CreateIndexedTable creates a table with the given name and schema
func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error {
d.tablesMu.RLock()
_, ok := d.tables[name]
d.tablesMu.RUnlock()
if ok {
return sql.ErrTableAlreadyExists.New(name)
}
Expand All @@ -278,31 +306,40 @@ func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql
}
}

d.tables[name] = table
d.AddTable(name, table)
return nil
}

// DropTable drops the table with the given name
func (d *BaseDatabase) DropTable(ctx *sql.Context, name string) error {
d.tablesMu.RLock()
t, ok := d.tables[name]
d.tablesMu.RUnlock()

if !ok {
return sql.ErrTableNotFound.New(name)
}

SessionFromContext(ctx).dropTable(t.(*Table).data)

delete(d.tables, name)
d.DeleteTable(name)
return nil
}

func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) error {
d.tablesMu.RLock()
tbl, ok := d.tables[oldName]
d.tablesMu.RUnlock()

if !ok {
// Should be impossible (engine already checks this condition)
return sql.ErrTableNotFound.New(oldName)
}

d.tablesMu.RLock()
_, ok = d.tables[newName]
d.tablesMu.RUnlock()

if ok {
return sql.ErrTableAlreadyExists.New(newName)
}
Expand All @@ -327,8 +364,8 @@ func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) er
}
memTbl.data.tableName = newName

d.tables[newName] = memTbl
delete(d.tables, oldName)
d.AddTable(newName, memTbl)
d.DeleteTable(oldName)
sess.putTable(memTbl.data)

return nil
Expand Down

0 comments on commit 9e760ca

Please sign in to comment.