From 23d7c4f87e7941ad648266ce6fe7f50f0359d834 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 19 Oct 2021 01:23:38 +0530 Subject: [PATCH] handle sql_mode differently for new value change validation Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/set.go | 66 ++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index 53d8fe024de..2b5cdc5ad69 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -327,6 +327,9 @@ func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo. func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) { sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where @@%s != %s", svs.Expr, svs.Name, svs.Expr) + if svs.Name == "sql_mode" { + sysVarExprValidationQuery = fmt.Sprintf("select @@%s orig, %s new", svs.Name, svs.Expr) + } rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{key.DestinationKeyspaceID{0}}) if err != nil { return false, err @@ -335,18 +338,69 @@ func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evaleng if err != nil { return false, err } - if len(qr.Rows) == 0 { + changed := len(qr.Rows) > 0 + if !changed { return false, nil } - // TODO : validate how value needs to be stored. - value := qr.Rows[0][0] - buf := new(bytes.Buffer) - value.EncodeSQL(buf) - vcursor.Session().SetSysVar(svs.Name, buf.String()) + + var newVal string + if svs.Name == "sql_mode" { + changed, newVal = sqlModeChangedValue(qr) + if !changed { + return false, nil + } + } else { + buf := new(bytes.Buffer) + value := qr.Rows[0][0] + value.EncodeSQL(buf) + newVal = buf.String() + } + vcursor.Session().SetSysVar(svs.Name, newVal) vcursor.Session().NeedsReservedConn() return true, nil } +func sqlModeChangedValue(qr *sqltypes.Result) (bool, string) { + if len(qr.Fields) != 2 { + return false, "" + } + if len(qr.Rows[0]) != 2 { + return false, "" + } + orig := qr.Rows[0][0].ToString() + newVal := qr.Rows[0][1].ToString() + + origArr := strings.Split(orig, ",") + // Keep track of if the value is seen or not. + origMap := map[string]bool{} + for _, oVal := range origArr { + // Default is not seen. + origMap[oVal] = true + } + uniqOrigVal := len(origMap) + origValSeen := 0 + + changed := false + newValArr := strings.Split(newVal, ",") + for _, nVal := range newValArr { + notSeen, exists := origMap[nVal] + if !exists { + changed = true + break + } + if exists && notSeen { + // Value seen. Turn it off + origMap[nVal] = false + origValSeen++ + } + } + if !changed && uniqOrigVal != origValSeen { + changed = true + } + + return changed, newVal +} + var _ SetOp = (*SysVarSetAware)(nil) // MarshalJSON marshals all the json