From 9c852c0e0a57b4bd1ad16d79965f6ac7b0803cf6 Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Tue, 10 Jul 2018 04:44:37 -0400 Subject: [PATCH] importccl: support unchecked foreign keys in IMPORT PGDUMP We achieve this by implementing sql.SchemaResolver with a map from the found tables in the IMPORT and using that to resolve table names during FK creation. Release note (sql change): support foreign keys in IMPORT PGDUMP. --- pkg/ccl/importccl/csv_internal_test.go | 6 +- pkg/ccl/importccl/import_stmt.go | 97 +++++++++++++++++-- pkg/ccl/importccl/import_stmt_test.go | 102 ++++++++++++++++++-- pkg/ccl/importccl/read_import_mysql_test.go | 2 +- pkg/ccl/importccl/read_import_pgdump.go | 12 ++- pkg/ccl/partitionccl/partition_test.go | 2 +- pkg/sql/create_table.go | 69 +++++++------ 7 files changed, 238 insertions(+), 52 deletions(-) diff --git a/pkg/ccl/importccl/csv_internal_test.go b/pkg/ccl/importccl/csv_internal_test.go index 5dbccba3798f..331f35775202 100644 --- a/pkg/ccl/importccl/csv_internal_test.go +++ b/pkg/ccl/importccl/csv_internal_test.go @@ -40,11 +40,11 @@ func TestMakeSimpleTableDescriptorErrors(t *testing.T) { }, { stmt: "create table a (i int references b (id))", - error: `foreign keys not supported: FOREIGN KEY \(i\) REFERENCES b \(id\)`, + error: `table "b" not found`, }, { stmt: "create table a (i int, constraint a foreign key (i) references c (id))", - error: `foreign keys not supported: CONSTRAINT a FOREIGN KEY \(i\) REFERENCES c \(id\)`, + error: `table "c" not found`, }, { stmt: `create table a ( @@ -71,7 +71,7 @@ func TestMakeSimpleTableDescriptorErrors(t *testing.T) { if !ok { t.Fatal("expected CREATE TABLE statement in table file") } - _, err = MakeSimpleTableDescriptor(ctx, st, create, defaultCSVParentID, defaultCSVTableID, 0) + _, err = MakeSimpleTableDescriptor(ctx, st, create, defaultCSVParentID, defaultCSVTableID, nil, 0) if !testutils.IsError(err, tc.error) { t.Fatalf("expected %v, got %+v", tc.error, err) } diff --git a/pkg/ccl/importccl/import_stmt.go b/pkg/ccl/importccl/import_stmt.go index 6d0fffe8e76c..9687df9458a8 100644 --- a/pkg/ccl/importccl/import_stmt.go +++ b/pkg/ccl/importccl/import_stmt.go @@ -33,6 +33,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -129,6 +130,7 @@ func MakeSimpleTableDescriptor( create *tree.CreateTable, parentID, tableID sqlbase.ID, + otherTables fkResolver, walltime int64, ) (*sqlbase.TableDescriptor, error) { sql.HoistConstraints(create) @@ -153,7 +155,8 @@ func MakeSimpleTableDescriptor( return nil, errors.Errorf("computed columns not supported: %s", tree.AsString(def)) } case *tree.ForeignKeyConstraintTableDef: - return nil, errors.Errorf("foreign keys not supported: %s", tree.AsString(def)) + n := tree.MakeTableName("", tree.Name(def.Table.TableNameReference.String())) + def.Table.TableNameReference = &n default: return nil, errors.Errorf("unsupported table definition: %s", tree.AsString(def)) } @@ -163,28 +166,42 @@ func MakeSimpleTableDescriptor( CtxProvider: ctxProvider{ctx}, Sequence: &importSequenceOperators{}, } + affected := make(map[sqlbase.ID]*sqlbase.TableDescriptor) tableDesc, err := sql.MakeTableDesc( ctx, nil, /* txn */ - nil, /* vt */ + otherTables, st, create, parentID, tableID, hlc.Timestamp{WallTime: walltime}, sqlbase.NewDefaultPrivilegeDescriptor(), - nil, /* affected */ + affected, &semaCtx, &evalCtx, ) if err != nil { return nil, err } + // If the table had a FK, it was put into the ADD state and its references were marked as validated. We need to undo those changes. + tableDesc.State = sqlbase.TableDescriptor_PUBLIC + if err := tableDesc.ForeachNonDropIndex(func(idx *sqlbase.IndexDescriptor) error { + if idx.ForeignKey.IsSet() { + idx.ForeignKey.Validity = sqlbase.ConstraintValidity_Unvalidated + } + return nil + }); err != nil { + return nil, err + } return &tableDesc, nil } -var errSequenceOperators = errors.New("sequence operations unsupported") +var ( + errSequenceOperators = errors.New("sequence operations unsupported") + errSchemaResolver = errors.New("schema resolver unsupported") +) // Implements the tree.SequenceOperators interface. type importSequenceOperators struct { @@ -230,6 +247,58 @@ func (so *importSequenceOperators) SetSequenceValue( return errSequenceOperators } +type fkResolver map[string]*sqlbase.TableDescriptor + +var _ sql.SchemaResolver = fkResolver{} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) Txn() *client.Txn { + return nil +} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) LogicalSchemaAccessor() sql.SchemaAccessor { + return nil +} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) CurrentDatabase() string { + return "" +} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) CurrentSearchPath() sessiondata.SearchPath { + return sessiondata.SearchPath{} +} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) CommonLookupFlags(ctx context.Context, required bool) sql.CommonLookupFlags { + return sql.CommonLookupFlags{} +} + +// Implements the sql.SchemaResolver interface. +func (r fkResolver) ObjectLookupFlags(ctx context.Context, required bool) sql.ObjectLookupFlags { + return sql.ObjectLookupFlags{} +} + +// Implements the tree.TableNameExistingResolver interface. +func (r fkResolver) LookupObject( + ctx context.Context, dbName, scName, obName string, +) (found bool, objMeta tree.NameResolutionResult, err error) { + tbl, ok := r[obName] + if ok { + return true, tbl, nil + } + return false, nil, errors.Errorf("table %q not found in tables previously defined in the same IMPORT", obName) +} + +// Implements the tree.TableNameTargetResolver interface. +func (r fkResolver) LookupSchema( + ctx context.Context, dbName, scName string, +) (found bool, scMeta tree.SchemaMeta, err error) { + return false, nil, errSchemaResolver +} + const csvDatabaseName = "csv" func finalizeCSVBackup( @@ -626,7 +695,7 @@ func importPlanHook( } tbl, err := MakeSimpleTableDescriptor( - ctx, p.ExecCfg().Settings, create, parentID, defaultCSVTableID, walltime) + ctx, p.ExecCfg().Settings, create, parentID, defaultCSVTableID, nil, walltime) if err != nil { return err } @@ -659,12 +728,28 @@ func importPlanHook( // restoring. We do this last because we want to avoid calling // GenerateUniqueDescID if there's any kind of error above. // Reserving a table ID now means we can avoid the rekey work during restore. + tableRewrites := make(map[sqlbase.ID]sqlbase.ID) for _, tableDesc := range tableDescs { - tableDesc.ID, err = sql.GenerateUniqueDescID(ctx, p.ExecCfg().DB) + tableRewrites[tableDesc.ID], err = sql.GenerateUniqueDescID(ctx, p.ExecCfg().DB) if err != nil { return err } } + // Now that we have all the new table IDs rewrite them along with FKs. + for _, tableDesc := range tableDescs { + tableDesc.ID = tableRewrites[tableDesc.ID] + if err := tableDesc.ForeachNonDropIndex(func(idx *sqlbase.IndexDescriptor) error { + if idx.ForeignKey.IsSet() { + idx.ForeignKey.Table = tableRewrites[idx.ForeignKey.Table] + } + for i, fk := range idx.ReferencedBy { + idx.ReferencedBy[i].Table = tableRewrites[fk.Table] + } + return nil + }); err != nil { + return err + } + } } tableDetails := make([]jobspb.ImportDetails_Table, 0, len(tableDescs)) diff --git a/pkg/ccl/importccl/import_stmt_test.go b/pkg/ccl/importccl/import_stmt_test.go index a7080c7d306a..8aa9d829b001 100644 --- a/pkg/ccl/importccl/import_stmt_test.go +++ b/pkg/ccl/importccl/import_stmt_test.go @@ -62,13 +62,14 @@ func TestImportData(t *testing.T) { sqlDB.Exec(t, `CREATE DATABASE d; USE d`) tests := []struct { - name string - create string - with string - typ string - data string - err string - query map[string][][]string + name string + create string + with string + typ string + data string + err string + cleanup string + query map[string][][]string }{ { name: "duplicate unique index key", @@ -412,6 +413,37 @@ COPY t (a, b, c) FROM stdin; `, err: "expected 2 columns, got 3", }, + { + name: "fk", + typ: "PGDUMP", + data: testPgdumpFk, + query: map[string][][]string{ + `SHOW TABLES`: {{"cities"}, {"weather"}}, + `SELECT city FROM cities`: {{"Berkeley"}}, + `SELECT city FROM weather`: {{"Berkeley"}}, + + `SELECT dependson_name + FROM crdb_internal.backward_dependencies + `: {{"weather_city_fkey"}}, + + `SELECT create_statement + FROM crdb_internal.create_statements + WHERE descriptor_name in ('cities', 'weather') + ORDER BY descriptor_name + `: {{testPgdumpCreateCities}, {testPgdumpCreateWeather}}, + + // Verify the constraint is unvalidated. + `SHOW CONSTRAINTS FROM weather + `: {{"weather", "weather_city_fkey", "FOREIGN KEY", "FOREIGN KEY (city) REFERENCES cities (city)", "false"}}, + }, + cleanup: `DROP TABLE cities, weather`, + }, + { + name: "fk unreferenced", + typ: "TABLE weather FROM PGDUMP", + data: testPgdumpFk, + err: `table "cities" not found`, + }, // Error { @@ -438,7 +470,7 @@ COPY t (a, b, c) FROM stdin; for _, tc := range tests { t.Run(fmt.Sprintf("%s: %s", tc.typ, tc.name), func(t *testing.T) { - sqlDB.Exec(t, `DROP TABLE IF EXISTS d.t`) + sqlDB.Exec(t, `DROP TABLE IF EXISTS d.t, t`) var q string if tc.create != "" { q = fmt.Sprintf(`IMPORT TABLE d.t (%s) %s DATA ($1) %s`, tc.create, tc.typ, tc.with) @@ -454,6 +486,9 @@ COPY t (a, b, c) FROM stdin; for query, res := range tc.query { sqlDB.CheckQueryResults(t, query, res) } + if tc.cleanup != "" { + sqlDB.Exec(t, tc.cleanup) + } }) } @@ -465,6 +500,55 @@ COPY t (a, b, c) FROM stdin; }) } +const ( + testPgdumpCreateCities = `CREATE TABLE cities ( + city STRING(80) NOT NULL, + CONSTRAINT cities_pkey PRIMARY KEY (city ASC), + FAMILY "primary" (city) +)` + testPgdumpCreateWeather = `CREATE TABLE weather ( + city STRING(80) NULL, + temp_lo INTEGER NULL, + temp_hi INTEGER NULL, + prcp REAL NULL, + date DATE NULL, + CONSTRAINT weather_city_fkey FOREIGN KEY (city) REFERENCES cities (city), + INDEX weather_auto_index_weather_city_fkey (city ASC), + FAMILY "primary" (city, temp_lo, temp_hi, prcp, date, rowid) +)` + testPgdumpFk = ` +CREATE TABLE cities ( + city character varying(80) NOT NULL +); + +ALTER TABLE cities OWNER TO postgres; + +CREATE TABLE weather ( + city character varying(80), + temp_lo integer, + temp_hi integer, + prcp real, + date date +); + +ALTER TABLE weather OWNER TO postgres; + +COPY cities (city) FROM stdin; +Berkeley +\. + +COPY weather (city, temp_lo, temp_hi, prcp, date) FROM stdin; +Berkeley 45 53 0 1994-11-28 +\. + +ALTER TABLE ONLY cities + ADD CONSTRAINT cities_pkey PRIMARY KEY (city); + +ALTER TABLE ONLY weather + ADD CONSTRAINT weather_city_fkey FOREIGN KEY (city) REFERENCES cities(city); +` +) + // TODO(dt): switch to a helper in sampledataccl. func makeCSVData( t testing.TB, in string, numFiles, rowsPerFile int, @@ -1136,7 +1220,7 @@ func BenchmarkConvertRecord(b *testing.B) { create := stmt.(*tree.CreateTable) st := cluster.MakeTestingClusterSettings() - tableDesc, err := MakeSimpleTableDescriptor(ctx, st, create, sqlbase.ID(100), sqlbase.ID(100), 1) + tableDesc, err := MakeSimpleTableDescriptor(ctx, st, create, sqlbase.ID(100), sqlbase.ID(100), nil, 1) if err != nil { b.Fatal(err) } diff --git a/pkg/ccl/importccl/read_import_mysql_test.go b/pkg/ccl/importccl/read_import_mysql_test.go index c5673afcfa8d..54944a431f92 100644 --- a/pkg/ccl/importccl/read_import_mysql_test.go +++ b/pkg/ccl/importccl/read_import_mysql_test.go @@ -39,7 +39,7 @@ func descForTable(t *testing.T, create string, parent, id sqlbase.ID) *sqlbase.T t.Fatal(err) } stmt := parsed.(*tree.CreateTable) - table, err := MakeSimpleTableDescriptor(context.TODO(), nil, stmt, parent, id, testEvalCtx.StmtTimestamp.UnixNano()) + table, err := MakeSimpleTableDescriptor(context.TODO(), nil, stmt, parent, id, nil, testEvalCtx.StmtTimestamp.UnixNano()) if err != nil { t.Fatal(err) } diff --git a/pkg/ccl/importccl/read_import_pgdump.go b/pkg/ccl/importccl/read_import_pgdump.go index 3b8d6d28afa9..288c6baf5c85 100644 --- a/pkg/ccl/importccl/read_import_pgdump.go +++ b/pkg/ccl/importccl/read_import_pgdump.go @@ -169,18 +169,25 @@ func readPostgresCreateTable( // we'd have to delete the index and row and modify the column family. This // is much easier and probably safer too. createTbl := make(map[string]*tree.CreateTable) + // We need to run MakeSimpleTableDescriptor on tables in the same order as + // seen in the SQL file to guarantee that dependencies exist before being used + // (for FKs and sequences). + var tableOrder []string ps := newPostgreStream(input, max) for { stmt, err := ps.Next() if err == io.EOF { ret := make([]*sqlbase.TableDescriptor, 0, len(createTbl)) - for _, create := range createTbl { + seenDescs := make(fkResolver) + for _, name := range tableOrder { + create := createTbl[name] if create != nil { id := sqlbase.ID(int(defaultCSVTableID) + len(ret)) - desc, err := MakeSimpleTableDescriptor(evalCtx.Ctx(), settings, create, parentID, id, walltime) + desc, err := MakeSimpleTableDescriptor(evalCtx.Ctx(), settings, create, parentID, id, seenDescs, walltime) if err != nil { return nil, err } + seenDescs[desc.Name] = desc ret = append(ret, desc) } } @@ -210,6 +217,7 @@ func readPostgresCreateTable( } else { createTbl[name] = stmt } + tableOrder = append(tableOrder, name) case *tree.CreateIndex: name, err := getTableName(stmt.Table) if err != nil { diff --git a/pkg/ccl/partitionccl/partition_test.go b/pkg/ccl/partitionccl/partition_test.go index f9cb192272f2..b4a0391aa4b2 100644 --- a/pkg/ccl/partitionccl/partition_test.go +++ b/pkg/ccl/partitionccl/partition_test.go @@ -125,7 +125,7 @@ func (t *partitioningTest) parse() error { st := cluster.MakeTestingClusterSettings() const parentID, tableID = keys.MinUserDescID, keys.MinUserDescID + 1 t.parsed.tableDesc, err = importccl.MakeSimpleTableDescriptor( - ctx, st, createTable, parentID, tableID, hlc.UnixNano()) + ctx, st, createTable, parentID, tableID, nil, hlc.UnixNano()) if err != nil { return err } diff --git a/pkg/sql/create_table.go b/pkg/sql/create_table.go index a86c72e4e20a..e8f3dd10a2ea 100644 --- a/pkg/sql/create_table.go +++ b/pkg/sql/create_table.go @@ -411,10 +411,25 @@ func (p *planner) resolveFK( backrefs map[sqlbase.ID]*sqlbase.TableDescriptor, mode sqlbase.ConstraintValidity, ) error { - return resolveFK(ctx, p.txn, p, tbl, d, backrefs, mode) + return ResolveFK(ctx, p.txn, p, tbl, d, backrefs, mode) } -// resolveFK looks up the tables and columns mentioned in a `REFERENCES` +func qualifyFKColErrorWithDB( + ctx context.Context, txn *client.Txn, tbl *sqlbase.TableDescriptor, col string, +) string { + if txn == nil { + return tree.ErrString(tree.NewUnresolvedName(tbl.Name, col)) + } + + // TODO(whomever): this ought to use a database cache. + db, err := sqlbase.GetDatabaseDescFromID(ctx, txn, tbl.ParentID) + if err != nil { + return tree.ErrString(tree.NewUnresolvedName(tbl.Name, col)) + } + return tree.ErrString(tree.NewUnresolvedName(db.Name, tree.PublicSchema, tbl.Name, col)) +} + +// ResolveFK looks up the tables and columns mentioned in a `REFERENCES` // constraint and adds metadata representing that constraint to the descriptor. // It may, in doing so, add to or alter descriptors in the passed in `backrefs` // map of other tables that need to be updated when this table is created. @@ -432,7 +447,10 @@ func (p *planner) resolveFK( // If there are any FKs, the descriptor of the depended-on table must // be looked up uncached, and we'll allow FK dependencies on tables // that were just added. -func resolveFK( +// +// The passed Txn is used to lookup databases to qualify names in error messages +// but if nil, will result in unqualified names in those errors. +func ResolveFK( ctx context.Context, txn *client.Txn, sc SchemaResolver, @@ -441,6 +459,16 @@ func resolveFK( backrefs map[sqlbase.ID]*sqlbase.TableDescriptor, mode sqlbase.ConstraintValidity, ) error { + for _, col := range d.FromCols { + col, _, err := tbl.FindColumnByName(col) + if err != nil { + return err + } + if err := col.CheckCanBeFKRef(); err != nil { + return err + } + } + targetTable := d.Table.TableName() target, err := ResolveExistingObject(ctx, sc, targetTable, true /*required*/, requireTableDesc) @@ -528,7 +556,7 @@ func resolveFK( return pgerror.NewErrorf( pgerror.CodeInvalidForeignKeyError, "there is no unique constraint matching given keys for referenced table %s", - targetTable.String(), + target.Name, ) } } @@ -538,15 +566,10 @@ func resolveFK( if d.Actions.Delete == tree.SetNull || d.Actions.Update == tree.SetNull { for _, sourceColumn := range srcCols { if !sourceColumn.Nullable { - // TODO(whomever): this ought to use a database cache. - database, err := sqlbase.GetDatabaseDescFromID(ctx, txn, tbl.ParentID) - if err != nil { - return err - } + col := qualifyFKColErrorWithDB(ctx, txn, tbl, sourceColumn.Name) return pgerror.NewErrorf(pgerror.CodeInvalidForeignKeyError, - "cannot add a SET NULL cascading action on column %q which has a NOT NULL constraint", - tree.ErrString(tree.NewUnresolvedName( - database.Name, tree.PublicSchema, tbl.Name, sourceColumn.Name))) + "cannot add a SET NULL cascading action on column %q which has a NOT NULL constraint", col, + ) } } } @@ -556,15 +579,10 @@ func resolveFK( if d.Actions.Delete == tree.SetDefault || d.Actions.Update == tree.SetDefault { for _, sourceColumn := range srcCols { if sourceColumn.DefaultExpr == nil { - // TODO(whomever): this ought to use a database cache. - database, err := sqlbase.GetDatabaseDescFromID(ctx, txn, tbl.ParentID) - if err != nil { - return err - } + col := qualifyFKColErrorWithDB(ctx, txn, tbl, sourceColumn.Name) return pgerror.NewErrorf(pgerror.CodeInvalidForeignKeyError, - "cannot add a SET DEFAULT cascading action on column %q which has no DEFAULT expression", - tree.ErrString(tree.NewUnresolvedName( - database.Name, tree.PublicSchema, tbl.Name, sourceColumn.Name))) + "cannot add a SET DEFAULT cascading action on column %q which has no DEFAULT expression", col, + ) } } } @@ -1218,16 +1236,7 @@ func MakeTableDesc( desc.Checks = append(desc.Checks, ck) case *tree.ForeignKeyConstraintTableDef: - for _, col := range d.FromCols { - col, _, err := desc.FindColumnByName(col) - if err != nil { - return desc, err - } - if err := col.CheckCanBeFKRef(); err != nil { - return desc, err - } - } - if err := resolveFK(ctx, txn, fkResolver, &desc, d, affected, sqlbase.ConstraintValidity_Validated); err != nil { + if err := ResolveFK(ctx, txn, fkResolver, &desc, d, affected, sqlbase.ConstraintValidity_Validated); err != nil { return desc, err } default: