diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 16b4e142bd91..df887f0196bd 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -256,6 +256,7 @@ ALL_TESTS = [ "//pkg/sql/catalog/seqexpr:seqexpr_disallowed_imports_test", "//pkg/sql/catalog/seqexpr:seqexpr_test", "//pkg/sql/catalog/systemschema_test:systemschema_test_test", + "//pkg/sql/catalog/tabledesc:tabledesc_disallowed_imports_test", "//pkg/sql/catalog/tabledesc:tabledesc_test", "//pkg/sql/catalog/typedesc:typedesc_test", "//pkg/sql/catalog:catalog_disallowed_imports_test", diff --git a/pkg/ccl/backupccl/restore_old_sequences_test.go b/pkg/ccl/backupccl/restore_old_sequences_test.go index 4934b4cdf28f..783f97767702 100644 --- a/pkg/ccl/backupccl/restore_old_sequences_test.go +++ b/pkg/ccl/backupccl/restore_old_sequences_test.go @@ -79,35 +79,29 @@ func restoreOldSequencesTest(exportDir string) func(t *testing.T) { t.Fatalf("expected %d rows, got %d", totalRows, importedRows) } - // Verify that sequences created in older versions cannot be renamed, nor can the - // database they are referencing. - sqlDB.ExpectErr(t, - `pq: cannot rename relation "test.public.s" because view "t1" depends on it`, - `ALTER SEQUENCE test.s RENAME TO test.s2`) - sqlDB.ExpectErr(t, - `pq: cannot rename relation "test.public.t1_i_seq" because view "t1" depends on it`, - `ALTER SEQUENCE test.t1_i_seq RENAME TO test.t1_i_seq_new`) - sqlDB.ExpectErr(t, - `pq: cannot rename database because relation "test.public.t1" depends on relation "test.public.s"`, - `ALTER DATABASE test RENAME TO new_test`) + // Verify that restored sequences are now referenced by ID. + var createTable string + sqlDB.QueryRow(t, `SHOW CREATE test.t1`).Scan(&unused, &createTable) + require.Contains(t, createTable, "i INT8 NOT NULL DEFAULT nextval('test.public.t1_i_seq'::REGCLASS)") + require.Contains(t, createTable, "j INT8 NOT NULL DEFAULT nextval('test.public.s'::REGCLASS)") + sqlDB.QueryRow(t, `SHOW CREATE test.v`).Scan(&unused, &createTable) + require.Contains(t, createTable, "SELECT nextval('test.public.s2'::REGCLASS)") + sqlDB.QueryRow(t, `SHOW CREATE test.v2`).Scan(&unused, &createTable) + require.Contains(t, createTable, "SELECT nextval('test.public.s2'::REGCLASS) AS k") - sequenceResults := [][]string{ + // Verify that, as a result, all sequences can now be renamed. + sqlDB.Exec(t, `ALTER SEQUENCE test.t1_i_seq RENAME TO test.t1_i_seq_new`) + sqlDB.Exec(t, `ALTER SEQUENCE test.s RENAME TO test.s_new`) + sqlDB.Exec(t, `ALTER SEQUENCE test.s2 RENAME TO test.s2_new`) + + // Finally, verify that sequences are correctly restored and can be used in tables/views. + sqlDB.Exec(t, `INSERT INTO test.t1 VALUES (default, default)`) + expectedRows := [][]string{ {"1", "1"}, {"2", "2"}, } - - // Verify that tables with old sequences aren't corrupted. - sqlDB.Exec(t, `SET database = test; INSERT INTO test.t1 VALUES (default, default)`) - sqlDB.CheckQueryResults(t, `SELECT * FROM test.t1 ORDER BY i`, sequenceResults) - - // Verify that the views are okay, and the sequences it depends on cannot be renamed. - sqlDB.CheckQueryResults(t, `SET database = test; SELECT * FROM test.v`, [][]string{{"1"}}) - sqlDB.CheckQueryResults(t, `SET database = test; SELECT * FROM test.v2`, [][]string{{"2"}}) - sqlDB.ExpectErr(t, - `pq: cannot rename relation "s2" because view "v" depends on it`, - `ALTER SEQUENCE s2 RENAME TO s3`) - sqlDB.CheckQueryResults(t, `SET database = test; SHOW CREATE VIEW test.v`, [][]string{{ - "test.public.v", "CREATE VIEW public.v (\n\tnextval\n) AS (SELECT nextval('s2':::STRING))", - }}) + sqlDB.CheckQueryResults(t, `SELECT * FROM test.t1 ORDER BY i`, expectedRows) + sqlDB.CheckQueryResults(t, `SELECT * FROM test.v`, [][]string{{"1"}}) + sqlDB.CheckQueryResults(t, `SELECT * FROM test.v2`, [][]string{{"2"}}) } } diff --git a/pkg/ccl/backupccl/restore_planning.go b/pkg/ccl/backupccl/restore_planning.go index e4c6ba223cd2..c818b34ac986 100644 --- a/pkg/ccl/backupccl/restore_planning.go +++ b/pkg/ccl/backupccl/restore_planning.go @@ -42,6 +42,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" "github.com/cockroachdb/cockroach/pkg/sql/catalog/multiregion" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/nstree" "github.com/cockroachdb/cockroach/pkg/sql/catalog/rewrite" "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemadesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/systemschema" @@ -863,6 +864,12 @@ func resolveTargetDB( // the set provided are omitted during the upgrade, instead of causing an error // to be returned. func maybeUpgradeDescriptors(descs []catalog.Descriptor, skipFKsWithNoMatchingTable bool) error { + // A data structure for efficient descriptor lookup by ID or by name. + descCatalog := &nstree.MutableCatalog{} + for _, d := range descs { + descCatalog.UpsertDescriptorEntry(d) + } + for j, desc := range descs { var b catalog.DescriptorBuilder if tableDesc, isTable := desc.(catalog.TableDescriptor); isTable { @@ -873,14 +880,7 @@ func maybeUpgradeDescriptors(descs []catalog.Descriptor, skipFKsWithNoMatchingTa if err := b.RunPostDeserializationChanges(); err != nil { return errors.NewAssertionErrorWithWrappedErrf(err, "error during RunPostDeserializationChanges") } - err := b.RunRestoreChanges(func(id descpb.ID) catalog.Descriptor { - for _, d := range descs { - if d.GetID() == id { - return d - } - } - return nil - }) + err := b.RunRestoreChanges(descCatalog.LookupDescriptorEntry) if err != nil { return err } diff --git a/pkg/sql/catalog/BUILD.bazel b/pkg/sql/catalog/BUILD.bazel index 904e78b574b6..753b4843e7e4 100644 --- a/pkg/sql/catalog/BUILD.bazel +++ b/pkg/sql/catalog/BUILD.bazel @@ -10,7 +10,7 @@ go_library( "descriptor.go", "descriptor_id_set.go", "errors.go", - "post_derserialization_changes.go", + "post_deserialization_changes.go", "privilege_object.go", "schema.go", "synthetic_privilege.go", diff --git a/pkg/sql/catalog/post_derserialization_changes.go b/pkg/sql/catalog/post_deserialization_changes.go similarity index 94% rename from pkg/sql/catalog/post_derserialization_changes.go rename to pkg/sql/catalog/post_deserialization_changes.go index b1e7fa74dc86..6d2af1b5c257 100644 --- a/pkg/sql/catalog/post_derserialization_changes.go +++ b/pkg/sql/catalog/post_deserialization_changes.go @@ -86,4 +86,8 @@ const ( // dropping a schema, we'd mark the database itself as though it was the // schema which was dropped. RemovedSelfEntryInSchemas + + // UpgradedSequenceReference indicates that the table/view had upgraded + // their sequence references, if any, from by-name to by-ID, if not already. + UpgradedSequenceReference ) diff --git a/pkg/sql/catalog/seqexpr/BUILD.bazel b/pkg/sql/catalog/seqexpr/BUILD.bazel index 1f37b67708ce..6471ff742afc 100644 --- a/pkg/sql/catalog/seqexpr/BUILD.bazel +++ b/pkg/sql/catalog/seqexpr/BUILD.bazel @@ -8,11 +8,14 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr", visibility = ["//visibility:public"], deps = [ + "//pkg/sql/catalog/descpb", + "//pkg/sql/parser", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/sem/builtins/builtinconstants", "//pkg/sql/sem/tree", "//pkg/sql/types", + "@com_github_cockroachdb_errors//:errors", ], ) @@ -21,11 +24,13 @@ go_test( srcs = ["sequence_test.go"], deps = [ ":seqexpr", + "//pkg/sql/catalog/descpb", "//pkg/sql/parser", "//pkg/sql/sem/builtins", "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/tree", "//pkg/sql/types", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/catalog/seqexpr/sequence.go b/pkg/sql/catalog/seqexpr/sequence.go index 1c25dfb239bd..89c7fab5a059 100644 --- a/pkg/sql/catalog/seqexpr/sequence.go +++ b/pkg/sql/catalog/seqexpr/sequence.go @@ -18,11 +18,14 @@ package seqexpr import ( "go/constant" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/errors" ) // SeqIdentifier wraps together different ways of identifying a sequence. @@ -43,7 +46,7 @@ func (si *SeqIdentifier) IsByID() bool { // a sequence name or an ID), wrapped in the SeqIdentifier type. // Returns the identifier of the sequence or nil if no sequence was found. // -// `getBuiltinProperties` argument is commonly builtins.GetBuiltinProperties. +// `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. func GetSequenceFromFunc( funcExpr *tree.FuncExpr, getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), @@ -131,7 +134,7 @@ func getSequenceIdentifier(expr tree.Expr) *SeqIdentifier { // identifiers are found. The identifier is wrapped in a SeqIdentifier. // e.g. nextval('foo') => "foo"; nextval(123::regclass) => 123; => nil // -// `getBuiltinProperties` argument is commonly builtins.GetBuiltinProperties. +// `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. func GetUsedSequences( defaultExpr tree.Expr, getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), @@ -163,10 +166,10 @@ func GetUsedSequences( // any sequence names in the expression by their IDs instead. // e.g. nextval('foo') => nextval(123::regclass) // -// `getBuiltinProperties` argument is commonly builtins.GetBuiltinProperties. +// `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. func ReplaceSequenceNamesWithIDs( defaultExpr tree.Expr, - nameToID map[string]int64, + nameToID map[string]descpb.ID, getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), ) (tree.Expr, error) { replaceFn := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { @@ -190,7 +193,7 @@ func ReplaceSequenceNamesWithIDs( &tree.AnnotateTypeExpr{ Type: types.RegClass, SyntaxMode: tree.AnnotateShort, - Expr: tree.NewNumVal(constant.MakeInt64(id), "", false), + Expr: tree.NewNumVal(constant.MakeInt64(int64(id)), "", false), }, }, }, nil @@ -201,3 +204,162 @@ func ReplaceSequenceNamesWithIDs( newExpr, err := tree.SimpleVisit(defaultExpr, replaceFn) return newExpr, err } + +// UpgradeSequenceReferenceInExpr upgrades all by-name sequence +// reference in `expr` to by-ID with a provided id-to-name +// mapping `usedSequenceIDsToNames`, from which we should be able +// to uniquely determine the ID of each by-name seq reference. +// +// Such a mapping can often be constructed if we know the sequence IDs +// used in a particular expression, e.g. a column descriptor's +// `usesSequenceIDs` field or a view descriptor's `dependsOn` field if +// the column DEFAULT/ON-UPDATE or the view's query references sequences. +// +// `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. +func UpgradeSequenceReferenceInExpr( + expr *string, + usedSequenceIDsToNames map[descpb.ID]*tree.TableName, + getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), +) (hasUpgraded bool, err error) { + // Find the "reverse" mapping from sequence name to their IDs for those + // sequences referenced by-name in `expr`. + usedSequenceNamesToIDs, err := seqNameToIDMappingInExpr(*expr, usedSequenceIDsToNames, getBuiltinProperties) + if err != nil { + return false, err + } + + // With this "reverse" mapping, we can simply replace each by-name + // seq reference in `expr` with the sequence's ID. + parsedExpr, err := parser.ParseExpr(*expr) + if err != nil { + return false, err + } + + newExpr, err := ReplaceSequenceNamesWithIDs(parsedExpr, usedSequenceNamesToIDs, getBuiltinProperties) + if err != nil { + return false, err + } + + // Modify `expr` in place, if any upgrade. + if *expr != tree.Serialize(newExpr) { + hasUpgraded = true + *expr = tree.Serialize(newExpr) + } + + return hasUpgraded, nil +} + +// seqNameToIDMappingInExpr attempts to find the seq ID for +// every by-name seq reference in `expr` from `seqIDToNameMapping`. +// This process can be thought of as a "reverse mapping" process +// where, given an id-to-seq-name mapping, for each by-name seq reference +// in `expr`, we attempt to find the entry in that mapping such that +// the entry's name "best matches" the by-name seq reference. +// See comments of findUniqueBestMatchingForTableName for "best matching" definition. +// +// It returns a non-nill error if zero or multiple entries +// in `seqIDToNameMapping` have a name that "best matches" +// the by-name seq reference. +// +// See its unit test for some examples. +func seqNameToIDMappingInExpr( + expr string, + seqIDToNameMapping map[descpb.ID]*tree.TableName, + getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), +) (map[string]descpb.ID, error) { + parsedExpr, err := parser.ParseExpr(expr) + if err != nil { + return nil, err + } + seqRefs, err := GetUsedSequences(parsedExpr, getBuiltinProperties) + if err != nil { + return nil, err + } + + // Construct the key mapping from seq-by-name-reference to their IDs. + result := make(map[string]descpb.ID) + for _, seqIdentifier := range seqRefs { + if seqIdentifier.IsByID() { + continue + } + + parsedSeqName, err := parser.ParseQualifiedTableName(seqIdentifier.SeqName) + if err != nil { + return nil, err + } + + // Pairing: find out which sequence name in the id-to-name mapping + // (i.e. `seqIDToNameMapping`) matches `parsedSeqName` so we + // know the ID of it. + idOfSeqIdentifier, err := findUniqueBestMatchingForTableName(seqIDToNameMapping, *parsedSeqName) + if err != nil { + return nil, err + } + + // Put it to the reverse mapping. + result[seqIdentifier.SeqName] = idOfSeqIdentifier + } + return result, nil +} + +// findUniqueBestMatchingForTableName picks the "best-matching" name from +// `allTableNamesByID` for `targetTableName`. The best-matching name is the +// one that matches all parts of `targetTableName`, if that part exists +// in both names. +// Example 1: +// allTableNamesByID = {23 : 'db.sc1.t', 25 : 'db.sc2.t'} +// tableName = 'sc2.t' +// return = 25 (because `db.sc2.t` best-matches `sc2.t`) +// Example 2: +// allTableNamesByID = {23 : 'db.sc1.t', 25 : 'sc2.t'} +// tableName = 'sc2.t' +// return = 25 (because `sc2.t` best-matches `sc2.t`) +// Example 3: +// allTableNamesByID = {23 : 'db.sc1.t', 25 : 'sc2.t'} +// tableName = 'db.sc2.t' +// return = 25 (because `sc2.t` best-matches `db.sc2.t`) +// +// Example 4: +// allTableNamesByID = {23 : 'sc1.t', 25 : 'sc2.t'} +// tableName = 't' +// return = non-nil error (because both 'sc1.t' and 'sc2.t' are equally good matches +// for 't' and we cannot decide, i.e., >1 valid candidates left.) +// Example 5: +// allTableNamesByID = {23 : 'sc1.t', 25 : 'sc2.t'} +// tableName = 't2' +// return = non-nil error (because neither 'sc1.t' nor 'sc2.t' matches 't2', that is, 0 valid candidate left) +func findUniqueBestMatchingForTableName( + allTableNamesByID map[descpb.ID]*tree.TableName, targetTableName tree.TableName, +) (match descpb.ID, err error) { + t := targetTableName.Table() + if t == "" { + return descpb.InvalidID, errors.AssertionFailedf("input tableName does not have a Table field.") + } + + for id, candidateTableName := range allTableNamesByID { + ct, tt := candidateTableName.Table(), targetTableName.Table() + cs, ts := candidateTableName.Schema(), targetTableName.Schema() + cdb, tdb := candidateTableName.Catalog(), targetTableName.Catalog() + if (ct != "" && tt != "" && ct != tt) || + (cs != "" && ts != "" && cs != ts) || + (cdb != "" && tdb != "" && cdb != tdb) { + // not a match -- there is a part, either db or schema or table name, + // that exists in both names but they don't match. + continue + } + + // id passes the check; consider it as the result + // If already found a valid result, report error! + if match != descpb.InvalidID { + return descpb.InvalidID, errors.AssertionFailedf("more than 1 matches found for %q", + targetTableName.String()) + } + match = id + } + + if match == descpb.InvalidID { + return descpb.InvalidID, errors.AssertionFailedf("no table name found to match input %q", t) + } + + return match, nil +} diff --git a/pkg/sql/catalog/seqexpr/sequence_test.go b/pkg/sql/catalog/seqexpr/sequence_test.go index 19f14348a150..04d9494e1618 100644 --- a/pkg/sql/catalog/seqexpr/sequence_test.go +++ b/pkg/sql/catalog/seqexpr/sequence_test.go @@ -15,12 +15,14 @@ import ( "fmt" "testing" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr" "github.com/cockroachdb/cockroach/pkg/sql/parser" _ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" // register all builtins in builtins:init() for seqexpr package "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/stretchr/testify/require" ) func TestGetSequenceFromFunc(t *testing.T) { @@ -122,7 +124,7 @@ func TestGetUsedSequences(t *testing.T) { } func TestReplaceSequenceNamesWithIDs(t *testing.T) { - namesToID := map[string]int64{ + namesToID := map[string]descpb.ID{ "seq": 123, } @@ -158,3 +160,126 @@ func TestReplaceSequenceNamesWithIDs(t *testing.T) { }) } } + +func TestUpgradeSequenceReferenceInExpr(t *testing.T) { + t.Run("test name-matching -- fully resolved candidate names", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("testdb", "sc1", "t") + tbl2 := tree.MakeTableNameWithSchema("testdb", "sc2", "t") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + expr := "nextval('testdb.sc1.t') + nextval('sc1.t')" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.True(t, hasUpgraded) + require.Equal(t, + "nextval(1:::REGCLASS) + nextval(1:::REGCLASS)", + expr) + }) + + t.Run("test name-matching -- partially resolved candidate names", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("", "sc1", "t") + tbl2 := tree.MakeTableNameWithSchema("testdb", "sc2", "t") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + expr := "nextval('testdb.sc1.t') + nextval('sc1.t')" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.True(t, hasUpgraded) + require.Equal(t, + "nextval(1:::REGCLASS) + nextval(1:::REGCLASS)", + expr) + }) + + t.Run("test name-matching -- public schema will be assumed when it's missing in candidate names", + func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("testdb", "", "t") + tbl2 := tree.MakeTableNameWithSchema("", "sc2", "t") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + expr := "nextval('testdb.public.t') + nextval('testdb.t')" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.True(t, hasUpgraded) + require.Equal(t, + "nextval(1:::REGCLASS) + nextval(1:::REGCLASS)", + expr) + }) + + t.Run("test name-matching -- ambiguous name matching, >1 candidates", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("", "sc1", "t") + tbl2 := tree.MakeTableNameWithSchema("", "sc2", "t") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + expr := "nextval('t')" + _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.Error(t, err, "ambiguous name matching for 't'; both 'sc1.t' and 'sc2.t' match it.") + require.Equal(t, "more than 1 matches found for \"t\"", err.Error()) + }) + + t.Run("test name-matching -- no matching name, 0 candidate", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("", "sc1", "t") + tbl2 := tree.MakeTableNameWithSchema("", "sc2", "t") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + expr := "nextval('t2')" + _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.Error(t, err, "no matching name for 't2'; neither 'sc1.t' nor 'sc2.t' match it.") + require.Equal(t, "no table name found to match input \"t2\"", err.Error()) + }) + + t.Run("all seq references are by-ID (no upgrades)", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("testdb", "public", "s1") + tbl2 := tree.MakeTableNameWithSchema("testdb", "public", "s2") + tbl3 := tree.MakeTableNameWithSchema("testdb", "sc1", "s3") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + usedSequenceIDsToNames[3] = &tbl3 + expr := "((nextval(1::REGCLASS) + nextval(2::REGCLASS)) + currval(3::REGCLASS)) + nextval(3::REGCLASS)" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.False(t, hasUpgraded) + require.Equal(t, + "((nextval(1::REGCLASS) + nextval(2::REGCLASS)) + currval(3::REGCLASS)) + nextval(3::REGCLASS)", + expr) + }) + + t.Run("all seq references are by-name", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("testdb", "public", "s1") + tbl2 := tree.MakeTableNameWithSchema("testdb", "public", "s2") + tbl3 := tree.MakeTableNameWithSchema("testdb", "sc1", "s3") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + usedSequenceIDsToNames[3] = &tbl3 + expr := "nextval('testdb.public.s1') + nextval('testdb.public.s2') + currval('testdb.sc1.s3') + nextval('testdb.sc1.s3')" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.True(t, hasUpgraded) + require.Equal(t, + "((nextval(1:::REGCLASS) + nextval(2:::REGCLASS)) + currval(3:::REGCLASS)) + nextval(3:::REGCLASS)", + expr) + }) + + t.Run("mixed by-name and by-ID seq references", func(t *testing.T) { + usedSequenceIDsToNames := make(map[descpb.ID]*tree.TableName) + tbl1 := tree.MakeTableNameWithSchema("testdb", "public", "s1") + tbl2 := tree.MakeTableNameWithSchema("testdb", "public", "s2") + tbl3 := tree.MakeTableNameWithSchema("testdb", "sc1", "s3") + usedSequenceIDsToNames[1] = &tbl1 + usedSequenceIDsToNames[2] = &tbl2 + usedSequenceIDsToNames[3] = &tbl3 + expr := "nextval('testdb.public.s1') + nextval(2::REGCLASS) + currval('testdb.sc1.s3') + nextval('testdb.sc1.s3')" + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + require.NoError(t, err) + require.True(t, hasUpgraded) + require.Equal(t, + "((nextval(1:::REGCLASS) + nextval(2::REGCLASS)) + currval(3:::REGCLASS)) + nextval(3:::REGCLASS)", + expr) + }) +} diff --git a/pkg/sql/catalog/tabledesc/BUILD.bazel b/pkg/sql/catalog/tabledesc/BUILD.bazel index 3acd3b4341b5..775e08b955d0 100644 --- a/pkg/sql/catalog/tabledesc/BUILD.bazel +++ b/pkg/sql/catalog/tabledesc/BUILD.bazel @@ -1,5 +1,6 @@ load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//pkg/testutils/buildutil:buildutil.bzl", "disallowed_imports_test") go_library( name = "tabledesc", @@ -33,6 +34,7 @@ go_library( "//pkg/sql/catalog/internal/validate", "//pkg/sql/catalog/multiregion", "//pkg/sql/catalog/schemaexpr", + "//pkg/sql/catalog/seqexpr", "//pkg/sql/catalog/typedesc", "//pkg/sql/lexbase", "//pkg/sql/parser", @@ -41,6 +43,7 @@ go_library( "//pkg/sql/privilege", "//pkg/sql/rowenc", "//pkg/sql/schemachanger/scpb", + "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/catconstants", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", @@ -115,4 +118,11 @@ go_test( ], ) +disallowed_imports_test( + "tabledesc", + disallowed_list = [ + "//pkg/sql/sem/builtins", + ], +) + get_x_data(name = "get_x_data") diff --git a/pkg/sql/catalog/tabledesc/table_desc_builder.go b/pkg/sql/catalog/tabledesc/table_desc_builder.go index a79003c4a715..0ca4122cd3cc 100644 --- a/pkg/sql/catalog/tabledesc/table_desc_builder.go +++ b/pkg/sql/catalog/tabledesc/table_desc_builder.go @@ -15,8 +15,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/catalog/catprivilege" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr" + "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" ) @@ -113,14 +117,28 @@ func (tdb *tableDescriptorBuilder) RunPostDeserializationChanges() error { func (tdb *tableDescriptorBuilder) RunRestoreChanges( descLookupFn func(id descpb.ID) catalog.Descriptor, ) (err error) { + // Upgrade FK representation upgradedFK, err := maybeUpgradeForeignKeyRepresentation( descLookupFn, tdb.skipFKsWithNoMatchingTable, tdb.maybeModified, ) + if err != nil { + return err + } if upgradedFK { tdb.changes.Add(catalog.UpgradedForeignKeyRepresentation) } + + // Upgrade sequence reference + upgradedSequenceReference, err := maybeUpgradeSequenceReference(descLookupFn, tdb.maybeModified) + if err != nil { + return err + } + if upgradedSequenceReference { + tdb.changes.Add(catalog.UpgradedSequenceReference) + } + return err } @@ -772,3 +790,185 @@ func maybeAddConstraintIDs(desc *descpb.TableDescriptor) (hasChanged bool) { } return desc.NextConstraintID != initialConstraintID } + +// maybeUpgradeSequenceReference attempts to upgrade by-name sequence references. +// If `rel` is a table: upgrade seq reference in each column; +// If `rel` is a view: upgrade seq reference in its view query; +// If `rel` is a sequence: upgrade its back-references to relations as "ByID". +// All these attempts are on a best-effort basis. +func maybeUpgradeSequenceReference( + descLookupFn func(id descpb.ID) catalog.Descriptor, rel *descpb.TableDescriptor, +) (hasUpgraded bool, err error) { + if rel.IsTable() { + hasUpgraded, err = maybeUpgradeSequenceReferenceForTable(descLookupFn, rel) + if err != nil { + return hasUpgraded, err + } + } else if rel.IsView() { + hasUpgraded, err = maybeUpgradeSequenceReferenceForView(descLookupFn, rel) + if err != nil { + return hasUpgraded, err + } + } else if rel.IsSequence() { + // Upgrade all references to this sequence to "by-ID". + for i, ref := range rel.DependedOnBy { + if ref.ID != descpb.InvalidID && !ref.ByID { + rel.DependedOnBy[i].ByID = true + hasUpgraded = true + } + } + } else { + return hasUpgraded, errors.AssertionFailedf("table descriptor %v (%d) is not a "+ + "table, view, or sequence.", rel.Name, rel.ID) + } + + return hasUpgraded, err +} + +// maybeUpgradeSequenceReferenceForTable upgrades all by-name sequence references +// in `tableDesc` to by-ID. +func maybeUpgradeSequenceReferenceForTable( + descLookupFn func(id descpb.ID) catalog.Descriptor, tableDesc *descpb.TableDescriptor, +) (hasUpgraded bool, err error) { + if !tableDesc.IsTable() { + return hasUpgraded, nil + } + + for _, col := range tableDesc.Columns { + // Find sequence names for all sequences used in this column. + usedSequenceIDToNames, err := resolveTableNamesForIDs(descLookupFn, col.UsesSequenceIds) + if err != nil { + return hasUpgraded, err + } + + // Upgrade sequence reference in DEFAULT expression, if any. + if col.HasDefault() { + hasUpgradedInDefault, err := seqexpr.UpgradeSequenceReferenceInExpr(col.DefaultExpr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + if err != nil { + return hasUpgraded, err + } + hasUpgraded = hasUpgraded || hasUpgradedInDefault + } + + // Upgrade sequence reference in ON UPDATE expression, if any. + if col.HasOnUpdate() { + hasUpgradedInOnUpdate, err := seqexpr.UpgradeSequenceReferenceInExpr(col.OnUpdateExpr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + if err != nil { + return hasUpgraded, err + } + hasUpgraded = hasUpgraded || hasUpgradedInOnUpdate + } + } + + return hasUpgraded, nil +} + +// maybeUpgradeSequenceReferenceForView similarily upgrades all by-name sequence references +// in `viewDesc` to by-ID. +func maybeUpgradeSequenceReferenceForView( + descLookupFn func(id descpb.ID) catalog.Descriptor, viewDesc *descpb.TableDescriptor, +) (hasUpgraded bool, err error) { + if !viewDesc.IsView() { + return hasUpgraded, err + } + + // Find sequence names for all those used sequences. + usedSequenceIDToNames, err := resolveTableNamesForIDs(descLookupFn, viewDesc.DependsOn) + if err != nil { + return hasUpgraded, err + } + + // A function that looks at an expression and replace any by-name sequence reference with + // by-ID reference. It, of course, also append replaced sequence IDs to `upgradedSeqIDs`. + replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + newExprStr := expr.String() + hasUpgradedInExpr, err := seqexpr.UpgradeSequenceReferenceInExpr(&newExprStr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + if err != nil { + return false, expr, err + } + newExpr, err = parser.ParseExpr(newExprStr) + if err != nil { + return false, expr, err + } + + hasUpgraded = hasUpgraded || hasUpgradedInExpr + return false, newExpr, err + } + + stmt, err := parser.ParseOne(viewDesc.GetViewQuery()) + if err != nil { + return hasUpgraded, err + } + + newStmt, err := tree.SimpleStmtVisit(stmt.AST, replaceSeqFunc) + if err != nil { + return hasUpgraded, err + } + + viewDesc.ViewQuery = newStmt.String() + + return hasUpgraded, err +} + +// Attempt to fully resolve table names for `ids` from a list of descriptors. +// IDs that do not exist or do not identify a table descriptor will be skipped. +// +// This is done on a best-effort basis, meaning if we cannot find a table's +// schema or database name from `descLookupFn`, they will be set to empty. +// Consumers of the return of this function should hence expect non-fully resolved +// table names. +func resolveTableNamesForIDs( + descLookupFn func(id descpb.ID) catalog.Descriptor, ids []descpb.ID, +) (map[descpb.ID]*tree.TableName, error) { + result := make(map[descpb.ID]*tree.TableName) + + for _, id := range ids { + if _, exists := result[id]; exists { + continue + } + + // Attempt to retrieve the table descriptor for `id`; Skip if it does not exist or it does not + // identify a table descriptor. + d := descLookupFn(id) + tableDesc, ok := d.(catalog.TableDescriptor) + if !ok { + continue + } + + // Attempt to get its database and schema name on a best-effort basis. + dbName := "" + d = descLookupFn(tableDesc.GetParentID()) + if dbDesc, ok := d.(catalog.DatabaseDescriptor); ok { + dbName = dbDesc.GetName() + } + + scName := "" + d = descLookupFn(tableDesc.GetParentSchemaID()) + if d != nil { + if scDesc, ok := d.(catalog.SchemaDescriptor); ok { + scName = scDesc.GetName() + } + } else { + if tableDesc.GetParentSchemaID() == keys.PublicSchemaIDForBackup { + // For backups created in 21.2 and prior, the "public" schema is descriptorless, + // and always uses the const `keys.PublicSchemaIDForBackUp` as the "public" + // schema ID. + scName = tree.PublicSchema + } + } + + result[id] = tree.NewTableNameWithSchema( + tree.Name(dbName), + tree.Name(scName), + tree.Name(tableDesc.GetName()), + ) + if dbName == "" { + result[id].ExplicitCatalog = false + } + if scName == "" { + result[id].ExplicitSchema = false + } + } + + return result, nil +} diff --git a/pkg/sql/create_view.go b/pkg/sql/create_view.go index bab9c18e1ce8..bebf5931993c 100644 --- a/pkg/sql/create_view.go +++ b/pkg/sql/create_view.go @@ -439,13 +439,13 @@ func replaceSeqNamesWithIDs( if err != nil { return false, expr, err } - seqNameToID := make(map[string]int64) + seqNameToID := make(map[string]descpb.ID) for _, seqIdentifier := range seqIdentifiers { seqDesc, err := GetSequenceDescFromIdentifier(ctx, sc, seqIdentifier) if err != nil { return false, expr, err } - seqNameToID[seqIdentifier.SeqName] = int64(seqDesc.ID) + seqNameToID[seqIdentifier.SeqName] = seqDesc.ID } newExpr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID, builtinsregistry.GetBuiltinProperties) if err != nil { diff --git a/pkg/sql/schemachanger/scbuild/builder_state.go b/pkg/sql/schemachanger/scbuild/builder_state.go index 00a9dc58bbf7..6243f1803c90 100644 --- a/pkg/sql/schemachanger/scbuild/builder_state.go +++ b/pkg/sql/schemachanger/scbuild/builder_state.go @@ -415,7 +415,7 @@ func (b *builderState) WrapExpression(parentID catid.DescID, expr tree.Expr) *sc if err != nil { panic(err) } - seqNameToID := make(map[string]int64) + seqNameToID := make(map[string]descpb.ID) for _, seqIdentifier := range seqIdentifiers { if seqIdentifier.IsByID() { seqIDs.Add(catid.DescID(seqIdentifier.SeqID)) @@ -430,7 +430,7 @@ func (b *builderState) WrapExpression(parentID catid.DescID, expr tree.Expr) *sc RequiredPrivilege: privilege.SELECT, }) _, _, seq := scpb.FindSequence(elts) - seqNameToID[seqIdentifier.SeqName] = int64(seq.SequenceID) + seqNameToID[seqIdentifier.SeqName] = seq.SequenceID seqIDs.Add(seq.SequenceID) } if len(seqNameToID) > 0 { diff --git a/pkg/sql/sequence.go b/pkg/sql/sequence.go index 53bc836a2c27..6f7fa940ec2c 100644 --- a/pkg/sql/sequence.go +++ b/pkg/sql/sequence.go @@ -833,7 +833,7 @@ func maybeAddSequenceDependencies( } var seqDescs []*tabledesc.Mutable - seqNameToID := make(map[string]int64) + seqNameToID := make(map[string]descpb.ID) for _, seqIdentifier := range seqIdentifiers { seqDesc, err := GetSequenceDescFromIdentifier(ctx, sc, seqIdentifier) if err != nil { @@ -850,7 +850,7 @@ func maybeAddSequenceDependencies( ) } - seqNameToID[seqIdentifier.SeqName] = int64(seqDesc.ID) + seqNameToID[seqIdentifier.SeqName] = seqDesc.ID // If we had already modified this Sequence as part of this transaction, // we only want to modify a single instance of it instead of overwriting it. diff --git a/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go b/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go index 6bfbed0a8d1c..66ee04bb8ed3 100644 --- a/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go +++ b/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go @@ -283,8 +283,8 @@ func maybeUpdateBackRefsAndBuildMap( t *tabledesc.Mutable, seqIdentifiers []seqexpr.SeqIdentifier, changedSeqDescs *[]*tabledesc.Mutable, -) (map[string]int64, error) { - seqNameToID := make(map[string]int64) +) (map[string]descpb.ID, error) { + seqNameToID := make(map[string]descpb.ID) for _, seqIdentifier := range seqIdentifiers { seqDesc, err := sql.GetSequenceDescFromIdentifier(ctx, sc, seqIdentifier) if err != nil { @@ -309,7 +309,7 @@ func maybeUpdateBackRefsAndBuildMap( *changedSeqDescs = append(*changedSeqDescs, seqDesc) } } - seqNameToID[seqDesc.GetName()] = int64(seqDesc.ID) + seqNameToID[seqDesc.GetName()] = seqDesc.ID } return seqNameToID, nil