Skip to content

Commit

Permalink
Merge pull request #6080 from planetscale/set-parser-update
Browse files Browse the repository at this point in the history
Set parser update
  • Loading branch information
systay authored Apr 20, 2020
2 parents 19d215e + ff80083 commit 7589230
Show file tree
Hide file tree
Showing 15 changed files with 3,495 additions and 3,422 deletions.
108 changes: 9 additions & 99 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package sqlparser

import (
"fmt"
"strconv"
"strings"
"unicode"

Expand Down Expand Up @@ -96,6 +95,15 @@ func CanNormalize(stmt Statement) bool {
return false
}

//IsSetStatement takes Statement and returns if the statement is set statement.
func IsSetStatement(stmt Statement) bool {
switch stmt.(type) {
case *Set:
return true
}
return false
}

// Preview analyzes the beginning of the query using a simpler and faster
// textual comparison to identify the statement type.
func Preview(sql string) StatementType {
Expand Down Expand Up @@ -365,101 +373,3 @@ func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
}
return sqltypes.PlanValue{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "expression is too complex '%v'", String(node))
}

// SetKey is the extracted key from one SetExpr
type SetKey struct {
Key string
Scope string
}

// ExtractSetValues returns a map of key-value pairs
// if the query is a SET statement. Values can be bool, int64 or string.
// Since set variable names are case insensitive, all keys are returned
// as lower case.
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
stmt, err := Parse(sql)
if err != nil {
return nil, "", err
}
setStmt, ok := stmt.(*Set)
if !ok {
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
}
result := make(map[SetKey]interface{})
for _, expr := range setStmt.Exprs {
var scope string
key := expr.Name.Lowered()

switch expr.Name.at {
case NoAt:
scope = ImplicitStr
case SingleAt:
scope = VariableStr
case DoubleAt:
switch {
case strings.HasPrefix(key, "global."):
scope = GlobalStr
key = strings.TrimPrefix(key, "global.")
case strings.HasPrefix(key, "session."):
scope = SessionStr
key = strings.TrimPrefix(key, "session.")
case strings.HasPrefix(key, "vitess_metadata."):
scope = VitessMetadataStr
key = strings.TrimPrefix(key, "vitess_metadata.")
default:
scope = SessionStr
}

// This is what correctly allows us to handle queries such as "set @@session.`autocommit`=1"
// it will remove backticks and double quotes that might surround the part after the first period
_, out := NewStringTokenizer(key).Scan()
key = string(out)
}

if setStmt.Scope != "" && scope != "" {
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
}

setKey := SetKey{
Key: key,
Scope: scope,
}

switch expr := expr.Expr.(type) {
case *SQLVal:
switch expr.Type {
case StrVal:
result[setKey] = strings.ToLower(string(expr.Val))
case IntVal:
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
case FloatVal:
num, err := strconv.ParseFloat(string(expr.Val), 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
default:
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
}
case BoolVal:
var val int64
if expr {
val = 1
}
result[setKey] = val
case *ColName:
result[setKey] = expr.Name.String()
case *NullVal:
result[setKey] = nil
case *Default:
result[setKey] = "default"
default:
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
}
}
return result, strings.ToLower(setStmt.Scope), nil
}
185 changes: 0 additions & 185 deletions go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/sqltypes"
)
Expand Down Expand Up @@ -435,188 +432,6 @@ func TestNewPlanValue(t *testing.T) {
}
}

