diff --git a/ast/sql.go b/ast/sql.go index eedd404a..9eff0352 100644 --- a/ast/sql.go +++ b/ast/sql.go @@ -7,6 +7,86 @@ import ( "github.com/cloudspannerecosystem/memefish/token" ) +type FormatOption struct { + Newline bool + Indent int +} + +var emptyFormatContext = &FormatContext{ + Option: FormatOption{ + Newline: false, + Indent: 0, + }, + Current: 0, +} + +type FormatContext struct { + Option FormatOption + Current int +} + +func (fc *FormatContext) SQL(node Node) string { + if nodeFormat, ok := node.(NodeFormat); fc != nil && ok { + return nodeFormat.SQLContext(fc) + } else { + return node.SQL() + } +} + +func (fc *FormatContext) Newline() string { + if fc != nil && fc.Option.Newline { + return "\n" + strings.Repeat(" ", fc.Current) + } + return " " +} + +func (fc *FormatContext) NewlineOrEmpty() string { + if fc != nil && fc.Option.Newline { + return "\n" + strings.Repeat(" ", fc.Current) + } + return "" +} + +func (fc *FormatContext) WithIndent(f func(fc *FormatContext) string) string { + var newFc FormatContext + if fc != nil { + newFc = *fc + newFc.Current += newFc.Option.Indent + return f(&newFc) + } else { + return f(emptyFormatContext) + } +} + +func sqlOptCtx[T interface { + Node + comparable +}](fc *FormatContext, left string, node T, right string) string { + var zero T + if node == zero { + return "" + } + return left + fc.SQL(node) + right +} + +// sqlJoin outputs joined string of SQL() of all elems by sep. +// This function corresponds to sqlJoin in ast.go +func sqlJoinCtx[T Node](fc *FormatContext, elems []T, sep string) string { + var b strings.Builder + for i, r := range elems { + if i > 0 { + b.WriteString(sep) + } + b.WriteString(fc.SQL(r)) + } + return b.String() +} + +type NodeFormat interface { + Node + SQLContext(fmtCtx *FormatContext) string +} + // ================================================================================ // // Helper functions for SQL() @@ -134,16 +214,14 @@ func paren(p prec, e Expr) string { // // ================================================================================ +func (q *QueryStatement) SQLContext(fc *FormatContext) string { + return sqlOptCtx(fc, "", q.Hint, fc.Newline()) + + sqlOptCtx(fc, "", q.With, fc.Newline()) + + fc.SQL(q.Query) +} + func (q *QueryStatement) SQL() string { - var sql string - if q.Hint != nil { - sql += q.Hint.SQL() + " " - } - if q.With != nil { - sql += q.With.SQL() + " " - } - sql += q.Query.SQL() - return sql + return q.SQLContext(nil) } func (h *Hint) SQL() string { @@ -171,17 +249,23 @@ func (c *CTE) SQL() string { return c.Name.SQL() + " AS (" + c.QueryExpr.SQL() + ")" } -func (s *Select) SQL() string { +func (s *Select) SQLContext(fc *FormatContext) string { return "SELECT " + strOpt(s.Distinct, "DISTINCT ") + - sqlOpt("", s.As, " ") + - sqlJoin(s.Results, ", ") + - sqlOpt(" ", s.From, "") + - sqlOpt(" ", s.Where, "") + - sqlOpt(" ", s.GroupBy, "") + - sqlOpt(" ", s.Having, "") + - sqlOpt(" ", s.OrderBy, "") + - sqlOpt(" ", s.Limit, "") + sqlOptCtx(fc, "", s.As, " ") + + fc.WithIndent(func(fc *FormatContext) string { + return sqlJoinCtx(fc, s.Results, ","+fc.Newline()) + }) + + sqlOptCtx(fc, fc.Newline(), s.From, "") + + sqlOptCtx(fc, fc.Newline(), s.Where, "") + + sqlOptCtx(fc, fc.Newline(), s.GroupBy, "") + + sqlOptCtx(fc, fc.Newline(), s.Having, "") + + sqlOptCtx(fc, fc.Newline(), s.OrderBy, "") + + sqlOptCtx(fc, fc.Newline(), s.Limit, "") +} + +func (s *Select) SQL() string { + return s.SQLContext(nil) } func (a *AsStruct) SQL() string { return "AS STRUCT" } @@ -242,8 +326,11 @@ func (e *ExprSelectItem) SQL() string { return e.Expr.SQL() } +func (f *From) SQLContext(fc *FormatContext) string { + return "FROM " + fc.SQL(f.Source) +} func (f *From) SQL() string { - return "FROM " + f.Source.SQL() + return f.SQLContext(nil) } func (w *Where) SQL() string { @@ -341,23 +428,27 @@ func (e *PathTableExpr) SQL() string { sqlOpt(" ", e.Sample, "") } +func (s *SubQueryTableExpr) SQLContext(fc *FormatContext) string { + return "(" + fc.WithIndent( + func(fc *FormatContext) string { + return fc.NewlineOrEmpty() + fc.SQL(s.Query) + }) + fc.Newline() + ")" + + sqlOptCtx(fc, " ", s.As, "") + + sqlOptCtx(fc, " ", s.Sample, "") +} + func (s *SubQueryTableExpr) SQL() string { - sql := "(" + s.Query.SQL() + ")" - if s.As != nil { - sql += " " + s.As.SQL() - } - if s.Sample != nil { - sql += " " + s.Sample.SQL() - } - return sql + return s.SQLContext(nil) +} + +func (p *ParenTableExpr) SQLContext(fc *FormatContext) string { + return "(" + fc.WithIndent(func(fc *FormatContext) string { + return fc.NewlineOrEmpty() + p.Source.SQL() + }) + fc.NewlineOrEmpty() + ")" + sqlOpt(" ", p.Sample, "") } func (p *ParenTableExpr) SQL() string { - sql := "(" + p.Source.SQL() + ")" - if p.Sample != nil { - sql += " " + p.Sample.SQL() - } - return sql + return p.SQLContext(nil) } func (j *Join) SQL() string { diff --git a/tools/parse/main.go b/tools/parse/main.go index 17d4a6ac..642ac184 100644 --- a/tools/parse/main.go +++ b/tools/parse/main.go @@ -11,11 +11,12 @@ import ( "unicode" "github.com/MakeNowJust/heredoc/v2" + "github.com/k0kubun/pp" + "github.com/cloudspannerecosystem/memefish" "github.com/cloudspannerecosystem/memefish/ast" "github.com/cloudspannerecosystem/memefish/token" "github.com/cloudspannerecosystem/memefish/tools/util/poslang" - "github.com/k0kubun/pp" ) var usage = heredoc.Doc(` @@ -118,6 +119,14 @@ func main() { fmt.Println() fmt.Println("--- SQL") fmt.Println(node.SQL()) + fmt.Println("--- SQL with indentation") + fmt.Println((&ast.FormatContext{ + Option: ast.FormatOption{ + Newline: true, + Indent: 2, + }, + Current: 0, + }).SQL(node)) if *pos != "" { fmt.Println("--- POS")