Skip to content

Commit

Permalink
handle sql_mode differently for new value change validation
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Oct 18, 2021
1 parent ce8d37b commit 23d7c4f
Showing 1 changed file with 60 additions and 6 deletions.
66 changes: 60 additions & 6 deletions go/vt/vtgate/engine/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo.

func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) {
sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where @@%s != %s", svs.Expr, svs.Name, svs.Expr)
if svs.Name == "sql_mode" {
sysVarExprValidationQuery = fmt.Sprintf("select @@%s orig, %s new", svs.Name, svs.Expr)
}
rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{key.DestinationKeyspaceID{0}})
if err != nil {
return false, err
Expand All @@ -335,18 +338,69 @@ func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evaleng
if err != nil {
return false, err
}
if len(qr.Rows) == 0 {
changed := len(qr.Rows) > 0
if !changed {
return false, nil
}
// TODO : validate how value needs to be stored.
value := qr.Rows[0][0]
buf := new(bytes.Buffer)
value.EncodeSQL(buf)
vcursor.Session().SetSysVar(svs.Name, buf.String())

var newVal string
if svs.Name == "sql_mode" {
changed, newVal = sqlModeChangedValue(qr)
if !changed {
return false, nil
}
} else {
buf := new(bytes.Buffer)
value := qr.Rows[0][0]
value.EncodeSQL(buf)
newVal = buf.String()
}
vcursor.Session().SetSysVar(svs.Name, newVal)
vcursor.Session().NeedsReservedConn()
return true, nil
}

func sqlModeChangedValue(qr *sqltypes.Result) (bool, string) {
if len(qr.Fields) != 2 {
return false, ""
}
if len(qr.Rows[0]) != 2 {
return false, ""
}
orig := qr.Rows[0][0].ToString()
newVal := qr.Rows[0][1].ToString()

origArr := strings.Split(orig, ",")
// Keep track of if the value is seen or not.
origMap := map[string]bool{}
for _, oVal := range origArr {
// Default is not seen.
origMap[oVal] = true
}
uniqOrigVal := len(origMap)
origValSeen := 0

changed := false
newValArr := strings.Split(newVal, ",")
for _, nVal := range newValArr {
notSeen, exists := origMap[nVal]
if !exists {
changed = true
break
}
if exists && notSeen {
// Value seen. Turn it off
origMap[nVal] = false
origValSeen++
}
}
if !changed && uniqOrigVal != origValSeen {
changed = true
}

return changed, newVal
}

var _ SetOp = (*SysVarSetAware)(nil)

// MarshalJSON marshals all the json
Expand Down

0 comments on commit 23d7c4f

Please sign in to comment.