func TestExtractSetValues(t *testing.T) {
testcases := []struct {
sql string
out map[SetKey]interface{}
scope string
err string
}{{
sql: "invalid",
err: "syntax error at position 8 near 'invalid'",
}, {
sql: "select * from t",
err: "ast did not yield *sqlparser.Set: *sqlparser.Select",
}, {
sql: "set autocommit=1+1",
err: "invalid syntax: 1 + 1",
}, {
sql: "set transaction_mode='single'",
out: map[SetKey]interface{}{{Key: "transaction_mode", Scope: ImplicitStr}: "single"},
}, {
sql: "set autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "set autocommit=true",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "set autocommit=false",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(0)},
}, {
sql: "set autocommit=on",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: "on"},
}, {
sql: "set autocommit=off",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: "off"},
}, {
sql: "set @@global.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: GlobalStr}: int64(1)},
}, {
sql: "set @@global.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: GlobalStr}: int64(1)},
}, {
sql: "set @@session.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.`autocommit`=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.'autocommit'=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.\"autocommit\"=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.'\"autocommit'=1",
out: map[SetKey]interface{}{{Key: "\"autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.`autocommit'`=1",
out: map[SetKey]interface{}{{Key: "autocommit'", Scope: SessionStr}: int64(1)},
}, {
sql: "set AUTOCOMMIT=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "SET character_set_results = NULL",
out: map[SetKey]interface{}{{Key: "character_set_results", Scope: ImplicitStr}: nil},
}, {
sql: "SET foo = 0x1234",
err: "invalid value type: 0x1234",
}, {
sql: "SET names utf8",
out: map[SetKey]interface{}{{Key: "names", Scope: ImplicitStr}: "utf8"},
}, {
sql: "SET names ascii collate ascii_bin",
out: map[SetKey]interface{}{{Key: "names", Scope: ImplicitStr}: "ascii"},
}, {
sql: "SET charset default",
out: map[SetKey]interface{}{{Key: "charset", Scope: ImplicitStr}: "default"},
}, {
sql: "SET character set ascii",
out: map[SetKey]interface{}{{Key: "charset", Scope: ImplicitStr}: "ascii"},
}, {
sql: "SET SESSION wait_timeout = 3600",
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: ImplicitStr}: int64(3600)},
scope: SessionStr,
}, {
sql: "SET GLOBAL wait_timeout = 3600",
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: ImplicitStr}: int64(3600)},
scope: GlobalStr,
}, {
sql: "set session transaction isolation level repeatable read",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelRepeatableRead},
scope: SessionStr,
}, {
sql: "set session transaction isolation level read committed",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelReadCommitted},
scope: SessionStr,
}, {
sql: "set session transaction isolation level read uncommitted",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelReadUncommitted},
scope: SessionStr,
}, {
sql: "set session transaction isolation level serializable",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelSerializable},
scope: SessionStr,
}, {
sql: "set transaction isolation level serializable",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelSerializable},
}, {
sql: "set transaction read only",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadOnly},
}, {
sql: "set transaction read write",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadWrite},
}, {
sql: "set session transaction read write",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadWrite},
scope: SessionStr,
}, {
sql: "set session tx_read_only = 0",
out: map[SetKey]interface{}{{Key: "tx_read_only", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session tx_read_only = 1",
out: map[SetKey]interface{}{{Key: "tx_read_only", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set session sql_safe_updates = 0",
out: map[SetKey]interface{}{{Key: "sql_safe_updates", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session transaction_read_only = 0",
out: map[SetKey]interface{}{{Key: "transaction_read_only", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session transaction_read_only = 1",
out: map[SetKey]interface{}{{Key: "transaction_read_only", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set session sql_safe_updates = 1",
out: map[SetKey]interface{}{{Key: "sql_safe_updates", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set @foo = 42",
out: map[SetKey]interface{}{
{Key: "foo", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set @foo.bar.baz = 42",
out: map[SetKey]interface{}{
{Key: "foo.bar.baz", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set @`string` = 'abc', @`float` = 4.2, @`int` = 42",
out: map[SetKey]interface{}{
{Key: "string", Scope: VariableStr}: "abc",
{Key: "float", Scope: VariableStr}: 4.2,
{Key: "int", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set session @foo = 42",
err: "unsupported in set: scope and user defined variables",
}, {
sql: "set global @foo = 42",
err: "unsupported in set: scope and user defined variables",
}}
for _, tcase := range testcases {
t.Run(tcase.sql, func(t *testing.T) {
out, _, err := ExtractSetValues(tcase.sql)
if tcase.err != "" {
require.Error(t, err, tcase.err)
} else if err != nil {
require.NoError(t, err)
}

if diff := cmp.Diff(tcase.out, out); diff != "" {
t.Error(diff)
}
})
}
}

func newStrVal(in string) *SQLVal {
return NewStrVal([]byte(in))
}
Expand Down
Loading

0 comments on commit 7589230

Please sign in to comment.