diff --git a/dev b/dev index 7cb73c198025..bd0c9fc6a6fe 100755 --- a/dev +++ b/dev @@ -8,7 +8,7 @@ fi set -euo pipefail # Bump this counter to force rebuilding `dev` on all machines. -DEV_VERSION=47 +DEV_VERSION=48 THIS_DIR=$(cd "$(dirname "$0")" && pwd) BINARY_DIR=$THIS_DIR/bin/dev-versions diff --git a/pkg/ccl/backupccl/show.go b/pkg/ccl/backupccl/show.go index 0697945544ac..542d1aaed3c9 100644 --- a/pkg/ccl/backupccl/show.go +++ b/pkg/ccl/backupccl/show.go @@ -105,9 +105,9 @@ func (m manifestInfoReader) header() colinfo.ResultColumns { func (m manifestInfoReader) showBackup( ctx context.Context, mem *mon.BoundAccount, - mkStore cloud.ExternalStorageFromURIFactory, + _ cloud.ExternalStorageFromURIFactory, info backupInfo, - user username.SQLUsername, + _ username.SQLUsername, resultsCh chan<- tree.Datums, ) error { var memReserved int64 @@ -129,7 +129,7 @@ func (m manifestInfoReader) showBackup( return err } - datums, err := m.shower.fn(info) + datums, err := m.shower.fn(ctx, info) if err != nil { return err } @@ -158,7 +158,7 @@ func (m metadataSSTInfoReader) header() colinfo.ResultColumns { func (m metadataSSTInfoReader) showBackup( ctx context.Context, - mem *mon.BoundAccount, + _ *mon.BoundAccount, mkStore cloud.ExternalStorageFromURIFactory, info backupInfo, user username.SQLUsername, @@ -249,15 +249,14 @@ func showBackupPlanHook( case tree.BackupFileDetails: shower = backupShowerFileSetup(backup.InCollection) case tree.BackupSchemaDetails: - shower = backupShowerDefault(ctx, p, true, opts) + shower = backupShowerDefault(p, true, opts) default: - shower = backupShowerDefault(ctx, p, false, opts) + shower = backupShowerDefault(p, false, opts) } infoReader = manifestInfoReader{shower: shower} } fn := func(ctx context.Context, _ []sql.PlanNode, resultsCh chan<- tree.Datums) error { - // TODO(dan): Move this span into sql. ctx, span := tracing.ChildSpan(ctx, stmt.StatementTag()) defer span.Finish() @@ -612,7 +611,7 @@ type backupShower struct { // fn is the specific implementation of the shower that can either be a default, ranges, files, // or JSON shower. - fn func(info backupInfo) ([]tree.Datums, error) + fn func(ctx context.Context, info backupInfo) ([]tree.Datums, error) } // backupShowerHeaders defines the schema for the table presented to the user. @@ -656,11 +655,11 @@ func backupShowerHeaders(showSchemas bool, opts map[string]string) colinfo.Resul } func backupShowerDefault( - ctx context.Context, p sql.PlanHookState, showSchemas bool, opts map[string]string, + p sql.PlanHookState, showSchemas bool, opts map[string]string, ) backupShower { return backupShower{ header: backupShowerHeaders(showSchemas, opts), - fn: func(info backupInfo) ([]tree.Datums, error) { + fn: func(ctx context.Context, info backupInfo) ([]tree.Datums, error) { var rows []tree.Datums for layer, manifest := range info.manifests { // Map database ID to descriptor name. @@ -998,7 +997,7 @@ var backupShowerRanges = backupShower{ {Name: "end_key", Typ: types.Bytes}, }, - fn: func(info backupInfo) (rows []tree.Datums, err error) { + fn: func(ctx context.Context, info backupInfo) (rows []tree.Datums, err error) { for _, manifest := range info.manifests { for _, span := range manifest.Spans { rows = append(rows, tree.Datums{ @@ -1027,7 +1026,7 @@ func backupShowerFileSetup(inCol tree.StringOrPlaceholderOptList) backupShower { {Name: "file_bytes", Typ: types.Int}, }, - fn: func(info backupInfo) (rows []tree.Datums, err error) { + fn: func(ctx context.Context, info backupInfo) (rows []tree.Datums, err error) { var manifestDirs []string var localityAware bool @@ -1161,7 +1160,7 @@ var jsonShower = backupShower{ {Name: "manifest", Typ: types.Jsonb}, }, - fn: func(info backupInfo) ([]tree.Datums, error) { + fn: func(ctx context.Context, info backupInfo) ([]tree.Datums, error) { rows := make([]tree.Datums, len(info.manifests)) for i, manifest := range info.manifests { j, err := protoreflect.MessageToJSON( diff --git a/pkg/ccl/changefeedccl/cdceval/expr_eval.go b/pkg/ccl/changefeedccl/cdceval/expr_eval.go index 08a295847733..c53d618a3cc9 100644 --- a/pkg/ccl/changefeedccl/cdceval/expr_eval.go +++ b/pkg/ccl/changefeedccl/cdceval/expr_eval.go @@ -624,8 +624,7 @@ func checkFunctionSupported( fnVolatility = fnCall.ResolvedOverload().Volatility } else { // Pick highest volatility overload. - for _, o := range fn.Definition { - overload := o.(*tree.Overload) + for _, overload := range fn.Definition { if overload.Volatility > fnVolatility { fnVolatility = overload.Volatility } diff --git a/pkg/cmd/dev/test.go b/pkg/cmd/dev/test.go index 49d0c72f76cb..721b5dbf3750 100644 --- a/pkg/cmd/dev/test.go +++ b/pkg/cmd/dev/test.go @@ -454,7 +454,11 @@ func (d *dev) determineAffectedTargets(ctx context.Context) ([]string, error) { if err != nil { return nil, err } - changedFilesList := strings.Split(strings.TrimSpace(string(changedFiles)), "\n") + trimmedOutput := strings.TrimSpace(string(changedFiles)) + if trimmedOutput == "" { + return nil, nil + } + changedFilesList := strings.Split(trimmedOutput, "\n") // Each file in this list needs to be munged somewhat to match up to the // Bazel target syntax. for idx, file := range changedFilesList { diff --git a/pkg/cmd/roachtest/tests/cdc.go b/pkg/cmd/roachtest/tests/cdc.go index 9ea8f12fafce..73ac113acf74 100644 --- a/pkg/cmd/roachtest/tests/cdc.go +++ b/pkg/cmd/roachtest/tests/cdc.go @@ -529,7 +529,7 @@ func runCDCSchemaRegistry(ctx context.Context, t test.Test, c cluster.Cluster) { updatedMap := make(map[string]struct{}) var resolved []string pagesFetched := 0 - pageSize := 14 + pageSize := 7 for len(updatedMap) < 10 && pagesFetched < 5 { result, err := c.RunWithDetailsSingleNode(ctx, t.L(), kafkaNode, @@ -569,11 +569,8 @@ func runCDCSchemaRegistry(ctx context.Context, t test.Test, c cluster.Cluster) { `{"before":null,"after":{"foo":{"a":{"long":3},"b":{"string":"3"},"c":{"long":3}}},"updated":{"string":""}}`, `{"before":null,"after":{"foo":{"a":{"long":4},"c":{"long":4}}},"updated":{"string":""}}`, `{"before":{"foo_before":{"a":{"long":1},"b":null,"c":null}},"after":{"foo":{"a":{"long":1},"c":null}},"updated":{"string":""}}`, - `{"before":{"foo_before":{"a":{"long":1},"c":null}},"after":{"foo":{"a":{"long":1},"c":null}},"updated":{"string":""}}`, `{"before":{"foo_before":{"a":{"long":2},"b":{"string":"2"},"c":null}},"after":{"foo":{"a":{"long":2},"c":null}},"updated":{"string":""}}`, - `{"before":{"foo_before":{"a":{"long":2},"c":null}},"after":{"foo":{"a":{"long":2},"c":null}},"updated":{"string":""}}`, `{"before":{"foo_before":{"a":{"long":3},"b":{"string":"3"},"c":{"long":3}}},"after":{"foo":{"a":{"long":3},"c":{"long":3}}},"updated":{"string":""}}`, - `{"before":{"foo_before":{"a":{"long":3},"c":{"long":3}}},"after":{"foo":{"a":{"long":3},"c":{"long":3}}},"updated":{"string":""}}`, } } else { expected = []string{ diff --git a/pkg/internal/sqlsmith/relational.go b/pkg/internal/sqlsmith/relational.go index 15820266ea89..2089eeb0a98f 100644 --- a/pkg/internal/sqlsmith/relational.go +++ b/pkg/internal/sqlsmith/relational.go @@ -752,7 +752,7 @@ var countStar = func() tree.TypedExpr { nil, /* window */ typ, &fn.FunctionProperties, - fn.Definition[0].(*tree.Overload), + fn.Definition[0], ) }() diff --git a/pkg/internal/sqlsmith/schema.go b/pkg/internal/sqlsmith/schema.go index a090e95d5e92..a8e7bdd07b49 100644 --- a/pkg/internal/sqlsmith/schema.go +++ b/pkg/internal/sqlsmith/schema.go @@ -534,7 +534,6 @@ var functions = func() map[tree.FunctionClass]map[oid.Oid][]function { continue } for _, ov := range def.Definition { - ov := ov.(*tree.Overload) // Ignore documented unusable functions. if strings.Contains(ov.Info, "Not usable") { continue diff --git a/pkg/sql/catalog/seqexpr/BUILD.bazel b/pkg/sql/catalog/seqexpr/BUILD.bazel index 6471ff742afc..7e7e73cf5130 100644 --- a/pkg/sql/catalog/seqexpr/BUILD.bazel +++ b/pkg/sql/catalog/seqexpr/BUILD.bazel @@ -27,7 +27,6 @@ go_test( "//pkg/sql/catalog/descpb", "//pkg/sql/parser", "//pkg/sql/sem/builtins", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/tree", "//pkg/sql/types", "@com_github_stretchr_testify//require", diff --git a/pkg/sql/catalog/seqexpr/sequence.go b/pkg/sql/catalog/seqexpr/sequence.go index 24ef5b3cd136..a0191e5c34d8 100644 --- a/pkg/sql/catalog/seqexpr/sequence.go +++ b/pkg/sql/catalog/seqexpr/sequence.go @@ -47,24 +47,27 @@ func (si *SeqIdentifier) IsByID() bool { // Returns the identifier of the sequence or nil if no sequence was found. // // `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. -func GetSequenceFromFunc( - funcExpr *tree.FuncExpr, - getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), -) (*SeqIdentifier, error) { +func GetSequenceFromFunc(funcExpr *tree.FuncExpr) (*SeqIdentifier, error) { // Resolve doesn't use the searchPath for resolving FunctionDefinitions // so we can pass in an empty SearchPath. // TODO(mgartner): Plumb a function resolver here, or determine that the // function should have already been resolved. + // TODO(chengxiong): Since we have funcExpr here, it's possible to narrow down + // overloads by using input types. def, err := funcExpr.Func.Resolve(tree.EmptySearchPath, nil /* resolver */) if err != nil { return nil, err } - fnProps, overloads := getBuiltinProperties(def.Name) - if fnProps != nil && fnProps.HasSequenceArguments { + hasSequenceArguments, err := def.GetHasSequenceArguments() + if err != nil { + return nil, err + } + + if hasSequenceArguments { found := false - for _, overload := range overloads { + for _, overload := range def.Definition { // Find the overload that matches funcExpr. if len(funcExpr.Exprs) == overload.Types.Length() { found = true @@ -137,17 +140,14 @@ func getSequenceIdentifier(expr tree.Expr) *SeqIdentifier { // e.g. nextval('foo') => "foo"; nextval(123::regclass) => 123; => nil // // `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. -func GetUsedSequences( - defaultExpr tree.Expr, - getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), -) ([]SeqIdentifier, error) { +func GetUsedSequences(defaultExpr tree.Expr) ([]SeqIdentifier, error) { var seqIdentifiers []SeqIdentifier _, err := tree.SimpleVisit( defaultExpr, func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { switch t := expr.(type) { case *tree.FuncExpr: - identifier, err := GetSequenceFromFunc(t, getBuiltinProperties) + identifier, err := GetSequenceFromFunc(t) if err != nil { return false, nil, err } @@ -170,14 +170,12 @@ func GetUsedSequences( // // `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. func ReplaceSequenceNamesWithIDs( - defaultExpr tree.Expr, - nameToID map[string]descpb.ID, - getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), + defaultExpr tree.Expr, nameToID map[string]descpb.ID, ) (tree.Expr, error) { replaceFn := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { switch t := expr.(type) { case *tree.FuncExpr: - identifier, err := GetSequenceFromFunc(t, getBuiltinProperties) + identifier, err := GetSequenceFromFunc(t) if err != nil { return false, nil, err } @@ -219,13 +217,11 @@ func ReplaceSequenceNamesWithIDs( // // `getBuiltinProperties` argument is commonly builtinsregistry.GetBuiltinProperties. func UpgradeSequenceReferenceInExpr( - expr *string, - usedSequenceIDsToNames map[descpb.ID]*tree.TableName, - getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), + expr *string, usedSequenceIDsToNames map[descpb.ID]*tree.TableName, ) (hasUpgraded bool, err error) { // Find the "reverse" mapping from sequence name to their IDs for those // sequences referenced by-name in `expr`. - usedSequenceNamesToIDs, err := seqNameToIDMappingInExpr(*expr, usedSequenceIDsToNames, getBuiltinProperties) + usedSequenceNamesToIDs, err := seqNameToIDMappingInExpr(*expr, usedSequenceIDsToNames) if err != nil { return false, err } @@ -237,7 +233,7 @@ func UpgradeSequenceReferenceInExpr( return false, err } - newExpr, err := ReplaceSequenceNamesWithIDs(parsedExpr, usedSequenceNamesToIDs, getBuiltinProperties) + newExpr, err := ReplaceSequenceNamesWithIDs(parsedExpr, usedSequenceNamesToIDs) if err != nil { return false, err } @@ -265,15 +261,13 @@ func UpgradeSequenceReferenceInExpr( // // See its unit test for some examples. func seqNameToIDMappingInExpr( - expr string, - seqIDToNameMapping map[descpb.ID]*tree.TableName, - getBuiltinProperties func(name string) (*tree.FunctionProperties, []tree.Overload), + expr string, seqIDToNameMapping map[descpb.ID]*tree.TableName, ) (map[string]descpb.ID, error) { parsedExpr, err := parser.ParseExpr(expr) if err != nil { return nil, err } - seqRefs, err := GetUsedSequences(parsedExpr, getBuiltinProperties) + seqRefs, err := GetUsedSequences(parsedExpr) if err != nil { return nil, err } diff --git a/pkg/sql/catalog/seqexpr/sequence_test.go b/pkg/sql/catalog/seqexpr/sequence_test.go index 04d9494e1618..b29c3192fb88 100644 --- a/pkg/sql/catalog/seqexpr/sequence_test.go +++ b/pkg/sql/catalog/seqexpr/sequence_test.go @@ -19,7 +19,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr" "github.com/cockroachdb/cockroach/pkg/sql/parser" _ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" // register all builtins in builtins:init() for seqexpr package - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/stretchr/testify/require" @@ -53,7 +52,7 @@ func TestGetSequenceFromFunc(t *testing.T) { if !ok { t.Fatal("Expr is not a FuncExpr") } - identifier, err := seqexpr.GetSequenceFromFunc(funcExpr, builtinsregistry.GetBuiltinProperties) + identifier, err := seqexpr.GetSequenceFromFunc(funcExpr) if err != nil { t.Fatal(err) } @@ -99,7 +98,7 @@ func TestGetUsedSequences(t *testing.T) { if err != nil { t.Fatal(err) } - identifiers, err := seqexpr.GetUsedSequences(typedExpr, builtinsregistry.GetBuiltinProperties) + identifiers, err := seqexpr.GetUsedSequences(typedExpr) if err != nil { t.Fatal(err) } @@ -150,7 +149,7 @@ func TestReplaceSequenceNamesWithIDs(t *testing.T) { if err != nil { t.Fatal(err) } - newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(typedExpr, namesToID, builtinsregistry.GetBuiltinProperties) + newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(typedExpr, namesToID) if err != nil { t.Fatal(err) } @@ -169,7 +168,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[1] = &tbl1 usedSequenceIDsToNames[2] = &tbl2 expr := "nextval('testdb.sc1.t') + nextval('sc1.t')" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.True(t, hasUpgraded) require.Equal(t, @@ -184,7 +183,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[1] = &tbl1 usedSequenceIDsToNames[2] = &tbl2 expr := "nextval('testdb.sc1.t') + nextval('sc1.t')" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.True(t, hasUpgraded) require.Equal(t, @@ -200,7 +199,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[1] = &tbl1 usedSequenceIDsToNames[2] = &tbl2 expr := "nextval('testdb.public.t') + nextval('testdb.t')" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.True(t, hasUpgraded) require.Equal(t, @@ -215,7 +214,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[1] = &tbl1 usedSequenceIDsToNames[2] = &tbl2 expr := "nextval('t')" - _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.Error(t, err, "ambiguous name matching for 't'; both 'sc1.t' and 'sc2.t' match it.") require.Equal(t, "more than 1 matches found for \"t\"", err.Error()) }) @@ -227,7 +226,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[1] = &tbl1 usedSequenceIDsToNames[2] = &tbl2 expr := "nextval('t2')" - _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + _, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.Error(t, err, "no matching name for 't2'; neither 'sc1.t' nor 'sc2.t' match it.") require.Equal(t, "no table name found to match input \"t2\"", err.Error()) }) @@ -241,7 +240,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[2] = &tbl2 usedSequenceIDsToNames[3] = &tbl3 expr := "((nextval(1::REGCLASS) + nextval(2::REGCLASS)) + currval(3::REGCLASS)) + nextval(3::REGCLASS)" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.False(t, hasUpgraded) require.Equal(t, @@ -258,7 +257,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[2] = &tbl2 usedSequenceIDsToNames[3] = &tbl3 expr := "nextval('testdb.public.s1') + nextval('testdb.public.s2') + currval('testdb.sc1.s3') + nextval('testdb.sc1.s3')" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.True(t, hasUpgraded) require.Equal(t, @@ -275,7 +274,7 @@ func TestUpgradeSequenceReferenceInExpr(t *testing.T) { usedSequenceIDsToNames[2] = &tbl2 usedSequenceIDsToNames[3] = &tbl3 expr := "nextval('testdb.public.s1') + nextval(2::REGCLASS) + currval('testdb.sc1.s3') + nextval('testdb.sc1.s3')" - hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames, builtinsregistry.GetBuiltinProperties) + hasUpgraded, err := seqexpr.UpgradeSequenceReferenceInExpr(&expr, usedSequenceIDsToNames) require.NoError(t, err) require.True(t, hasUpgraded) require.Equal(t, diff --git a/pkg/sql/catalog/tabledesc/BUILD.bazel b/pkg/sql/catalog/tabledesc/BUILD.bazel index 944afcf73ec9..1f94aec128e6 100644 --- a/pkg/sql/catalog/tabledesc/BUILD.bazel +++ b/pkg/sql/catalog/tabledesc/BUILD.bazel @@ -43,7 +43,6 @@ go_library( "//pkg/sql/privilege", "//pkg/sql/rowenc", "//pkg/sql/schemachanger/scpb", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", diff --git a/pkg/sql/catalog/tabledesc/table_desc_builder.go b/pkg/sql/catalog/tabledesc/table_desc_builder.go index 0ca4122cd3cc..fb73b975aeaf 100644 --- a/pkg/sql/catalog/tabledesc/table_desc_builder.go +++ b/pkg/sql/catalog/tabledesc/table_desc_builder.go @@ -18,7 +18,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/protoutil" @@ -843,7 +842,7 @@ func maybeUpgradeSequenceReferenceForTable( // Upgrade sequence reference in DEFAULT expression, if any. if col.HasDefault() { - hasUpgradedInDefault, err := seqexpr.UpgradeSequenceReferenceInExpr(col.DefaultExpr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + hasUpgradedInDefault, err := seqexpr.UpgradeSequenceReferenceInExpr(col.DefaultExpr, usedSequenceIDToNames) if err != nil { return hasUpgraded, err } @@ -852,7 +851,7 @@ func maybeUpgradeSequenceReferenceForTable( // Upgrade sequence reference in ON UPDATE expression, if any. if col.HasOnUpdate() { - hasUpgradedInOnUpdate, err := seqexpr.UpgradeSequenceReferenceInExpr(col.OnUpdateExpr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + hasUpgradedInOnUpdate, err := seqexpr.UpgradeSequenceReferenceInExpr(col.OnUpdateExpr, usedSequenceIDToNames) if err != nil { return hasUpgraded, err } @@ -882,7 +881,7 @@ func maybeUpgradeSequenceReferenceForView( // by-ID reference. It, of course, also append replaced sequence IDs to `upgradedSeqIDs`. replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { newExprStr := expr.String() - hasUpgradedInExpr, err := seqexpr.UpgradeSequenceReferenceInExpr(&newExprStr, usedSequenceIDToNames, builtinsregistry.GetBuiltinProperties) + hasUpgradedInExpr, err := seqexpr.UpgradeSequenceReferenceInExpr(&newExprStr, usedSequenceIDToNames) if err != nil { return false, expr, err } diff --git a/pkg/sql/create_view.go b/pkg/sql/create_view.go index 8ddf5c58dd37..6b1f4f288555 100644 --- a/pkg/sql/create_view.go +++ b/pkg/sql/create_view.go @@ -33,7 +33,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" @@ -436,7 +435,7 @@ func replaceSeqNamesWithIDs( ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool, ) (string, error) { replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { - seqIdentifiers, err := seqexpr.GetUsedSequences(expr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { return false, expr, err } @@ -448,7 +447,7 @@ func replaceSeqNamesWithIDs( } seqNameToID[seqIdentifier.SeqName] = seqDesc.ID } - newExpr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID, builtinsregistry.GetBuiltinProperties) + newExpr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID) if err != nil { return false, expr, err } diff --git a/pkg/sql/opt/optbuilder/groupby.go b/pkg/sql/opt/optbuilder/groupby.go index 7572a4b0cff5..e03de4a97f70 100644 --- a/pkg/sql/opt/optbuilder/groupby.go +++ b/pkg/sql/opt/optbuilder/groupby.go @@ -876,19 +876,23 @@ func (b *Builder) constructAggregate(name string, args []opt.ScalarExpr) opt.Sca } func isAggregate(def *tree.FunctionDefinition) bool { - return def.Class == tree.AggregateClass -} - -func isWindow(def *tree.FunctionDefinition) bool { - return def.Class == tree.WindowClass + return isClass(def, tree.AggregateClass) } func isGenerator(def *tree.FunctionDefinition) bool { - return def.Class == tree.GeneratorClass + return isClass(def, tree.GeneratorClass) } func isSQLFn(def *tree.FunctionDefinition) bool { - return def.Class == tree.SQLClass + return isClass(def, tree.SQLClass) +} + +func isClass(def *tree.FunctionDefinition, want tree.FunctionClass) bool { + cls, err := def.GetClass() + if err != nil { + panic(err) + } + return cls == want } func newGroupingError(name tree.Name) error { diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index dd62308d42eb..bee9671e88ad 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -25,7 +25,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treebin" @@ -540,11 +539,11 @@ func (b *Builder) buildFunction( return b.buildUDF(f, def, inScope, outScope, outCol) } - if isAggregate(def) { + if f.ResolvedOverload().Class == tree.AggregateClass { panic(errors.AssertionFailedf("aggregate function should have been replaced")) } - if isWindow(def) { + if f.ResolvedOverload().Class == tree.WindowClass { panic(errors.AssertionFailedf("window function should have been replaced")) } @@ -557,17 +556,17 @@ func (b *Builder) buildFunction( out = b.factory.ConstructFunction(args, &memo.FunctionPrivate{ Name: def.Name, Typ: f.ResolvedType(), - Properties: &def.FunctionProperties, + Properties: &f.ResolvedOverload().FunctionProperties, Overload: f.ResolvedOverload(), }) - if isGenerator(def) { + if f.ResolvedOverload().Class == tree.GeneratorClass { return b.finishBuildGeneratorFunction(f, out, inScope, outScope, outCol) } // Add a dependency on sequences that are used as a string argument. if b.trackSchemaDeps { - seqIdentifier, err := seqexpr.GetSequenceFromFunc(f, builtinsregistry.GetBuiltinProperties) + seqIdentifier, err := seqexpr.GetSequenceFromFunc(f) if err != nil { panic(err) } diff --git a/pkg/sql/opt/optbuilder/scope.go b/pkg/sql/opt/optbuilder/scope.go index d30a867f3d1d..ad6d1c446447 100644 --- a/pkg/sql/opt/optbuilder/scope.go +++ b/pkg/sql/opt/optbuilder/scope.go @@ -1015,6 +1015,9 @@ func (s *scope) VisitPre(expr tree.Expr) (recurse bool, newExpr tree.Expr) { case *tree.FuncExpr: semaCtx := s.builder.semaCtx + // TODO(mgartner): At this point the the function has not been type checked + // and resolved to one overload yet. Consider refactoring this so that it + // can handle overloads with the same name. def, err := t.Func.Resolve(semaCtx.SearchPath, semaCtx.FunctionResolver) if err != nil { panic(err) @@ -1133,14 +1136,25 @@ func (s *scope) replaceSRF(f *tree.FuncExpr, def *tree.FunctionDefinition) *srf func isOrderedSetAggregate(def *tree.FunctionDefinition) (*tree.FunctionDefinition, bool) { // The impl functions are private because they should never be run directly. // Thus, they need to be marked as non-private before using them. + + // FunctionProperties exist in function definitions and their overloads, so we + // unset all private fields here. + unsetPrivate := func(def *tree.FunctionDefinition) { + def.Private = true + for i := range def.Definition { + newOverload := *def.Definition[i] + newOverload.Private = false + def.Definition[i] = &newOverload + } + } switch def { case tree.FunDefs["percentile_disc"]: newDef := *tree.FunDefs["percentile_disc_impl"] - newDef.Private = false + unsetPrivate(&newDef) return &newDef, true case tree.FunDefs["percentile_cont"]: newDef := *tree.FunDefs["percentile_cont_impl"] - newDef.Private = false + unsetPrivate(&newDef) return &newDef, true } return def, false @@ -1227,7 +1241,7 @@ func (s *scope) replaceAggregate(f *tree.FuncExpr, def *tree.FunctionDefinition) private := memo.FunctionPrivate{ Name: def.Name, - Properties: &def.FunctionProperties, + Properties: &f.ResolvedOverload().FunctionProperties, Overload: f.ResolvedOverload(), } @@ -1339,7 +1353,7 @@ func (s *scope) replaceWindowFn(f *tree.FuncExpr, def *tree.FunctionDefinition) FuncExpr: f, def: memo.FunctionPrivate{ Name: def.Name, - Properties: &def.FunctionProperties, + Properties: &f.ResolvedOverload().FunctionProperties, Overload: f.ResolvedOverload(), }, } @@ -1394,7 +1408,7 @@ func (s *scope) replaceSQLFn(f *tree.FuncExpr, def *tree.FunctionDefinition) tre FuncExpr: f, def: memo.FunctionPrivate{ Name: def.Name, - Properties: &def.FunctionProperties, + Properties: &f.ResolvedOverload().FunctionProperties, Overload: f.ResolvedOverload(), }, args: args, diff --git a/pkg/sql/opt/optbuilder/srfs.go b/pkg/sql/opt/optbuilder/srfs.go index 1cd06fba4651..35625bc82eb5 100644 --- a/pkg/sql/opt/optbuilder/srfs.go +++ b/pkg/sql/opt/optbuilder/srfs.go @@ -97,7 +97,8 @@ func (b *Builder) buildZip(exprs tree.Exprs, inScope *scope) (outScope *scope) { texpr := inScope.resolveType(expr, types.Any) var def *tree.FunctionDefinition - if funcExpr, ok := texpr.(*tree.FuncExpr); ok { + funcExpr, ok := texpr.(*tree.FuncExpr) + if ok { if def, err = funcExpr.Func.Resolve( b.semaCtx.SearchPath, b.semaCtx.FunctionResolver, ); err != nil { @@ -107,13 +108,13 @@ func (b *Builder) buildZip(exprs tree.Exprs, inScope *scope) (outScope *scope) { var outCol *scopeColumn startCols := len(outScope.cols) - if def == nil || def.Class != tree.GeneratorClass || b.shouldCreateDefaultColumn(texpr) { - if def != nil && len(def.ReturnLabels) > 0 { + if def == nil || funcExpr.ResolvedOverload().Class != tree.GeneratorClass || b.shouldCreateDefaultColumn(texpr) { + if def != nil && len(funcExpr.ResolvedOverload().ReturnLabels) > 0 { // Override the computed alias with the one defined in the ReturnLabels. This // satisfies a Postgres quirk where some json functions use different labels // when used in a from clause. - alias = def.ReturnLabels[0] + alias = funcExpr.ResolvedOverload().ReturnLabels[0] } outCol = outScope.addColumn(scopeColName(tree.Name(alias)), texpr) } diff --git a/pkg/sql/opt/testutils/testcat/function.go b/pkg/sql/opt/testutils/testcat/function.go index ca20e192feb8..6d0a5f2cce22 100644 --- a/pkg/sql/opt/testutils/testcat/function.go +++ b/pkg/sql/opt/testutils/testcat/function.go @@ -150,7 +150,7 @@ func formatFunction(fn *tree.FunctionDefinition) string { if len(fn.Definition) != 1 { panic(fmt.Errorf("functions with multiple overloads not supported")) } - o := fn.Definition[0].(*tree.Overload) + o := fn.Definition[0] tp := treeprinter.New() nullStr := "" if !o.NullableArgs { diff --git a/pkg/sql/parser/help.go b/pkg/sql/parser/help.go index efff558507a0..7adace6a481f 100644 --- a/pkg/sql/parser/help.go +++ b/pkg/sql/parser/help.go @@ -117,8 +117,7 @@ func helpWithFunction(sqllex sqlLexer, f tree.ResolvableFunctionReference) int { // documentation, so we need to also combine the descriptions // together. lastInfo := "" - for i, overload := range d.Definition { - b := overload.(*tree.Overload) + for i, b := range d.Definition { if b.Info != "" && b.Info != lastInfo { if i > 0 { fmt.Fprintln(w, "---") diff --git a/pkg/sql/pg_catalog.go b/pkg/sql/pg_catalog.go index 90d6f495251c..14975c3b3f80 100644 --- a/pkg/sql/pg_catalog.go +++ b/pkg/sql/pg_catalog.go @@ -4224,13 +4224,11 @@ func init() { h := makeOidHasher() tree.OidToBuiltinName = make(map[oid.Oid]string, len(tree.FunDefs)) for name, def := range tree.FunDefs { - for _, o := range def.Definition { - if overload, ok := o.(*tree.Overload); ok { - builtinOid := h.BuiltinOid(name, overload) - id := builtinOid.Oid - tree.OidToBuiltinName[id] = name - overload.Oid = id - } + for _, overload := range def.Definition { + builtinOid := h.BuiltinOid(name, overload) + id := builtinOid.Oid + tree.OidToBuiltinName[id] = name + overload.Oid = id } } } diff --git a/pkg/sql/rename_database.go b/pkg/sql/rename_database.go index 8df2f57d2aa6..f92308f4b340 100644 --- a/pkg/sql/rename_database.go +++ b/pkg/sql/rename_database.go @@ -23,7 +23,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/util" @@ -298,7 +297,7 @@ func isAllowedDependentDescInRenameDatabase( if err != nil { return false, "", err } - seqIdentifiers, err := seqexpr.GetUsedSequences(typedExpr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(typedExpr) if err != nil { return false, "", err } diff --git a/pkg/sql/row/BUILD.bazel b/pkg/sql/row/BUILD.bazel index 3e0673b13bd5..d548d0f188ca 100644 --- a/pkg/sql/row/BUILD.bazel +++ b/pkg/sql/row/BUILD.bazel @@ -53,7 +53,6 @@ go_library( "//pkg/sql/rowinfra", "//pkg/sql/scrub", "//pkg/sql/sem/builtins/builtinconstants", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/eval", "//pkg/sql/sem/transform", "//pkg/sql/sem/tree", diff --git a/pkg/sql/row/expr_walker.go b/pkg/sql/row/expr_walker.go index ffef6a99f6e2..96326aaa4fc1 100644 --- a/pkg/sql/row/expr_walker.go +++ b/pkg/sql/row/expr_walker.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" @@ -629,7 +628,7 @@ var supportedImportFuncOverrides = map[string]*customFunc{ visitorSideEffect: func(annot *tree.Annotations, fn *tree.FuncExpr) error { // Get sequence name so that we can update the annotation with the number // of nextval calls to this sequence in a row. - seqIdentifier, err := seqexpr.GetSequenceFromFunc(fn, builtinsregistry.GetBuiltinProperties) + seqIdentifier, err := seqexpr.GetSequenceFromFunc(fn) if err != nil { return err } diff --git a/pkg/sql/schemachanger/scbuild/BUILD.bazel b/pkg/sql/schemachanger/scbuild/BUILD.bazel index 1240c5fca8ac..bbcecc46c9c1 100644 --- a/pkg/sql/schemachanger/scbuild/BUILD.bazel +++ b/pkg/sql/schemachanger/scbuild/BUILD.bazel @@ -39,7 +39,6 @@ go_library( "//pkg/sql/schemachanger/scerrors", "//pkg/sql/schemachanger/scpb", "//pkg/sql/schemachanger/screl", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", diff --git a/pkg/sql/schemachanger/scbuild/builder_state.go b/pkg/sql/schemachanger/scbuild/builder_state.go index 86e300a6a198..a6de43e29442 100644 --- a/pkg/sql/schemachanger/scbuild/builder_state.go +++ b/pkg/sql/schemachanger/scbuild/builder_state.go @@ -34,7 +34,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scerrors" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/screl" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -412,7 +411,7 @@ func (b *builderState) WrapExpression(parentID catid.DescID, expr tree.Expr) *sc // Collect sequence IDs. var seqIDs catalog.DescriptorIDSet { - seqIdentifiers, err := seqexpr.GetUsedSequences(expr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { panic(err) } @@ -435,7 +434,7 @@ func (b *builderState) WrapExpression(parentID catid.DescID, expr tree.Expr) *sc seqIDs.Add(seq.SequenceID) } if len(seqNameToID) > 0 { - expr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID, builtinsregistry.GetBuiltinProperties) + expr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID) if err != nil { panic(err) } diff --git a/pkg/sql/schemachanger/scdecomp/BUILD.bazel b/pkg/sql/schemachanger/scdecomp/BUILD.bazel index d79cad126202..e3cd8449f8b6 100644 --- a/pkg/sql/schemachanger/scdecomp/BUILD.bazel +++ b/pkg/sql/schemachanger/scdecomp/BUILD.bazel @@ -20,7 +20,6 @@ go_library( "//pkg/sql/parser", "//pkg/sql/schemachanger/scerrors", "//pkg/sql/schemachanger/scpb", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/tree", diff --git a/pkg/sql/schemachanger/scdecomp/helpers.go b/pkg/sql/schemachanger/scdecomp/helpers.go index c29eeed4ecb7..6b52ba8e4b1e 100644 --- a/pkg/sql/schemachanger/scdecomp/helpers.go +++ b/pkg/sql/schemachanger/scdecomp/helpers.go @@ -20,7 +20,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scerrors" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/lib/pq/oid" @@ -68,7 +67,7 @@ func (w *walkCtx) newExpression(expr string) (*scpb.Expression, error) { } var seqIDs catalog.DescriptorIDSet { - seqIdents, err := seqexpr.GetUsedSequences(e, builtinsregistry.GetBuiltinProperties) + seqIdents, err := seqexpr.GetUsedSequences(e) if err != nil { return nil, err } diff --git a/pkg/sql/schemachanger/scexec/scmutationexec/BUILD.bazel b/pkg/sql/schemachanger/scexec/scmutationexec/BUILD.bazel index a659f6754d86..110a6e3b5a05 100644 --- a/pkg/sql/schemachanger/scexec/scmutationexec/BUILD.bazel +++ b/pkg/sql/schemachanger/scexec/scmutationexec/BUILD.bazel @@ -35,7 +35,6 @@ go_library( "//pkg/sql/schemachanger/scop", "//pkg/sql/schemachanger/scpb", "//pkg/sql/schemachanger/screl", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/catid", "//pkg/sql/sem/tree", "//pkg/sql/types", diff --git a/pkg/sql/schemachanger/scexec/scmutationexec/helpers.go b/pkg/sql/schemachanger/scexec/scmutationexec/helpers.go index 84a751f0aacc..9640b3206da6 100644 --- a/pkg/sql/schemachanger/scexec/scmutationexec/helpers.go +++ b/pkg/sql/schemachanger/scexec/scmutationexec/helpers.go @@ -22,7 +22,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/util/log/eventpb" "github.com/cockroachdb/errors" ) @@ -255,7 +254,7 @@ func sequenceIDsInExpr(expr string) (ids catalog.DescriptorIDSet, _ error) { if err != nil { return ids, err } - seqIdents, err := seqexpr.GetUsedSequences(e, builtinsregistry.GetBuiltinProperties) + seqIdents, err := seqexpr.GetUsedSequences(e) if err != nil { return ids, err } diff --git a/pkg/sql/sem/eval/parse_doid.go b/pkg/sql/sem/eval/parse_doid.go index d8b7c69601b3..8734ff9cc8ef 100644 --- a/pkg/sql/sem/eval/parse_doid.go +++ b/pkg/sql/sem/eval/parse_doid.go @@ -18,7 +18,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" - "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -79,11 +78,7 @@ func ParseDOid(ctx *Context, s string, t *types.T) (*tree.DOid, error) { return nil, pgerror.Newf(pgcode.AmbiguousAlias, "more than one function named '%s'", funcDef.Name) } - def := funcDef.Definition[0] - overload, ok := def.(*tree.Overload) - if !ok { - return nil, errors.AssertionFailedf("invalid non-overload regproc %s", funcDef.Name) - } + overload := funcDef.Definition[0] return tree.NewDOidWithTypeAndName(overload.Oid, t, funcDef.Name), nil case oid.T_regtype: parsedTyp, err := ctx.Planner.GetTypeFromValidSQLSyntax(s) diff --git a/pkg/sql/sem/transform/aggregates.go b/pkg/sql/sem/transform/aggregates.go index 2e186f42055e..0844d057ad16 100644 --- a/pkg/sql/sem/transform/aggregates.go +++ b/pkg/sql/sem/transform/aggregates.go @@ -15,17 +15,17 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" ) -// IsAggregateVisitor checks if walked expressions contain aggregate functions. -type IsAggregateVisitor struct { +// isAggregateVisitor checks if walked expressions contain aggregate functions. +type isAggregateVisitor struct { Aggregated bool // searchPath is used to search for unqualified function names. searchPath sessiondata.SearchPath } -var _ tree.Visitor = &IsAggregateVisitor{} +var _ tree.Visitor = &isAggregateVisitor{} // VisitPre satisfies the Visitor interface. -func (v *IsAggregateVisitor) VisitPre(expr tree.Expr) (recurse bool, newExpr tree.Expr) { +func (v *isAggregateVisitor) VisitPre(expr tree.Expr) (recurse bool, newExpr tree.Expr) { switch t := expr.(type) { case *tree.FuncExpr: if t.IsWindowFunctionApplication() { @@ -39,7 +39,12 @@ func (v *IsAggregateVisitor) VisitPre(expr tree.Expr) (recurse bool, newExpr tre if err != nil { return false, expr } - if fd.Class == tree.AggregateClass { + funcCls, err := fd.GetClass() + if err != nil { + return false, expr + } + + if funcCls == tree.AggregateClass { v.Aggregated = true return false, expr } @@ -51,4 +56,4 @@ func (v *IsAggregateVisitor) VisitPre(expr tree.Expr) (recurse bool, newExpr tre } // VisitPost satisfies the Visitor interface. -func (*IsAggregateVisitor) VisitPost(expr tree.Expr) tree.Expr { return expr } +func (*isAggregateVisitor) VisitPost(expr tree.Expr) tree.Expr { return expr } diff --git a/pkg/sql/sem/transform/expr_transform.go b/pkg/sql/sem/transform/expr_transform.go index 140176eae869..04c1b0fb7549 100644 --- a/pkg/sql/sem/transform/expr_transform.go +++ b/pkg/sql/sem/transform/expr_transform.go @@ -23,7 +23,7 @@ import ( // visitors between uses. type ExprTransformContext struct { normalizeVisitor normalize.Visitor - isAggregateVisitor IsAggregateVisitor + isAggregateVisitor isAggregateVisitor } // NormalizeExpr is a wrapper around EvalContex.Expr which @@ -55,7 +55,7 @@ func (t *ExprTransformContext) AggregateInExpr( return false } - t.isAggregateVisitor = IsAggregateVisitor{ + t.isAggregateVisitor = isAggregateVisitor{ searchPath: searchPath, } tree.WalkExprConst(&t.isAggregateVisitor, expr) diff --git a/pkg/sql/sem/tree/function_definition.go b/pkg/sql/sem/tree/function_definition.go index cb83d94a8ced..420e3dc6a6fb 100644 --- a/pkg/sql/sem/tree/function_definition.go +++ b/pkg/sql/sem/tree/function_definition.go @@ -10,19 +10,23 @@ package tree -import "github.com/lib/pq/oid" +import ( + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/lib/pq/oid" +) // FunctionDefinition implements a reference to the (possibly several) // overloads for a built-in function. +// TODO(Chengxiong): Remove this struct entirely. Instead, use overloads from +// function resolution or use "GetBuiltinProperties" if the need is to only look +// at builtin functions(there are such existing use cases). type FunctionDefinition struct { // Name is the short name of the function. Name string // Definition is the set of overloads for this function name. - // We use []overloadImpl here although all the uses of this struct - // could actually write a []Overload, because we want to share - // the code with typeCheckOverloadedExprs(). - Definition []overloadImpl + Definition []*Overload // FunctionProperties are the properties common to all overloads. FunctionProperties @@ -136,7 +140,7 @@ var _ = NormalClass func NewFunctionDefinition( name string, props *FunctionProperties, def []Overload, ) *FunctionDefinition { - overloads := make([]overloadImpl, len(def)) + overloads := make([]*Overload, len(def)) for i := range def { if def[i].PreferredOverload { @@ -144,6 +148,7 @@ func NewFunctionDefinition( props.AmbiguousReturnType = true } + def[i].FunctionProperties = *props overloads[i] = &def[i] } return &FunctionDefinition{ @@ -170,4 +175,74 @@ var OidToBuiltinName map[oid.Oid]string func (fd *FunctionDefinition) Format(ctx *FmtCtx) { ctx.WriteString(fd.Name) } + +// String implements the Stringer interface. func (fd *FunctionDefinition) String() string { return AsString(fd) } + +// TODO(Chengxiong): Remove this method after we moved the +// "UnsupportedWithIssue" check into function resolver implementation. +func (fd *FunctionDefinition) undefined() bool { + return fd.UnsupportedWithIssue != 0 +} + +// GetClass returns function class by checking each overload's Class and returns +// the homogeneous Class value if all overloads are the same Class. Ambiguous +// error is returned if there is any overload with different Class. +func (fd *FunctionDefinition) GetClass() (FunctionClass, error) { + if fd.undefined() { + return fd.Class, nil + } + return getFuncClass(fd.Name, fd.Definition) +} + +// GetReturnLabel returns function ReturnLabel by checking each overload and +// returns a ReturnLabel if all overloads have a ReturnLabel of the same length. +// Ambiguous error is returned if there is any overload has ReturnLabel of a +// different length. This is good enough since we don't create UDF with +// ReturnLabel. +func (fd *FunctionDefinition) GetReturnLabel() ([]string, error) { + if fd.undefined() { + return fd.ReturnLabels, nil + } + return getFuncReturnLabels(fd.Name, fd.Definition) +} + +// GetHasSequenceArguments returns function's HasSequenceArguments flag by +// checking each overload's HasSequenceArguments flag. Ambiguous error is +// returned if there is any overload has a different flag. +func (fd *FunctionDefinition) GetHasSequenceArguments() (bool, error) { + if fd.undefined() { + return fd.HasSequenceArguments, nil + } + return getHasSequenceArguments(fd.Name, fd.Definition) +} + +func getFuncClass(fnName string, fns []*Overload) (FunctionClass, error) { + ret := fns[0].Class + for _, o := range fns { + if o.Class != ret { + return 0, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function class on %s", fnName) + } + } + return ret, nil +} + +func getFuncReturnLabels(fnName string, fns []*Overload) ([]string, error) { + ret := fns[0].ReturnLabels + for _, o := range fns { + if len(ret) != len(o.ReturnLabels) { + return nil, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function return label on %s", fnName) + } + } + return ret, nil +} + +func getHasSequenceArguments(fnName string, fns []*Overload) (bool, error) { + ret := fns[0].HasSequenceArguments + for _, o := range fns { + if ret != o.HasSequenceArguments { + return false, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function sequence argument on %s", fnName) + } + } + return ret, nil +} diff --git a/pkg/sql/sem/tree/function_name.go b/pkg/sql/sem/tree/function_name.go index d8a38ee90213..fb7c61bb7309 100644 --- a/pkg/sql/sem/tree/function_name.go +++ b/pkg/sql/sem/tree/function_name.go @@ -29,6 +29,12 @@ import ( // FunctionReferenceResolver is the interface that provides the ability to // resolve built-in or user-defined function definitions from unresolved names. type FunctionReferenceResolver interface { + // ResolveFunction resolves a group of overloads with the given function name + // within a search path. + // TODO(Chengxiong): Consider adding an optional slice of argument types to + // the input of this method, so that we can try to narrow down the scope of + // overloads a bit earlier and decrease the possibility of ambiguous error + // on function properties. ResolveFunction(name *UnresolvedName, path SearchPath) (*FunctionDefinition, error) } diff --git a/pkg/sql/sem/tree/overload.go b/pkg/sql/sem/tree/overload.go index 46b1455120de..8f4b6c7c0637 100644 --- a/pkg/sql/sem/tree/overload.go +++ b/pkg/sql/sem/tree/overload.go @@ -172,6 +172,9 @@ type Overload struct { // NOTE: when set, a function should be prepared for any of its arguments to // be NULL and should act accordingly. NullableArgs bool + + // FunctionProperties are the properties of this overload. + FunctionProperties } // params implements the overloadImpl interface. diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index 7f8bda04d1ff..cd6bb1e7ff61 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -929,6 +929,7 @@ func NewInvalidFunctionUsageError(class FunctionClass, context string) error { // checkFunctionUsage checks whether a given built-in function is // allowed in the current context. func (sc *SemaContext) checkFunctionUsage(expr *FuncExpr, def *FunctionDefinition) error { + // TODO(Chengxiong): Move def.UnsupportedWithIssue check to function resolver implementation. if def.UnsupportedWithIssue != 0 { // Note: no need to embed the function name in the message; the // caller will add the function name as prefix. @@ -938,15 +939,20 @@ func (sc *SemaContext) checkFunctionUsage(expr *FuncExpr, def *FunctionDefinitio } return unimplemented.NewWithIssueDetail(def.UnsupportedWithIssue, def.Name, msg) } - if def.Private { - return pgerror.Wrapf(errPrivateFunction, pgcode.ReservedName, - "%s()", errors.Safe(def.Name)) - } + if sc == nil { // We can't check anything further. Give up. return nil } + // TODO(Chengxiong): Consider doing this check when we narrow down to an + // overload. This is fine at the moment since we don't allow creating + // aggregate/window functions yet. But, ideally, we should figure out a way + // to do this check after overload resolution. + fnCls, err := def.GetClass() + if err != nil { + return err + } if expr.IsWindowFunctionApplication() { if sc.Properties.required.rejectFlags&RejectWindowApplications != 0 { return NewInvalidFunctionUsageError(WindowClass, sc.Properties.required.context) @@ -960,7 +966,7 @@ func (sc *SemaContext) checkFunctionUsage(expr *FuncExpr, def *FunctionDefinitio } else { // If it is an aggregate function *not used OVER a window*, then // we have an aggregation. - if def.Class == AggregateClass { + if fnCls == AggregateClass { if sc.Properties.Derived.inFuncExpr && sc.Properties.required.rejectFlags&RejectNestedAggregates != 0 { return NewAggInAggError() @@ -971,7 +977,7 @@ func (sc *SemaContext) checkFunctionUsage(expr *FuncExpr, def *FunctionDefinitio sc.Properties.Derived.SeenAggregate = true } } - if def.Class == GeneratorClass { + if fnCls == GeneratorClass { if sc.Properties.Derived.inFuncExpr && sc.Properties.required.rejectFlags&RejectNestedGenerators != 0 { return NewInvalidNestedSRFError(sc.Properties.required.context) @@ -1021,7 +1027,11 @@ func (sc *SemaContext) checkVolatility(v volatility.V) error { // CheckIsWindowOrAgg returns an error if the function definition is not a // window function or an aggregate. func CheckIsWindowOrAgg(def *FunctionDefinition) error { - switch def.Class { + cls, err := def.GetClass() + if err != nil { + return err + } + switch cls { case AggregateClass: case WindowClass: default: @@ -1051,6 +1061,7 @@ func (expr *FuncExpr) TypeCheck( return nil, pgerror.Wrapf(err, pgcode.InvalidParameterValue, "%s()", def.Name) } + if semaCtx != nil { // We'll need to remember we are in a function application to // generate suitable errors in checkFunctionUsage(). We cannot @@ -1072,7 +1083,11 @@ func (expr *FuncExpr) TypeCheck( } } - typedSubExprs, fns, err := typeCheckOverloadedExprs(ctx, semaCtx, desired, def.Definition, false, expr.Exprs...) + overloadImpls := make([]overloadImpl, 0, len(def.Definition)) + for _, o := range def.Definition { + overloadImpls = append(overloadImpls, o) + } + typedSubExprs, fns, err := typeCheckOverloadedExprs(ctx, semaCtx, desired, overloadImpls, false, expr.Exprs...) if err != nil { return nil, pgerror.Wrapf(err, pgcode.InvalidParameterValue, "%s()", def.Name) } @@ -1093,7 +1108,11 @@ func (expr *FuncExpr) TypeCheck( // chooses the overload with preferred type for the given category. For // example, float8 is the preferred type for the numeric category in Postgres. // To match Postgres' behavior, we should add that logic here too. - if def.FunctionProperties.Class == AggregateClass { + funcCls, err := def.GetClass() + if err != nil { + return nil, err + } + if funcCls == AggregateClass { for i := range typedSubExprs { if typedSubExprs[i].ResolvedType().Family() == types.UnknownFamily { var filtered []overloadImpl @@ -1122,8 +1141,8 @@ func (expr *FuncExpr) TypeCheck( // Return NULL if at least one overload is possible, no overload accepts // NULL arguments, the function isn't a generator or aggregate builtin, and // NULL is given as an argument. - if len(fns) > 0 && len(nullableArgFns) == 0 && def.FunctionProperties.Class != GeneratorClass && - def.FunctionProperties.Class != AggregateClass { + if len(fns) > 0 && len(nullableArgFns) == 0 && funcCls != GeneratorClass && + funcCls != AggregateClass { for _, expr := range typedSubExprs { if expr.ResolvedType().Family() == types.UnknownFamily { return DNull, nil @@ -1152,6 +1171,10 @@ func (expr *FuncExpr) TypeCheck( return nil, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous call: %s, candidates are:\n%s", sig, fnsStr) } overloadImpl := fns[0].(*Overload) + if overloadImpl.Private { + return nil, pgerror.Wrapf(errPrivateFunction, pgcode.ReservedName, + "%s()", errors.Safe(def.Name)) + } if expr.IsWindowFunctionApplication() { // Make sure the window function application is of either a built-in window @@ -1164,14 +1187,14 @@ func (expr *FuncExpr) TypeCheck( } } else { // Make sure the window function builtins are used as window function applications. - if def.Class == WindowClass { + if funcCls == WindowClass { return nil, pgerror.Newf(pgcode.WrongObjectType, "window function %s() requires an OVER clause", &expr.Func) } } if expr.Filter != nil { - if def.Class != AggregateClass { + if funcCls != AggregateClass { // Same error message as Postgres. If we have a window function, only // aggregates accept a FILTER clause. return nil, pgerror.Newf(pgcode.WrongObjectType, @@ -1199,7 +1222,7 @@ func (expr *FuncExpr) TypeCheck( expr.Exprs[i] = subExpr } expr.fn = overloadImpl - expr.fnProps = &def.FunctionProperties + expr.fnProps = &overloadImpl.FunctionProperties expr.typ = overloadImpl.returnType()(typedSubExprs) if expr.typ == UnknownReturnType { typeNames := make([]string, 0, len(expr.Exprs)) diff --git a/pkg/sql/sequence.go b/pkg/sql/sequence.go index 871b95a86700..3e068b6732f0 100644 --- a/pkg/sql/sequence.go +++ b/pkg/sql/sequence.go @@ -30,7 +30,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/types" @@ -599,7 +598,7 @@ func maybeAddSequenceDependencies( backrefs map[descpb.ID]*tabledesc.Mutable, colExprKind tabledesc.ColExprKind, ) ([]*tabledesc.Mutable, error) { - seqIdentifiers, err := seqexpr.GetUsedSequences(expr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { return nil, err } @@ -674,7 +673,7 @@ func maybeAddSequenceDependencies( // If sequences are present in the expr (and the cluster is the right version), // walk the expr tree and replace any sequences names with their IDs. if len(seqIdentifiers) > 0 { - newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID, builtinsregistry.GetBuiltinProperties) + newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID) if err != nil { return nil, err } diff --git a/pkg/upgrade/upgrades/BUILD.bazel b/pkg/upgrade/upgrades/BUILD.bazel index 8dbbb9d487e8..212be6b4d513 100644 --- a/pkg/upgrade/upgrades/BUILD.bazel +++ b/pkg/upgrade/upgrades/BUILD.bazel @@ -48,7 +48,6 @@ go_library( "//pkg/sql/catalog/typedesc", "//pkg/sql/parser", "//pkg/sql/privilege", - "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/tree", "//pkg/sql/sessiondata", "//pkg/sql/sqlutil", diff --git a/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go b/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go index 8813e71fe382..482daab72a14 100644 --- a/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go +++ b/pkg/upgrade/upgrades/upgrade_sequence_to_be_referenced_by_ID.go @@ -23,7 +23,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/seqexpr" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/parser" - "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" "github.com/cockroachdb/cockroach/pkg/upgrade" @@ -176,7 +175,7 @@ func upgradeSequenceReferenceInTable( if err != nil { return err } - seqIdentifiers, err := seqexpr.GetUsedSequences(parsedExpr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(parsedExpr) if err != nil { return err } @@ -190,7 +189,7 @@ func upgradeSequenceReferenceInTable( } // Perform the sequence replacement in the default expression. - newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(parsedExpr, seqNameToID, builtinsregistry.GetBuiltinProperties) + newExpr, err := seqexpr.ReplaceSequenceNamesWithIDs(parsedExpr, seqNameToID) if err != nil { return err } @@ -235,7 +234,7 @@ func upgradeSequenceReferenceInView( ) error { var changedSeqDescs []*tabledesc.Mutable replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { - seqIdentifiers, err := seqexpr.GetUsedSequences(expr, builtinsregistry.GetBuiltinProperties) + seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { return false, expr, err } @@ -244,7 +243,7 @@ func upgradeSequenceReferenceInView( return false, expr, err } - newExpr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID, builtinsregistry.GetBuiltinProperties) + newExpr, err = seqexpr.ReplaceSequenceNamesWithIDs(expr, seqNameToID) if err != nil { return false, expr, err }