Skip to content

Commit

Permalink
Merge pull request #9014 from planetscale/sql_mode-change-validation
Browse files Browse the repository at this point in the history
handle sql_mode differently for new value change validation
  • Loading branch information
harshit-gangal authored Oct 19, 2021
2 parents a7d542f + 87ec407 commit 315d907
Show file tree
Hide file tree
Showing 3 changed files with 384 additions and 205 deletions.
60 changes: 57 additions & 3 deletions go/vt/vtgate/engine/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo.

func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) {
sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where @@%s != %s", svs.Expr, svs.Name, svs.Expr)
if svs.Name == "sql_mode" {
sysVarExprValidationQuery = fmt.Sprintf("select @@%s orig, %s new", svs.Name, svs.Expr)
}
rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{key.DestinationKeyspaceID{0}})
if err != nil {
return false, err
Expand All @@ -335,18 +338,69 @@ func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evaleng
if err != nil {
return false, err
}
if len(qr.Rows) == 0 {
changed := len(qr.Rows) > 0
if !changed {
return false, nil
}
// TODO : validate how value needs to be stored.
value := qr.Rows[0][0]

var value sqltypes.Value
if svs.Name == "sql_mode" {
changed, value = sqlModeChangedValue(qr)
if !changed {
return false, nil
}
} else {
value = qr.Rows[0][0]
}
buf := new(bytes.Buffer)
value.EncodeSQL(buf)
vcursor.Session().SetSysVar(svs.Name, buf.String())
vcursor.Session().NeedsReservedConn()
return true, nil
}

func sqlModeChangedValue(qr *sqltypes.Result) (bool, sqltypes.Value) {
if len(qr.Fields) != 2 {
return false, sqltypes.Value{}
}
if len(qr.Rows[0]) != 2 {
return false, sqltypes.Value{}
}
orig := qr.Rows[0][0].ToString()
newVal := qr.Rows[0][1].ToString()

origArr := strings.Split(orig, ",")
// Keep track of if the value is seen or not.
origMap := map[string]bool{}
for _, oVal := range origArr {
// Default is not seen.
origMap[strings.ToUpper(oVal)] = true
}
uniqOrigVal := len(origMap)
origValSeen := 0

changed := false
newValArr := strings.Split(newVal, ",")
for _, nVal := range newValArr {
nVal = strings.ToUpper(nVal)
notSeen, exists := origMap[nVal]
if !exists {
changed = true
break
}
if exists && notSeen {
// Value seen. Turn it off
origMap[nVal] = false
origValSeen++
}
}
if !changed && uniqOrigVal != origValSeen {
changed = true
}

return changed, qr.Rows[0][1]
}

var _ SetOp = (*SysVarSetAware)(nil)

// MarshalJSON marshals all the json
Expand Down
Loading

0 comments on commit 315d907

Please sign in to comment.