Skip to content

Commit

Permalink
WIP pretty printing
Browse files Browse the repository at this point in the history
  • Loading branch information
apstndb committed Dec 17, 2024
1 parent 66dfc61 commit b84d9a5
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 33 deletions.
155 changes: 123 additions & 32 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,86 @@ import (
"github.com/cloudspannerecosystem/memefish/token"
)

type FormatOption struct {
Newline bool
Indent int
}

var emptyFormatContext = &FormatContext{
Option: FormatOption{
Newline: false,
Indent: 0,
},
Current: 0,
}

type FormatContext struct {
Option FormatOption
Current int
}

func (fc *FormatContext) SQL(node Node) string {
if nodeFormat, ok := node.(NodeFormat); fc != nil && ok {
return nodeFormat.SQLContext(fc)
} else {
return node.SQL()
}
}

func (fc *FormatContext) Newline() string {
if fc != nil && fc.Option.Newline {
return "\n" + strings.Repeat(" ", fc.Current)
}
return " "
}

func (fc *FormatContext) NewlineOrEmpty() string {
if fc != nil && fc.Option.Newline {
return "\n" + strings.Repeat(" ", fc.Current)
}
return ""
}

func (fc *FormatContext) WithIndent(f func(fc *FormatContext) string) string {
var newFc FormatContext
if fc != nil {
newFc = *fc
newFc.Current += newFc.Option.Indent
return f(&newFc)
} else {
return f(emptyFormatContext)
}
}

func sqlOptCtx[T interface {
Node
comparable
}](fc *FormatContext, left string, node T, right string) string {
var zero T
if node == zero {
return ""
}
return left + fc.SQL(node) + right
}

// sqlJoin outputs joined string of SQL() of all elems by sep.
// This function corresponds to sqlJoin in ast.go
func sqlJoinCtx[T Node](fc *FormatContext, elems []T, sep string) string {
var b strings.Builder
for i, r := range elems {
if i > 0 {
b.WriteString(sep)
}
b.WriteString(fc.SQL(r))
}
return b.String()
}

type NodeFormat interface {
Node
SQLContext(fmtCtx *FormatContext) string
}

// ================================================================================
//
// Helper functions for SQL()
Expand Down Expand Up @@ -134,16 +214,14 @@ func paren(p prec, e Expr) string {
//
// ================================================================================

func (q *QueryStatement) SQLContext(fc *FormatContext) string {
return sqlOptCtx(fc, "", q.Hint, fc.Newline()) +
sqlOptCtx(fc, "", q.With, fc.Newline()) +
fc.SQL(q.Query)
}

func (q *QueryStatement) SQL() string {
var sql string
if q.Hint != nil {
sql += q.Hint.SQL() + " "
}
if q.With != nil {
sql += q.With.SQL() + " "
}
sql += q.Query.SQL()
return sql
return q.SQLContext(nil)
}

func (h *Hint) SQL() string {
Expand Down Expand Up @@ -171,17 +249,23 @@ func (c *CTE) SQL() string {
return c.Name.SQL() + " AS (" + c.QueryExpr.SQL() + ")"
}

func (s *Select) SQL() string {
func (s *Select) SQLContext(fc *FormatContext) string {
return "SELECT " +
strOpt(s.Distinct, "DISTINCT ") +
sqlOpt("", s.As, " ") +
sqlJoin(s.Results, ", ") +
sqlOpt(" ", s.From, "") +
sqlOpt(" ", s.Where, "") +
sqlOpt(" ", s.GroupBy, "") +
sqlOpt(" ", s.Having, "") +
sqlOpt(" ", s.OrderBy, "") +
sqlOpt(" ", s.Limit, "")
sqlOptCtx(fc, "", s.As, " ") +
fc.WithIndent(func(fc *FormatContext) string {
return sqlJoinCtx(fc, s.Results, ","+fc.Newline())
}) +
sqlOptCtx(fc, fc.Newline(), s.From, "") +
sqlOptCtx(fc, fc.Newline(), s.Where, "") +
sqlOptCtx(fc, fc.Newline(), s.GroupBy, "") +
sqlOptCtx(fc, fc.Newline(), s.Having, "") +
sqlOptCtx(fc, fc.Newline(), s.OrderBy, "") +
sqlOptCtx(fc, fc.Newline(), s.Limit, "")
}

func (s *Select) SQL() string {
return s.SQLContext(nil)
}

func (a *AsStruct) SQL() string { return "AS STRUCT" }
Expand Down Expand Up @@ -242,8 +326,11 @@ func (e *ExprSelectItem) SQL() string {
return e.Expr.SQL()
}

func (f *From) SQLContext(fc *FormatContext) string {
return "FROM " + fc.SQL(f.Source)
}
func (f *From) SQL() string {
return "FROM " + f.Source.SQL()
return f.SQLContext(nil)
}

func (w *Where) SQL() string {
Expand Down Expand Up @@ -341,23 +428,27 @@ func (e *PathTableExpr) SQL() string {
sqlOpt(" ", e.Sample, "")
}

func (s *SubQueryTableExpr) SQLContext(fc *FormatContext) string {
return "(" + fc.WithIndent(
func(fc *FormatContext) string {
return fc.NewlineOrEmpty() + fc.SQL(s.Query)
}) + fc.Newline() + ")" +
sqlOptCtx(fc, " ", s.As, "") +
sqlOptCtx(fc, " ", s.Sample, "")
}

func (s *SubQueryTableExpr) SQL() string {
sql := "(" + s.Query.SQL() + ")"
if s.As != nil {
sql += " " + s.As.SQL()
}
if s.Sample != nil {
sql += " " + s.Sample.SQL()
}
return sql
return s.SQLContext(nil)
}

func (p *ParenTableExpr) SQLContext(fc *FormatContext) string {
return "(" + fc.WithIndent(func(fc *FormatContext) string {
return fc.NewlineOrEmpty() + p.Source.SQL()
}) + fc.NewlineOrEmpty() + ")" + sqlOpt(" ", p.Sample, "")
}

func (p *ParenTableExpr) SQL() string {
sql := "(" + p.Source.SQL() + ")"
if p.Sample != nil {
sql += " " + p.Sample.SQL()
}
return sql
return p.SQLContext(nil)
}

func (j *Join) SQL() string {
Expand Down
11 changes: 10 additions & 1 deletion tools/parse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
"unicode"

"github.com/MakeNowJust/heredoc/v2"
"github.com/k0kubun/pp"

"github.com/cloudspannerecosystem/memefish"
"github.com/cloudspannerecosystem/memefish/ast"
"github.com/cloudspannerecosystem/memefish/token"
"github.com/cloudspannerecosystem/memefish/tools/util/poslang"
"github.com/k0kubun/pp"
)

var usage = heredoc.Doc(`
Expand Down Expand Up @@ -118,6 +119,14 @@ func main() {
fmt.Println()
fmt.Println("--- SQL")
fmt.Println(node.SQL())
fmt.Println("--- SQL with indentation")
fmt.Println((&ast.FormatContext{
Option: ast.FormatOption{
Newline: true,
Indent: 2,
},
Current: 0,
}).SQL(node))

if *pos != "" {
fmt.Println("--- POS")
Expand Down

0 comments on commit b84d9a5

Please sign in to comment.