diff --git a/pkg/sql/crdb_internal.go b/pkg/sql/crdb_internal.go index 9613b2fdb860..0bf4fdec44c1 100644 --- a/pkg/sql/crdb_internal.go +++ b/pkg/sql/crdb_internal.go @@ -2648,7 +2648,11 @@ CREATE TABLE crdb_internal.create_function_statements ( } for i := range treeNode.Options { if body, ok := treeNode.Options[i].(tree.FunctionBodyStr); ok { - seqReplacedBody, err := formatQuerySequencesForDisplay(ctx, &p.semaCtx, string(body), true /* multiStmt */) + typeReplacedBody, err := formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), string(body)) + if err != nil { + return err + } + seqReplacedBody, err := formatQuerySequencesForDisplay(ctx, &p.semaCtx, typeReplacedBody, true /* multiStmt */) if err != nil { return err } diff --git a/pkg/sql/logictest/testdata/logic_test/udf b/pkg/sql/logictest/testdata/logic_test/udf index 6d845c869f58..da095c14db2d 100644 --- a/pkg/sql/logictest/testdata/logic_test/udf +++ b/pkg/sql/logictest/testdata/logic_test/udf @@ -1251,7 +1251,7 @@ CREATE FUNCTION public.f_udt_rewrite() CALLED ON NULL INPUT LANGUAGE SQL AS $$ - SELECT b'@':::@100115; + SELECT 'Monday':::test.public.notmyworkday; $$ query T diff --git a/pkg/sql/show_create_clauses.go b/pkg/sql/show_create_clauses.go index d04de0a6b82f..751c77995546 100644 --- a/pkg/sql/show_create_clauses.go +++ b/pkg/sql/show_create_clauses.go @@ -281,6 +281,77 @@ func formatViewQueryTypesForDisplay( return newStmt.String(), nil } +// formatFunctionQueryTypesForDisplay is similar to +// formatViewQueryTypesForDisplay but can only be used for function. +// nil is used as the table descriptor for schemaexpr.FormatExprForDisplay call. +// This is fine assuming that UDFs cannot be created with expression casting a +// column/var to an enum in function body. This is super rare case for now, and +// it's tracked with issue #87475. We should also unify this function with +// formatViewQueryTypesForDisplay. +func formatFunctionQueryTypesForDisplay( + ctx context.Context, + semaCtx *tree.SemaContext, + sessionData *sessiondata.SessionData, + queries string, +) (string, error) { + replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + // We need to resolve the type to check if it's user-defined. If not, + // no other work is needed. + var typRef tree.ResolvableTypeReference + switch n := expr.(type) { + case *tree.CastExpr: + typRef = n.Type + case *tree.AnnotateTypeExpr: + typRef = n.Type + default: + return true, expr, nil + } + var typ *types.T + typ, err = tree.ResolveType(ctx, typRef, semaCtx.TypeResolver) + if err != nil { + return false, expr, err + } + if !typ.UserDefined() { + return true, expr, nil + } + formattedExpr, err := schemaexpr.FormatExprForDisplay( + ctx, nil, expr.String(), semaCtx, sessionData, tree.FmtParsable, + ) + if err != nil { + return false, expr, err + } + newExpr, err = parser.ParseExpr(formattedExpr) + if err != nil { + return false, expr, err + } + return false, newExpr, nil + } + + var stmts tree.Statements + parsedStmts, err := parser.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query") + } + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + + fmtCtx := tree.NewFmtCtx(tree.FmtSimple) + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + fmtCtx.WriteString(";") + } + return fmtCtx.CloseAndGetString(), nil +} + // showComments prints out the COMMENT statements sufficient to populate a // table's comments, including its index and column comments. func showComments(