Skip to content

Commit

Permalink
sqlparser: full type safety for unions
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg committed Mar 12, 2021
1 parent a2f139f commit 9380d38
Show file tree
Hide file tree
Showing 3 changed files with 1,901 additions and 665 deletions.
39 changes: 29 additions & 10 deletions go/vt/sqlparser/goyacc/goyacc.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"go/format"
"io/ioutil"
"os"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -626,7 +627,18 @@ outer:
levprd[nprod] |= ACTFLAG
fmt.Fprintf(fcode, "\n\tcase %v:", nprod)
fmt.Fprintf(fcode, "\n\t\t%sDollar = %sS[%spt-%v:%spt+1]", prefix, prefix, prefix, mem-1, prefix)
cpyact(curprod, mem)

var act bytes.Buffer
var unionType string
cpyact(&act, curprod, mem, &unionType)

if unionType != "" {
fmt.Fprintf(fcode, "\n\t\tvar %sLOCAL %s", prefix, unionType)
}
fcode.Write(act.Bytes())
if unionType != "" {
fmt.Fprintf(fcode, "\n\t\t%sVAL.union = %sLOCAL", prefix, prefix)
}

// action within rule...
t = gettok()
Expand Down Expand Up @@ -1344,9 +1356,9 @@ l1:
return nl
}

func cpyyvalaccess(curprod []int, tok int) {
const fastAppendPrefix = " append($$,"
var fastAppendRe = regexp.MustCompile(`\s+append\(\$[$1],`)

func cpyyvalaccess(fcode *bytes.Buffer, curprod []int, tok int, unionType *string) {
if ntypes == 0 {
fmt.Fprintf(fcode, "%sVAL", prefix)
return
Expand Down Expand Up @@ -1378,11 +1390,15 @@ loop:

case '=':
lvalue = true
if allowFastAppend {
if allowFastAppend && *unionType == "" {
peek, err := finput.Peek(16)
if err == nil && bytes.HasPrefix(peek, []byte(fastAppendPrefix)) {
if err != nil {
errorf("failed to scan forward: %v", err)
}
match := fastAppendRe.Find(peek)
if len(match) > 0 {
fastAppend = true
for range fastAppendPrefix {
for range match {
_ = getrune(finput)
}
} else {
Expand All @@ -1404,17 +1420,20 @@ loop:
fmt.Fprintf(fcode, "\t%sSLICE := (*%s)(%sIaddr(%sVAL.union))\n", prefix, ti.typename, prefix, prefix)
fmt.Fprintf(fcode, "\t*%sSLICE = append(*%sSLICE, ", prefix, prefix)
} else if lvalue {
fmt.Fprintf(fcode, "%sVAL.union", prefix)
} else {
fmt.Fprintf(fcode, "%sLOCAL", prefix)
*unionType = ti.typename
} else if *unionType == "" {
fmt.Fprintf(fcode, "%sVAL.%sUnion()", prefix, typeset[tok])
} else {
fmt.Fprintf(fcode, "%sLOCAL", prefix)
}
fcode.Write(buf.Bytes())
}

//
// copy action to the next ; or closing }
//
func cpyact(curprod []int, max int) {
func cpyact(fcode *bytes.Buffer, curprod []int, max int, unionType *string) {
if !lflag {
fmt.Fprintf(fcode, "\n//line %v:%v", infile, lineno)
}
Expand Down Expand Up @@ -1453,7 +1472,7 @@ loop:
c = getrune(finput)
}
if c == '$' {
cpyyvalaccess(curprod, tok)
cpyyvalaccess(fcode, curprod, tok, unionType)
continue loop
}
if c == '-' {
Expand Down
Loading

0 comments on commit 9380d38

Please sign in to comment.