Skip to content

Commit

Permalink
Add helper functions for SQL(), Pos(), End() (#120)
Browse files Browse the repository at this point in the history
* Bump Go version to 1.20, and add util.go

* Add comments to ast/util.go

* Add strOpt in ast/util.go

* Refactor (*CallExpr).SQL()

* Refactor (*Select).SQL()

* Refactor (*Select).End()

* Bump Go version in .github/workflows

* Refactor (*Path).Pos(), (*Path).End()

* Refactor (*Path).SQL()

* Remove lastElem

* Revert "Remove lastElem"

This reverts commit e1a3e81.

* Re-organize helper functions
  • Loading branch information
apstndb authored Oct 15, 2024
1 parent cdd31d5 commit e57746c
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 81 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:

- name: Set up Go 1.19
uses: actions/setup-go@v5
- name: Set up Go 1.20
uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: "1.20"
id: go

- name: Check out code into the Go module directory
Expand Down
69 changes: 48 additions & 21 deletions ast/pos.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,51 @@ import (
"github.com/cloudspannerecosystem/memefish/token"
)

// ================================================================================
//
// Helper functions for Pos(), End()
// These functions are intended for use within this file only.
//
// ================================================================================

// lastNode returns last element of Node slice.
// This function corresponds to NodeSliceVar[$] in ast.go.
func lastNode[T Node](s []T) T {
return s[len(s)-1]
}

// firstValidEnd returns the first valid Pos() in argument.
// "valid" means the node is not nil and Pos().Invalid() is not true.
// This function corresponds to "(n0 ?? n1 ?? ...).End()"
func firstValidEnd(ns ...Node) token.Pos {
for _, n := range ns {
if n != nil && !n.End().Invalid() {
return n.End()
}
}
return token.InvalidPos
}

// firstPos returns the Pos() of the first node.
// If argument is an empty slice, this function returns token.InvalidPos.
// This function corresponds to NodeSliceVar[0].pos in ast.go.
func firstPos[T Node](s []T) token.Pos {
if len(s) == 0 {
return token.InvalidPos
}
return s[0].Pos()
}

// lastEnd returns the End() of the last node.
// If argument is an empty slice, this function returns token.InvalidPos.
// This function corresponds to NodeSliceVar[$].end in ast.go.
func lastEnd[T Node](s []T) token.Pos {
if len(s) == 0 {
return token.InvalidPos
}
return lastNode(s).End()
}

// ================================================================================
//
// SELECT
Expand Down Expand Up @@ -39,25 +84,7 @@ func (c *CTE) End() token.Pos { return c.Rparen + 1 }
func (s *Select) Pos() token.Pos { return s.Select }

func (s *Select) End() token.Pos {
if s.Limit != nil {
return s.Limit.End()
}
if s.OrderBy != nil {
return s.OrderBy.End()
}
if s.Having != nil {
return s.Having.End()
}
if s.GroupBy != nil {
return s.GroupBy.End()
}
if s.Where != nil {
return s.Where.End()
}
if s.From != nil {
return s.From.End()
}
return s.Results[len(s.Results)-1].End()
return firstValidEnd(s.Limit, s.OrderBy, s.Having, s.GroupBy, s.Where, s.From, lastNode(s.Results))
}

func (c *CompoundQuery) Pos() token.Pos {
Expand Down Expand Up @@ -376,8 +403,8 @@ func (p *Param) End() token.Pos { return p.Atmark + 1 + token.Pos(len(p.Name)) }
func (i *Ident) Pos() token.Pos { return i.NamePos }
func (i *Ident) End() token.Pos { return i.NameEnd }

func (p *Path) Pos() token.Pos { return p.Idents[0].Pos() }
func (p *Path) End() token.Pos { return p.Idents[len(p.Idents)-1].End() }
func (p *Path) Pos() token.Pos { return firstPos(p.Idents) }
func (p *Path) End() token.Pos { return lastEnd(p.Idents) }

func (a *ArrayLiteral) Pos() token.Pos {
if !a.Array.Invalid() {
Expand Down
123 changes: 67 additions & 56 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,59 @@ package ast

import (
"github.com/cloudspannerecosystem/memefish/token"
"strings"
)

// ================================================================================
//
// Helper functions for SQL()
// These functions are intended for use within this file only.
//
// ================================================================================

// sqlOpt outputs:
//
// when node != nil: left + node.SQL() + right
// else : empty string
//
// This function corresponds to sqlOpt in ast.go
func sqlOpt[T interface {
Node
comparable
}](left string, node T, right string) string {
var zero T
if node == zero {
return ""
}
return left + node.SQL() + right
}

// strOpt outputs:
//
// when pred == true: s
// else : empty string
//
// This function corresponds to {{if pred}}s{{end}} in ast.go
func strOpt(pred bool, s string) string {
if pred {
return s
}
return ""
}

// sqlJoin outputs joined string of SQL() of all elems by sep.
// This function corresponds to sqlJoin in ast.go
func sqlJoin[T Node](elems []T, sep string) string {
var b strings.Builder
for i, r := range elems {
if i > 0 {
b.WriteString(sep)
}
b.WriteString(r.SQL())
}
return b.String()
}

type prec int

const (
Expand Down Expand Up @@ -116,36 +167,16 @@ func (c *CTE) SQL() string {
}

func (s *Select) SQL() string {
sql := "SELECT "
if s.Distinct {
sql += "DISTINCT "
}
if s.AsStruct {
sql += "AS STRUCT "
}
sql += s.Results[0].SQL()
for _, r := range s.Results[1:] {
sql += ", " + r.SQL()
}
if s.From != nil {
sql += " " + s.From.SQL()
}
if s.Where != nil {
sql += " " + s.Where.SQL()
}
if s.GroupBy != nil {
sql += " " + s.GroupBy.SQL()
}
if s.Having != nil {
sql += " " + s.Having.SQL()
}
if s.OrderBy != nil {
sql += " " + s.OrderBy.SQL()
}
if s.Limit != nil {
sql += " " + s.Limit.SQL()
}
return sql
return "SELECT " +
strOpt(s.Distinct, "DISTINCT ") +
strOpt(s.AsStruct, "AS STRUCT ") +
sqlJoin(s.Results, ", ") +
sqlOpt(" ", s.From, "") +
sqlOpt(" ", s.Where, "") +
sqlOpt(" ", s.GroupBy, "") +
sqlOpt(" ", s.Having, "") +
sqlOpt(" ", s.OrderBy, "") +
sqlOpt(" ", s.Limit, "")
}

func (c *CompoundQuery) SQL() string {
Expand Down Expand Up @@ -464,27 +495,11 @@ func (i *IndexExpr) SQL() string {
}

func (c *CallExpr) SQL() string {
sql := c.Func.SQL() + "("
if c.Distinct {
sql += "DISTINCT "
}
for i, a := range c.Args {
if i != 0 {
sql += ", "
}
sql += a.SQL()
}
if len(c.Args) > 0 && len(c.NamedArgs) > 0 {
sql += ", "
}
for i, v := range c.NamedArgs {
if i != 0 {
sql += ", "
}
sql += v.SQL()
}
sql += ")"
return sql
return c.Func.SQL() + "(" + strOpt(c.Distinct, "DISTINCT ") +
sqlJoin(c.Args, ", ") +
strOpt(len(c.Args) > 0 && len(c.NamedArgs) > 0, ", ") +
sqlJoin(c.NamedArgs, ", ") +
")"
}

func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() }
Expand Down Expand Up @@ -595,11 +610,7 @@ func (i *Ident) SQL() string {
}

func (p *Path) SQL() string {
sql := p.Idents[0].SQL()
for _, id := range p.Idents[1:] {
sql += "." + id.SQL()
}
return sql
return sqlJoin(p.Idents, ".")
}

func (a *ArrayLiteral) SQL() string {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/cloudspannerecosystem/memefish

go 1.19
go 1.20

require (
github.com/MakeNowJust/heredoc/v2 v2.0.1
Expand Down

0 comments on commit e57746c

Please sign in to comment.