From e50e20719cab121702ad57919f68bd4c2a6e1d03 Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Wed, 15 Aug 2018 21:59:28 +0800 Subject: [PATCH 01/12] expression: ECB/CBC modes with 128/192/256-bit key length for aes_decrypt/aes_encrypt --- expression/builtin_encryption.go | 137 +++++++++++++++++++++++--- expression/builtin_encryption_test.go | 49 ++++++--- expression/errors.go | 1 + session/session.go | 1 + sessionctx/variable/sysvar.go | 4 +- util/encrypt/aes.go | 40 ++++++++ 6 files changed, 206 insertions(+), 26 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 0430499bbc09d..fc2fc7d174698 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -33,6 +33,10 @@ import ( "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" + "github.com/pingcap/tidb/sessionctx/variable" + "strings" + "strconv" + "crypto/aes" ) var ( @@ -68,10 +72,41 @@ var ( _ builtinFunc = &builtinUncompressedLengthSig{} ) -// TODO: support other mode -const ( - aes128ecbBlobkSize = 16 -) +var aesMode map[string]ModeAES + +type ModeAES struct { + Name string + Mode string + KeySize int + IvRequired bool +} + +func registerModeAES(name string) { + it := strings.Split(name, "-") + keyLen, _ := strconv.Atoi(it[1]) + keySize := keyLen / 8 + ivRequired := true + if it[2] == "ecb" { + ivRequired = false + } + aesMode[name] = ModeAES{ + Name: name, + Mode: it[2], + KeySize: keySize, + IvRequired: ivRequired, + } +} + +func init() { + aesMode = make(map[string]ModeAES, 16) + //TODO support more mode + registerModeAES("aes-128-ecb") + registerModeAES("aes-192-ecb") + registerModeAES("aes-256-ecb") + registerModeAES("aes-128-cbc") + registerModeAES("aes-192-cbc") + registerModeAES("aes-256-cbc") +} type aesDecryptFunctionClass struct { baseFunctionClass @@ -81,7 +116,11 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) bf.tp.Flen = args[0].GetType().Flen // At most. types.SetBinChsClnFlag(bf.tp) sig := &builtinAesDecryptSig{bf} @@ -112,9 +151,43 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - plainText, err := encrypt.AESDecryptWithECB([]byte(cryptStr), key) + modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode := aesMode[modeName] //TODO check mode is exists. + var iv string + if len(b.args) == 3 { + iv, isNull, err = b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + } + if mode.IvRequired { + if len(b.args) != 3 { + err = ErrIncorrectParameterCount.GenByArgs("aes_decrypt") + return "", true, err + } + if len(iv) < aes.BlockSize { + err = errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") + return "", true, err + + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + } else if len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.KeySize) + var plainText []byte + switch modeName { + case "aes-128-ecb", "aes-192-ecb", "aes-256-ecb": + plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) + case "aes-128-cbc", "aes-192-cbc", "aes-256-cbc": + plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) + default: + //TODO + } + if err != nil { return "", true, nil } @@ -129,8 +202,12 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) - bf.tp.Flen = aes128ecbBlobkSize * (args[0].GetType().Flen/aes128ecbBlobkSize + 1) // At most. + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) + bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most. types.SetBinChsClnFlag(bf.tp) sig := &builtinAesEncryptSig{bf} return sig, nil @@ -160,9 +237,43 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - cipherText, err := encrypt.AESEncryptWithECB([]byte(str), key) + modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode := aesMode[modeName] //TODO check mode is exists. + var iv string + if len(b.args) == 3 { + iv, isNull, err = b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + } + if mode.IvRequired { + if len(b.args) != 3 { + err = ErrIncorrectParameterCount.GenByArgs("aes_encrypt") + return "", true, err + } + if len(iv) < aes.BlockSize { + err = errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") + return "", true, err + + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + } else if len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.KeySize) + var cipherText []byte + switch modeName { + case "aes-128-ecb", "aes-192-ecb", "aes-256-ecb": + cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) + case "aes-128-cbc", "aes-192-cbc", "aes-256-cbc": + cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) + default: + //TODO + } + if err != nil { return "", true, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 5d80199f15778..db1cc3ddb1fd3 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -24,42 +24,67 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/sessionctx/variable" ) var aesTests = []struct { - origin interface{} - key interface{} - crypt interface{} + mode string + origin interface{} + key interface{} + initVector interface{} + crypt interface{} }{ - {"pingcap", "1234567890123456", "697BFE9B3F8C2F289DD82C88C7BC95C4"}, - {"pingcap123", "1234567890123456", "CEC348F4EF5F84D3AA6C4FA184C65766"}, - {"pingcap", "123456789012345678901234", "6F1589686860C8E8C7A40A78B25FF2C0"}, - {"pingcap", "123", "996E0CA8688D7AD20819B90B273E01C6"}, - {"pingcap", 123, "996E0CA8688D7AD20819B90B273E01C6"}, - {nil, 123, nil}, + // test for ecb + {"aes-128-ecb", "pingcap", "1234567890123456", nil, "697BFE9B3F8C2F289DD82C88C7BC95C4"}, + {"aes-128-ecb", "pingcap123", "1234567890123456", nil, "CEC348F4EF5F84D3AA6C4FA184C65766"}, + {"aes-128-ecb", "pingcap", "123456789012345678901234", nil, "6F1589686860C8E8C7A40A78B25FF2C0"}, + {"aes-128-ecb", "pingcap", "123", nil, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", "pingcap", 123, nil, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", nil, 123, nil, nil}, + {"aes-192-ecb", "pingcap", "1234567890123456", nil, "9B139FD002E6496EA2D5C73A2265E661"}, + {"aes-256-ecb", "pingcap", "1234567890123456", nil, "F80DCDEDDBE5663BDB68F74AEDDB8EE3"}, + // test for cbc + {"aes-128-cbc", "pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792"}, + {"aes-128-cbc", "pingcap", "123456789012345678901234", "1234567890123456", "483788634DA8817423BA0934FD2C096E"}, + {"aes-192-cbc", "pingcap", "1234567890123456", "1234567890123456", "516391DB38E908ECA93AAB22870EC787"}, + {"aes-256-cbc", "pingcap", "1234567890123456", "1234567890123456", "5D0E22C1E77523AEF5C3E10B65653C8F"}, + {"aes-256-cbc", "pingcap", "12345678901234561234567890123456", "1234567890123456", "A26BA27CA4BE9D361D545AA84A17002D"}, + {"aes-256-cbc", "pingcap", "1234567890123456", "12345678901234561234567890123456", "5D0E22C1E77523AEF5C3E10B65653C8F"}, } func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesEncrypt] for _, tt := range aesTests { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) str := types.NewDatum(tt.origin) key := types.NewDatum(tt.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{str, key})) + args := []types.Datum{str, key} + if tt.initVector != nil { + vec := types.NewDatum(tt.initVector) + args = append(args, vec) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) crypt, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) } - s.testNullInput(c, ast.AesDecrypt) + s.testNullInput(c, ast.AesEncrypt) } func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesDecrypt] for _, test := range aesTests { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(test.mode)) cryptStr := fromHex(test.crypt) key := types.NewDatum(test.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{cryptStr, key})) + args := []types.Datum{cryptStr, key} + if test.initVector != nil { + vec := types.NewDatum(test.initVector) + args = append(args, vec) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) str, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(str, DeepEquals, types.NewDatum(test.origin)) diff --git a/expression/errors.go b/expression/errors.go index f342bf42916ae..c436e9d871046 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -38,6 +38,7 @@ var ( errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) + errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) ) func init() { diff --git a/session/session.go b/session/session.go index d76313e581abd..5b12298bea4df 100644 --- a/session/session.go +++ b/session/session.go @@ -1250,6 +1250,7 @@ const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variab variable.SQLModeVar + quoteCommaQuote + variable.MaxAllowedPacket + quoteCommaQuote + variable.TimeZone + quoteCommaQuote + + variable.BlockEncryptionMode + quoteCommaQuote + /* TiDB specific global variables: */ variable.TiDBSkipUTF8Check + quoteCommaQuote + variable.TiDBIndexJoinBatchSize + quoteCommaQuote + diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index cd57cdc78f380..8eb3a940d39b5 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -246,7 +246,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "myisam_mmap_size", "18446744073709551615"}, {ScopeGlobal, "init_slave", ""}, {ScopeNone, "innodb_buffer_pool_instances", "8"}, - {ScopeGlobal | ScopeSession, "block_encryption_mode", "aes-128-ecb"}, + {ScopeGlobal | ScopeSession, BlockEncryptionMode, "aes-128-ecb"}, {ScopeGlobal | ScopeSession, "max_length_for_sort_data", "1024"}, {ScopeNone, "character_set_system", "utf8"}, {ScopeGlobal | ScopeSession, "interactive_timeout", "28800"}, @@ -745,6 +745,8 @@ const ( WarningCount = "warning_count" // ErrorCount is the name for 'error_count' system variable. ErrorCount = "error_count" + // BlockEncryptionMode is the name for 'block_encryption_mode' system variable. + BlockEncryptionMode = "block_encryption_mode" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index deeea49e48976..6eb81049ed717 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -172,3 +172,43 @@ func DeriveKeyMySQL(key []byte, blockSize int) []byte { } return rKey } + +// AESEncryptWithCBC encrypts data using AES with CBC mode. +func AESEncryptWithCBC(str, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + blockSize := cb.BlockSize() + // The str arguments can be any length, and padding is automatically added to + // str so it is a multiple of a block as required by block-based algorithms such as AES. + // This padding is automatically removed by the AES_DECRYPT() function. + data, err := PKCS7Pad(str, blockSize) + if err != nil { + return nil, err + } + cbc := cipher.NewCBCEncrypter(cb, iv) + crypted := make([]byte, len(data)) + cbc.CryptBlocks(crypted, data) + return crypted, nil +} + +// AESDecryptWithCBC decrypts data using AES with CBC mode. +func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + blockSize := cb.BlockSize() + if len(cryptStr)%blockSize != 0 { + return nil, errors.New("Corrupted data") + } + cbc := cipher.NewCBCDecrypter(cb, iv) + data := make([]byte, len(cryptStr)) + cbc.CryptBlocks(data, cryptStr) + plain, err := PKCS7Unpad(data, blockSize) + if err != nil { + return nil, err + } + return plain, nil +} From 3ce847d2bbf45a35027fb591989a8babc5af73f8 Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Fri, 17 Aug 2018 14:29:37 +0800 Subject: [PATCH 02/12] expression: add more tests for AES --- expression/builtin_encryption.go | 129 +++++++++++--------------- expression/builtin_encryption_test.go | 23 ++++- expression/integration_test.go | 12 +++ util/encrypt/aes_test.go | 74 +++++++++++++++ 4 files changed, 159 insertions(+), 79 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index fc2fc7d174698..bc45e456e9bbf 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -34,9 +34,8 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" "github.com/pingcap/tidb/sessionctx/variable" - "strings" - "strconv" "crypto/aes" + "strings" ) var ( @@ -72,40 +71,20 @@ var ( _ builtinFunc = &builtinUncompressedLengthSig{} ) -var aesMode map[string]ModeAES - -type ModeAES struct { - Name string - Mode string - KeySize int - IvRequired bool -} - -func registerModeAES(name string) { - it := strings.Split(name, "-") - keyLen, _ := strconv.Atoi(it[1]) - keySize := keyLen / 8 - ivRequired := true - if it[2] == "ecb" { - ivRequired = false - } - aesMode[name] = ModeAES{ - Name: name, - Mode: it[2], - KeySize: keySize, - IvRequired: ivRequired, - } +type aesModeAttr struct { + modeName string + keySize int + ivRequired bool } -func init() { - aesMode = make(map[string]ModeAES, 16) - //TODO support more mode - registerModeAES("aes-128-ecb") - registerModeAES("aes-192-ecb") - registerModeAES("aes-256-ecb") - registerModeAES("aes-128-cbc") - registerModeAES("aes-192-cbc") - registerModeAES("aes-256-cbc") +var aesModes = map[string]*aesModeAttr{ + //TODO support more modes + "aes-128-ecb": &aesModeAttr{"ecb", 16, false}, + "aes-192-ecb": &aesModeAttr{"ecb", 24, false}, + "aes-256-ecb": &aesModeAttr{"ecb", 32, false}, + "aes-128-cbc": &aesModeAttr{"cbc", 16, true}, + "aes-192-cbc": &aesModeAttr{"cbc", 24, true}, + "aes-256-cbc": &aesModeAttr{"cbc", 32, true}, } type aesDecryptFunctionClass struct { @@ -151,43 +130,42 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode := aesMode[modeName] //TODO check mode is exists. var iv string - if len(b.args) == 3 { + modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(modeName)] + if !exists { + return "", true, errors.Errorf("unsupported block encryption mode - %v", modeName) + } + if mode.ivRequired { + if len(b.args) != 3 { + return "", true, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") + } iv, isNull, err = b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - } - if mode.IvRequired { - if len(b.args) != 3 { - err = ErrIncorrectParameterCount.GenByArgs("aes_decrypt") - return "", true, err - } if len(iv) < aes.BlockSize { - err = errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") - return "", true, err - + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] - } else if len(b.args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } else { + if len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.KeySize) + key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.keySize) var plainText []byte - switch modeName { - case "aes-128-ecb", "aes-192-ecb", "aes-256-ecb": + switch mode.modeName { + case "ecb": plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) - case "aes-128-cbc", "aes-192-cbc", "aes-256-cbc": + case "cbc": plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) default: - //TODO + return "", true, errors.Errorf("unsupported block encryption mode - %v", mode.modeName) } - if err != nil { return "", true, nil } @@ -237,43 +215,42 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode := aesMode[modeName] //TODO check mode is exists. var iv string - if len(b.args) == 3 { + modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(modeName)] + if !exists { + return "", true, errors.Errorf("unsupported block encryption mode - %v", modeName) + } + if mode.ivRequired { + if len(b.args) != 3 { + return "", true, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") + } iv, isNull, err = b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - } - if mode.IvRequired { - if len(b.args) != 3 { - err = ErrIncorrectParameterCount.GenByArgs("aes_encrypt") - return "", true, err - } if len(iv) < aes.BlockSize { - err = errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") - return "", true, err - + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] - } else if len(b.args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } else { + if len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.KeySize) + key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.keySize) var cipherText []byte - switch modeName { - case "aes-128-ecb", "aes-192-ecb", "aes-256-ecb": + switch mode.modeName { + case "ecb": cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) - case "aes-128-cbc", "aes-192-cbc", "aes-256-cbc": + case "cbc": cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) default: - //TODO + return "", true, errors.Errorf("unsupported block encryption mode - %v", mode.modeName) } - if err != nil { return "", true, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index db1cc3ddb1fd3..4f990d0b490db 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -69,7 +69,7 @@ func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { c.Assert(err, IsNil) c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) } - s.testNullInput(c, ast.AesEncrypt) + s.testAmbiguousInput(c, ast.AesEncrypt) } func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { @@ -89,10 +89,10 @@ func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { c.Assert(err, IsNil) c.Assert(str, DeepEquals, types.NewDatum(test.origin)) } - s.testNullInput(c, ast.AesDecrypt) + s.testAmbiguousInput(c, ast.AesDecrypt) } -func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { +func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { fc := funcs[fnName] arg := types.NewStringDatum("str") var argNull types.Datum @@ -105,6 +105,23 @@ func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { crypt, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(crypt.IsNull(), IsTrue) + + // test for modes that require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-cbc")) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) + crypt, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) + crypt, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + + // test for modes that do not require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, arg})) + crypt, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), GreaterEqual, 1) } func toHex(d types.Datum) (h types.Datum) { diff --git a/expression/integration_test.go b/expression/integration_test.go index 52ee004d0e1f4..c991feb366722 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -977,16 +977,28 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`) + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key')), HEX(AES_ENCRYPT(b, 'key')), HEX(AES_ENCRYPT(c, 'key')), HEX(AES_ENCRYPT(d, 'key')), HEX(AES_ENCRYPT(e, 'key')), HEX(AES_ENCRYPT(f, 'key')), HEX(AES_ENCRYPT(g, 'key')), HEX(AES_ENCRYPT(h, 'key')), HEX(AES_ENCRYPT(i, 'key')) from t") result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC 9E018F7F2838DBA23C57F0E4CCF93287 E764D3E9D4AF8F926CD0979DDB1D0AF40C208B20A6C39D5D028644885280973A C452FFEEB76D3F5E9B26B8D48F7A228C 181BD5C81CBD36779A3C9DD5FF486B35 CE15F14AC7FF4E56ECCF148DE60E4BEDBDB6900AD51383970A5F32C59B3AC6E3 E1B29995CCF423C75519790F54A08CD2 84525677E95AC97698D22E1125B67E92")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar')), HEX(AES_ENCRYPT(123, 'foobar')), HEX(AES_ENCRYPT('', 'foobar')), HEX(AES_ENCRYPT('你好', 'foobar')), AES_ENCRYPT(NULL, 'foobar')") result.Check(testkit.Rows(`45ABDD5C4802EFA6771A94C43F805208 45ABDD5C4802EFA6771A94C43F805208 791F1AEB6A6B796E6352BF381895CA0E D0147E2EB856186F146D9F6DE33F9546 `)) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`80D5646F07B4654B05A02D9085759770 80D5646F07B4654B05A02D9085759770 B3C14BA15030D2D7E99376DBE011E752 0CD2936EE4FEC7A8CDF6208438B2BC05 `)) // for AES_DECRYPT + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar'), 'bar')") result.Check(testkit.Rows("foo")) result = tk.MustQuery("select AES_DECRYPT(UNHEX('45ABDD5C4802EFA6771A94C43F805208'), 'foobar'), AES_DECRYPT(UNHEX('791F1AEB6A6B796E6352BF381895CA0E'), 'foobar'), AES_DECRYPT(UNHEX('D0147E2EB856186F146D9F6DE33F9546'), 'foobar'), AES_DECRYPT(NULL, 'foobar'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar')") result.Check(testkit.Rows(`123 你好 `)) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('80D5646F07B4654B05A02D9085759770'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('B3C14BA15030D2D7E99376DBE011E752'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('0CD2936EE4FEC7A8CDF6208438B2BC05'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`123 你好 `)) // for COMPRESS tk.MustExec("DROP TABLE IF EXISTS t1;") diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index d1e28128e2aa6..b6a6f60c847ec 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -269,6 +269,80 @@ func (s *testEncryptSuite) TestAESDecryptWithECB(c *C) { } } +func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + + for _, t := range tests { + str := []byte(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + crypted, err := AESEncryptWithCBC(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex(crypted) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + expect string + key string + iv string + hexCryptStr string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + // negtive cases: invalid padding / padding size + {"", "1234567890123456", "1234567890123456", "11223344556677112233", true}, + {"", "1234567890123456", "1234567890123456", "11223344556677112233112233445566", true}, + {"", "1234567890123456", "1234567890123456", "1122334455667711223311223344556611", true}, + } + + for _, t := range tests { + cryptStr, _ := hex.DecodeString(t.hexCryptStr) + key := []byte(t.key) + iv := []byte(t.iv) + + result, err := AESDecryptWithCBC(cryptStr, key, iv) + if t.isError { + c.Assert(err, NotNil) + continue + } + c.Assert(err, IsNil) + c.Assert(string(result), Equals, t.expect) + } +} + func (s *testEncryptSuite) TestDeriveKeyMySQL(c *C) { defer testleak.AfterTest(c)() From 2f0de8d660d5a1de44fc61e8fb7d034c5c80a96d Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Fri, 17 Aug 2018 14:55:58 +0800 Subject: [PATCH 03/12] *: fix ci --- expression/builtin_encryption.go | 16 ++++++++-------- expression/builtin_encryption_test.go | 2 +- util/encrypt/aes_test.go | 1 - 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index bc45e456e9bbf..704b075249914 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -26,15 +26,15 @@ import ( "hash" "io" + "crypto/aes" "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" - "github.com/pingcap/tidb/sessionctx/variable" - "crypto/aes" "strings" ) @@ -79,12 +79,12 @@ type aesModeAttr struct { var aesModes = map[string]*aesModeAttr{ //TODO support more modes - "aes-128-ecb": &aesModeAttr{"ecb", 16, false}, - "aes-192-ecb": &aesModeAttr{"ecb", 24, false}, - "aes-256-ecb": &aesModeAttr{"ecb", 32, false}, - "aes-128-cbc": &aesModeAttr{"cbc", 16, true}, - "aes-192-cbc": &aesModeAttr{"cbc", 24, true}, - "aes-256-cbc": &aesModeAttr{"cbc", 32, true}, + "aes-128-ecb": {"ecb", 16, false}, + "aes-192-ecb": {"ecb", 24, false}, + "aes-256-ecb": {"ecb", 32, false}, + "aes-128-cbc": {"cbc", 16, true}, + "aes-192-cbc": {"cbc", 24, true}, + "aes-256-cbc": {"cbc", 32, true}, } type aesDecryptFunctionClass struct { diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 4f990d0b490db..aa09bdd7893e4 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -19,12 +19,12 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/testleak" - "github.com/pingcap/tidb/sessionctx/variable" ) var aesTests = []struct { diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index b6a6f60c847ec..f93b83fde4307 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -288,7 +288,6 @@ func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { {"pingcap", "123456789012345", "1234567890123456", "", true}, } - for _, t := range tests { str := []byte(t.str) key := []byte(t.key) From 1c10b1965048ff7fa83a5dd6cbe7cbf04e612cac Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Mon, 20 Aug 2018 22:39:14 +0800 Subject: [PATCH 04/12] refactor AES Encrypt/Decrypt --- expression/builtin_encryption.go | 9 ++-- util/encrypt/aes.go | 90 ++++++++++++++------------------ 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 704b075249914..8935056487808 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -25,6 +25,7 @@ import ( "fmt" "hash" "io" + "strings" "crypto/aes" "github.com/juju/errors" @@ -35,7 +36,6 @@ import ( "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" - "strings" ) var ( @@ -71,6 +71,9 @@ var ( _ builtinFunc = &builtinUncompressedLengthSig{} ) +// IVSize indicates the initialization vector supplied to aes_decrypt +const IVSize = aes.BlockSize + type aesModeAttr struct { modeName string keySize int @@ -144,11 +147,11 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } - if len(iv) < aes.BlockSize { + if len(iv) < IVSize { return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) - iv = iv[0:aes.BlockSize] + iv = iv[0:IVSize] } else { if len(b.args) == 3 { // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index 6eb81049ed717..2c89aa7668738 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -21,6 +21,8 @@ import ( "github.com/juju/errors" ) +type blockModeBuild func(block cipher.Block) cipher.BlockMode + type ecb struct { b cipher.Block blockSize int @@ -120,42 +122,16 @@ func PKCS7Unpad(data []byte, blockSize int) ([]byte, error) { // AESEncryptWithECB encrypts data using AES with ECB mode. func AESEncryptWithECB(str, key []byte) ([]byte, error) { - cb, err := aes.NewCipher(key) - if err != nil { - return nil, errors.Trace(err) - } - blockSize := cb.BlockSize() - // The str arguments can be any length, and padding is automatically added to - // str so it is a multiple of a block as required by block-based algorithms such as AES. - // This padding is automatically removed by the AES_DECRYPT() function. - data, err := PKCS7Pad(str, blockSize) - if err != nil { - return nil, err - } - crypted := make([]byte, len(data)) - ecb := newECBEncrypter(cb) - ecb.CryptBlocks(crypted, data) - return crypted, nil + return aesEncrypt(str, key, func(block cipher.Block) cipher.BlockMode { + return newECBEncrypter(block) + }) } // AESDecryptWithECB decrypts data using AES with ECB mode. func AESDecryptWithECB(cryptStr, key []byte) ([]byte, error) { - cb, err := aes.NewCipher(key) - if err != nil { - return nil, errors.Trace(err) - } - blockSize := cb.BlockSize() - if len(cryptStr)%blockSize != 0 { - return nil, errors.New("Corrupted data") - } - mode := newECBDecrypter(cb) - data := make([]byte, len(cryptStr)) - mode.CryptBlocks(data, cryptStr) - plain, err := PKCS7Unpad(data, blockSize) - if err != nil { - return nil, err - } - return plain, nil + return aesDecrypt(cryptStr, key, func(block cipher.Block) cipher.BlockMode { + return newECBDecrypter(block) + }) } // DeriveKeyMySQL derives the encryption key from a password in MySQL algorithm. @@ -175,26 +151,20 @@ func DeriveKeyMySQL(key []byte, blockSize int) []byte { // AESEncryptWithCBC encrypts data using AES with CBC mode. func AESEncryptWithCBC(str, key []byte, iv []byte) ([]byte, error) { - cb, err := aes.NewCipher(key) - if err != nil { - return nil, errors.Trace(err) - } - blockSize := cb.BlockSize() - // The str arguments can be any length, and padding is automatically added to - // str so it is a multiple of a block as required by block-based algorithms such as AES. - // This padding is automatically removed by the AES_DECRYPT() function. - data, err := PKCS7Pad(str, blockSize) - if err != nil { - return nil, err - } - cbc := cipher.NewCBCEncrypter(cb, iv) - crypted := make([]byte, len(data)) - cbc.CryptBlocks(crypted, data) - return crypted, nil + return aesEncrypt(str, key, func(block cipher.Block) cipher.BlockMode { + return cipher.NewCBCEncrypter(block, iv) + }) } // AESDecryptWithCBC decrypts data using AES with CBC mode. func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { + return aesDecrypt(cryptStr, key, func(block cipher.Block) cipher.BlockMode { + return cipher.NewCBCDecrypter(block, iv) + }) +} + +// aesDecrypt decrypts data using AES. +func aesDecrypt(cryptStr, key []byte, build blockModeBuild) ([]byte, error) { cb, err := aes.NewCipher(key) if err != nil { return nil, errors.Trace(err) @@ -203,12 +173,32 @@ func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { if len(cryptStr)%blockSize != 0 { return nil, errors.New("Corrupted data") } - cbc := cipher.NewCBCDecrypter(cb, iv) + mode := build(cb) data := make([]byte, len(cryptStr)) - cbc.CryptBlocks(data, cryptStr) + mode.CryptBlocks(data, cryptStr) plain, err := PKCS7Unpad(data, blockSize) if err != nil { return nil, err } return plain, nil } + +// aesEncrypt encrypts data using AES. +func aesEncrypt(str, key []byte, build blockModeBuild) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + blockSize := cb.BlockSize() + // The str arguments can be any length, and padding is automatically added to + // str so it is a multiple of a block as required by block-based algorithms such as AES. + // This padding is automatically removed by the AES_DECRYPT() function. + data, err := PKCS7Pad(str, blockSize) + if err != nil { + return nil, err + } + mode := build(cb) + crypted := make([]byte, len(data)) + mode.CryptBlocks(crypted, data) + return crypted, nil +} From 817f05a55bec0c616c0817ed15a3b45a96030364 Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Sun, 26 Aug 2018 11:21:40 +0800 Subject: [PATCH 05/12] fix review problem --- expression/builtin_encryption.go | 59 +++++++++++++++----------- expression/builtin_encryption_test.go | 61 ++++++++++++--------------- 2 files changed, 61 insertions(+), 59 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 8935056487808..32a13cf67574f 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -71,9 +71,12 @@ var ( _ builtinFunc = &builtinUncompressedLengthSig{} ) -// IVSize indicates the initialization vector supplied to aes_decrypt -const IVSize = aes.BlockSize +// ivSize indicates the initialization vector supplied to aes_decrypt +const ivSize = aes.BlockSize +// aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. +// keySize is the key length in bits and mode is the encryption mode. +// ivRequired indicates that initialization vector is required or not. type aesModeAttr struct { modeName string keySize int @@ -81,7 +84,7 @@ type aesModeAttr struct { } var aesModes = map[string]*aesModeAttr{ - //TODO support more modes + //TODO support more modes, permitted mode values are: ECB, CBC, CFB1, CFB8, CFB128, OFB "aes-128-ecb": {"ecb", 16, false}, "aes-192-ecb": {"ecb", 24, false}, "aes-256-ecb": {"ecb", 32, false}, @@ -105,17 +108,25 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) bf.tp.Flen = args[0].GetType().Flen // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesDecryptSig{bf} + + modeName, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(modeName)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", modeName) + } + sig := &builtinAesDecryptSig{bf, mode} return sig, nil } type builtinAesDecryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesDecryptSig) Clone() builtinFunc { newSig := &builtinAesDecryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -134,12 +145,7 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { } var iv string - modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode, exists := aesModes[strings.ToLower(modeName)] - if !exists { - return "", true, errors.Errorf("unsupported block encryption mode - %v", modeName) - } - if mode.ivRequired { + if b.ivRequired { if len(b.args) != 3 { return "", true, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") } @@ -147,11 +153,11 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } - if len(iv) < IVSize { + if len(iv) < ivSize { return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) - iv = iv[0:IVSize] + iv = iv[0:ivSize] } else { if len(b.args) == 3 { // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. @@ -159,15 +165,15 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { } } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.keySize) + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var plainText []byte - switch mode.modeName { + switch b.modeName { case "ecb": plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) case "cbc": plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) default: - return "", true, errors.Errorf("unsupported block encryption mode - %v", mode.modeName) + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) } if err != nil { return "", true, nil @@ -190,17 +196,25 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesEncryptSig{bf} + + modeName, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(modeName)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", modeName) + } + sig := &builtinAesEncryptSig{bf, mode} return sig, nil } type builtinAesEncryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesEncryptSig) Clone() builtinFunc { newSig := &builtinAesEncryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -219,12 +233,7 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { } var iv string - modeName, _ := b.ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode, exists := aesModes[strings.ToLower(modeName)] - if !exists { - return "", true, errors.Errorf("unsupported block encryption mode - %v", modeName) - } - if mode.ivRequired { + if b.ivRequired { if len(b.args) != 3 { return "", true, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") } @@ -244,15 +253,15 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { } } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), mode.keySize) + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var cipherText []byte - switch mode.modeName { + switch b.modeName { case "ecb": cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) case "cbc": cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) default: - return "", true, errors.Errorf("unsupported block encryption mode - %v", mode.modeName) + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) } if err != nil { return "", true, nil diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index aa09bdd7893e4..d674f9462300e 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -28,28 +28,27 @@ import ( ) var aesTests = []struct { - mode string - origin interface{} - key interface{} - initVector interface{} - crypt interface{} + mode string + origin interface{} + params []interface{} + crypt interface{} }{ // test for ecb - {"aes-128-ecb", "pingcap", "1234567890123456", nil, "697BFE9B3F8C2F289DD82C88C7BC95C4"}, - {"aes-128-ecb", "pingcap123", "1234567890123456", nil, "CEC348F4EF5F84D3AA6C4FA184C65766"}, - {"aes-128-ecb", "pingcap", "123456789012345678901234", nil, "6F1589686860C8E8C7A40A78B25FF2C0"}, - {"aes-128-ecb", "pingcap", "123", nil, "996E0CA8688D7AD20819B90B273E01C6"}, - {"aes-128-ecb", "pingcap", 123, nil, "996E0CA8688D7AD20819B90B273E01C6"}, - {"aes-128-ecb", nil, 123, nil, nil}, - {"aes-192-ecb", "pingcap", "1234567890123456", nil, "9B139FD002E6496EA2D5C73A2265E661"}, - {"aes-256-ecb", "pingcap", "1234567890123456", nil, "F80DCDEDDBE5663BDB68F74AEDDB8EE3"}, + {"aes-128-ecb", "pingcap", []interface{}{"1234567890123456"}, "697BFE9B3F8C2F289DD82C88C7BC95C4"}, + {"aes-128-ecb", "pingcap123", []interface{}{"1234567890123456"}, "CEC348F4EF5F84D3AA6C4FA184C65766"}, + {"aes-128-ecb", "pingcap", []interface{}{"123456789012345678901234"}, "6F1589686860C8E8C7A40A78B25FF2C0"}, + {"aes-128-ecb", "pingcap", []interface{}{"123"}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", "pingcap", []interface{}{123}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", nil, []interface{}{123}, nil}, + {"aes-192-ecb", "pingcap", []interface{}{"1234567890123456"}, "9B139FD002E6496EA2D5C73A2265E661"}, + {"aes-256-ecb", "pingcap", []interface{}{"1234567890123456"}, "F80DCDEDDBE5663BDB68F74AEDDB8EE3"}, // test for cbc - {"aes-128-cbc", "pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792"}, - {"aes-128-cbc", "pingcap", "123456789012345678901234", "1234567890123456", "483788634DA8817423BA0934FD2C096E"}, - {"aes-192-cbc", "pingcap", "1234567890123456", "1234567890123456", "516391DB38E908ECA93AAB22870EC787"}, - {"aes-256-cbc", "pingcap", "1234567890123456", "1234567890123456", "5D0E22C1E77523AEF5C3E10B65653C8F"}, - {"aes-256-cbc", "pingcap", "12345678901234561234567890123456", "1234567890123456", "A26BA27CA4BE9D361D545AA84A17002D"}, - {"aes-256-cbc", "pingcap", "1234567890123456", "12345678901234561234567890123456", "5D0E22C1E77523AEF5C3E10B65653C8F"}, + {"aes-128-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "2ECA0077C5EA5768A0485AA522774792"}, + {"aes-128-cbc", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "483788634DA8817423BA0934FD2C096E"}, + {"aes-192-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "516391DB38E908ECA93AAB22870EC787"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, + {"aes-256-cbc", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "A26BA27CA4BE9D361D545AA84A17002D"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, } func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { @@ -57,12 +56,9 @@ func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { fc := funcs[ast.AesEncrypt] for _, tt := range aesTests { variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) - str := types.NewDatum(tt.origin) - key := types.NewDatum(tt.key) - args := []types.Datum{str, key} - if tt.initVector != nil { - vec := types.NewDatum(tt.initVector) - args = append(args, vec) + args := []types.Datum{types.NewDatum(tt.origin)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) } f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) crypt, err := evalBuiltinFunc(f, chunk.Row{}) @@ -75,19 +71,16 @@ func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesDecrypt] - for _, test := range aesTests { - variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(test.mode)) - cryptStr := fromHex(test.crypt) - key := types.NewDatum(test.key) - args := []types.Datum{cryptStr, key} - if test.initVector != nil { - vec := types.NewDatum(test.initVector) - args = append(args, vec) + for _, tt := range aesTests { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) + args := []types.Datum{fromHex(tt.crypt)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) } f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) str, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) - c.Assert(str, DeepEquals, types.NewDatum(test.origin)) + c.Assert(str, DeepEquals, types.NewDatum(tt.origin)) } s.testAmbiguousInput(c, ast.AesDecrypt) } From d20672517aff19ff075d5e1e2c47e96ae07dc7e9 Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Sun, 26 Aug 2018 12:47:41 +0800 Subject: [PATCH 06/12] refactor sig to two parts --- expression/builtin_encryption.go | 176 ++++++++++++++++++-------- expression/builtin_encryption_test.go | 5 +- 2 files changed, 127 insertions(+), 54 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 32a13cf67574f..1c558a619c313 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -60,7 +60,9 @@ var ( var ( _ builtinFunc = &builtinAesDecryptSig{} + _ builtinFunc = &builtinAesDecryptIVSig{} _ builtinFunc = &builtinAesEncryptSig{} + _ builtinFunc = &builtinAesEncryptIVSig{} _ builtinFunc = &builtinCompressSig{} _ builtinFunc = &builtinMD5Sig{} _ builtinFunc = &builtinPasswordSig{} @@ -109,13 +111,21 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp bf.tp.Flen = args[0].GetType().Flen // At most. types.SetBinChsClnFlag(bf.tp) - modeName, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode, exists := aesModes[strings.ToLower(modeName)] + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] if !exists { - return nil, errors.Errorf("unsupported block encryption mode - %v", modeName) + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) } - sig := &builtinAesDecryptSig{bf, mode} - return sig, nil + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") + } + return &builtinAesDecryptIVSig{bf, mode}, nil + } else if len(args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + return &builtinAesDecryptSig{bf, mode}, nil } type builtinAesDecryptSig struct { @@ -144,32 +154,59 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - var iv string - if b.ivRequired { - if len(b.args) != 3 { - return "", true, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") - } - iv, isNull, err = b.args[2].EvalString(b.ctx, row) - if isNull || err != nil { - return "", true, errors.Trace(err) - } - if len(iv) < ivSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") - } - // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) - iv = iv[0:ivSize] - } else { - if len(b.args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) - } - } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var plainText []byte switch b.modeName { case "ecb": plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(plainText), false, nil +} + +type builtinAesDecryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesDecryptIVSig) Clone() builtinFunc { + newSig := &builtinAesDecryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_DECRYPT(crypt_str, key_key, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + cryptStr, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var plainText []byte + switch b.modeName { case "cbc": plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) default: @@ -197,13 +234,21 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most. types.SetBinChsClnFlag(bf.tp) - modeName, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) - mode, exists := aesModes[strings.ToLower(modeName)] + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] if !exists { - return nil, errors.Errorf("unsupported block encryption mode - %v", modeName) + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) } - sig := &builtinAesEncryptSig{bf, mode} - return sig, nil + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") + } + return &builtinAesEncryptIVSig{bf, mode}, nil + } else if len(args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + return &builtinAesEncryptSig{bf, mode}, nil } type builtinAesEncryptSig struct { @@ -232,32 +277,59 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } - var iv string - if b.ivRequired { - if len(b.args) != 3 { - return "", true, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") - } - iv, isNull, err = b.args[2].EvalString(b.ctx, row) - if isNull || err != nil { - return "", true, errors.Trace(err) - } - if len(iv) < aes.BlockSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") - } - // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) - iv = iv[0:aes.BlockSize] - } else { - if len(b.args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) - } - } - key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var cipherText []byte switch b.modeName { case "ecb": cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(cipherText), false, nil +} + +type builtinAesEncryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesEncryptIVSig) Clone() builtinFunc { + newSig := &builtinAesEncryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_ENCRYPT(str, key_str, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var cipherText []byte + switch b.modeName { case "cbc": cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) default: diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index d674f9462300e..db6a2d706a159 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -86,6 +86,7 @@ func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { } func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) fc := funcs[fnName] arg := types.NewStringDatum("str") var argNull types.Datum @@ -101,10 +102,10 @@ func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { // test for modes that require init_vector variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-cbc")) - f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) - crypt, err = evalBuiltinFunc(f, chunk.Row{}) + _, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) c.Assert(err, NotNil) f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) + c.Assert(err, IsNil) crypt, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, NotNil) From 4279572918e68e1bb336e23259bbb65fd02179f9 Mon Sep 17 00:00:00 2001 From: xiaojian cai Date: Wed, 29 Aug 2018 11:52:31 +0800 Subject: [PATCH 07/12] revert merge code for mistake --- expression/errors.go | 1 - 1 file changed, 1 deletion(-) diff --git a/expression/errors.go b/expression/errors.go index 62059687fda3f..92be6e1487054 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -38,7 +38,6 @@ var ( errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) - errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) From dda3ec7ebde53706b88c7bebb75f789076478862 Mon Sep 17 00:00:00 2001 From: xiaojian cai Date: Wed, 29 Aug 2018 18:02:31 +0800 Subject: [PATCH 08/12] fix ci --- expression/errors.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/expression/errors.go b/expression/errors.go index 92be6e1487054..517398bd0a9a3 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -38,6 +38,7 @@ var ( errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) + errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) @@ -54,6 +55,7 @@ func init() { mysql.ErrOperandColumns: mysql.ErrOperandColumns, mysql.ErrRegexp: mysql.ErrRegexp, mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, + mysql.WarnOptionIgnored: mysql.WarnOptionIgnored, mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes From 492f94c33e9b34f785315067d2581df54088401f Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Sun, 9 Sep 2018 17:50:30 +0800 Subject: [PATCH 09/12] move 'iv' check to execution phase --- expression/builtin_encryption.go | 20 ++++++++++---------- expression/integration_test.go | 3 +++ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 1c558a619c313..31a446fa5f408 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -121,9 +121,6 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp return nil, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") } return &builtinAesDecryptIVSig{bf, mode}, nil - } else if len(args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) } return &builtinAesDecryptSig{bf, mode}, nil } @@ -148,11 +145,14 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } - keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var plainText []byte @@ -199,7 +199,7 @@ func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) return "", true, errors.Trace(err) } if len(iv) < aes.BlockSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least 16 bytes long") + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] @@ -244,9 +244,6 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp return nil, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") } return &builtinAesEncryptIVSig{bf, mode}, nil - } else if len(args) == 3 { - // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) } return &builtinAesEncryptSig{bf, mode}, nil } @@ -271,11 +268,14 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } - keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) var cipherText []byte @@ -322,7 +322,7 @@ func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) return "", true, errors.Trace(err) } if len(iv) < aes.BlockSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least 16 bytes long") + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize) } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] diff --git a/expression/integration_test.go b/expression/integration_test.go index 6458873dd57ad..8b36da9c21b18 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -994,6 +994,9 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC 9E018F7F2838DBA23C57F0E4CCF93287 E764D3E9D4AF8F926CD0979DDB1D0AF40C208B20A6C39D5D028644885280973A C452FFEEB76D3F5E9B26B8D48F7A228C 181BD5C81CBD36779A3C9DD5FF486B35 CE15F14AC7FF4E56ECCF148DE60E4BEDBDB6900AD51383970A5F32C59B3AC6E3 E1B29995CCF423C75519790F54A08CD2 84525677E95AC97698D22E1125B67E92")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar')), HEX(AES_ENCRYPT(123, 'foobar')), HEX(AES_ENCRYPT('', 'foobar')), HEX(AES_ENCRYPT('你好', 'foobar')), AES_ENCRYPT(NULL, 'foobar')") result.Check(testkit.Rows(`45ABDD5C4802EFA6771A94C43F805208 45ABDD5C4802EFA6771A94C43F805208 791F1AEB6A6B796E6352BF381895CA0E D0147E2EB856186F146D9F6DE33F9546 `)) + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', 'iv')), HEX(AES_ENCRYPT(b, 'key', 'iv')) from t") + result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1618| option ignored", "Warning|1618| option ignored")) tk.MustExec("SET block_encryption_mode='aes-128-cbc';") result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) From d732b472ab763cfd678613b7fceec92fad70703b Mon Sep 17 00:00:00 2001 From: CaiXiaoJian Date: Wed, 12 Sep 2018 08:51:25 +0000 Subject: [PATCH 10/12] formate import --- expression/builtin_encryption.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 31a446fa5f408..89347d25c3f20 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -16,6 +16,7 @@ package expression import ( "bytes" "compress/zlib" + "crypto/aes" "crypto/md5" "crypto/rand" "crypto/sha1" @@ -27,7 +28,6 @@ import ( "io" "strings" - "crypto/aes" "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" From ec0e07b6f0bf5680de8b621b4038a1c98cfcecef Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Wed, 12 Sep 2018 20:52:15 +0800 Subject: [PATCH 11/12] fix ci --- expression/builtin_encryption.go | 12 ++++++------ expression/builtin_encryption_test.go | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 3a033243ccdeb..94eaef8a84574 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -118,7 +118,7 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp } if mode.ivRequired { if len(args) != 3 { - return nil, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") + return nil, ErrIncorrectParameterCount.GenWithStackByArgs("aes_decrypt") } return &builtinAesDecryptIVSig{bf, mode}, nil } @@ -151,7 +151,7 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { } if !b.ivRequired && len(b.args) == 3 { // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenWithStackByArgs("IV")) } key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) @@ -199,7 +199,7 @@ func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) return "", true, errors.Trace(err) } if len(iv) < aes.BlockSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) + return "", true, errIncorrectArgs.GenWithStack("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] @@ -241,7 +241,7 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp } if mode.ivRequired { if len(args) != 3 { - return nil, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") + return nil, ErrIncorrectParameterCount.GenWithStackByArgs("aes_encrypt") } return &builtinAesEncryptIVSig{bf, mode}, nil } @@ -274,7 +274,7 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { } if !b.ivRequired && len(b.args) == 3 { // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenWithStackByArgs("IV")) } key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) @@ -322,7 +322,7 @@ func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) return "", true, errors.Trace(err) } if len(iv) < aes.BlockSize { - return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize) + return "", true, errIncorrectArgs.GenWithStack("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize) } // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) iv = iv[0:aes.BlockSize] diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index b5608a3d0aa5d..99f29ea972933 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -115,6 +115,8 @@ func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { c.Assert(err, IsNil) c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) } + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + s.testNullInput(c, ast.AesEncrypt) s.testAmbiguousInput(c, ast.AesEncrypt) } @@ -132,10 +134,12 @@ func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { c.Assert(err, IsNil) c.Assert(str, DeepEquals, types.NewDatum(tt.origin)) } + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + s.testNullInput(c, ast.AesDecrypt) s.testAmbiguousInput(c, ast.AesDecrypt) } -func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { +func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) fc := funcs[fnName] arg := types.NewStringDatum("str") @@ -149,20 +153,25 @@ func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { crypt, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(crypt.IsNull(), IsTrue) +} +func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { + fc := funcs[fnName] + arg := types.NewStringDatum("str") // test for modes that require init_vector variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-cbc")) - _, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) + _, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) c.Assert(err, NotNil) - f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) c.Assert(err, IsNil) - crypt, err = evalBuiltinFunc(f, chunk.Row{}) + _, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, NotNil) // test for modes that do not require init_vector variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, arg})) - crypt, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() c.Assert(len(warnings), GreaterEqual, 1) From b423fb8d0a88e4cbd30988e5546d6cf4e8abc710 Mon Sep 17 00:00:00 2001 From: caixiaojian Date: Sun, 23 Sep 2018 16:23:02 +0800 Subject: [PATCH 12/12] remove buildBlockMode --- util/encrypt/aes.go | 54 ++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index cc3d3e6a877d0..b1b90f3a524b8 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -21,8 +21,6 @@ import ( "github.com/pkg/errors" ) -type blockModeBuild func(block cipher.Block) cipher.BlockMode - type ecb struct { b cipher.Block blockSize int @@ -122,16 +120,22 @@ func PKCS7Unpad(data []byte, blockSize int) ([]byte, error) { // AESEncryptWithECB encrypts data using AES with ECB mode. func AESEncryptWithECB(str, key []byte) ([]byte, error) { - return aesEncrypt(str, key, func(block cipher.Block) cipher.BlockMode { - return newECBEncrypter(block) - }) + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := newECBEncrypter(cb) + return aesEncrypt(str, mode) } // AESDecryptWithECB decrypts data using AES with ECB mode. func AESDecryptWithECB(cryptStr, key []byte) ([]byte, error) { - return aesDecrypt(cryptStr, key, func(block cipher.Block) cipher.BlockMode { - return newECBDecrypter(block) - }) + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := newECBDecrypter(cb) + return aesDecrypt(cryptStr, mode) } // DeriveKeyMySQL derives the encryption key from a password in MySQL algorithm. @@ -151,29 +155,30 @@ func DeriveKeyMySQL(key []byte, blockSize int) []byte { // AESEncryptWithCBC encrypts data using AES with CBC mode. func AESEncryptWithCBC(str, key []byte, iv []byte) ([]byte, error) { - return aesEncrypt(str, key, func(block cipher.Block) cipher.BlockMode { - return cipher.NewCBCEncrypter(block, iv) - }) + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewCBCEncrypter(cb, iv) + return aesEncrypt(str, mode) } // AESDecryptWithCBC decrypts data using AES with CBC mode. func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { - return aesDecrypt(cryptStr, key, func(block cipher.Block) cipher.BlockMode { - return cipher.NewCBCDecrypter(block, iv) - }) -} - -// aesDecrypt decrypts data using AES. -func aesDecrypt(cryptStr, key []byte, build blockModeBuild) ([]byte, error) { cb, err := aes.NewCipher(key) if err != nil { return nil, errors.Trace(err) } - blockSize := cb.BlockSize() + mode := cipher.NewCBCDecrypter(cb, iv) + return aesDecrypt(cryptStr, mode) +} + +// aesDecrypt decrypts data using AES. +func aesDecrypt(cryptStr []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() if len(cryptStr)%blockSize != 0 { return nil, errors.New("Corrupted data") } - mode := build(cb) data := make([]byte, len(cryptStr)) mode.CryptBlocks(data, cryptStr) plain, err := PKCS7Unpad(data, blockSize) @@ -184,12 +189,8 @@ func aesDecrypt(cryptStr, key []byte, build blockModeBuild) ([]byte, error) { } // aesEncrypt encrypts data using AES. -func aesEncrypt(str, key []byte, build blockModeBuild) ([]byte, error) { - cb, err := aes.NewCipher(key) - if err != nil { - return nil, errors.Trace(err) - } - blockSize := cb.BlockSize() +func aesEncrypt(str []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() // The str arguments can be any length, and padding is automatically added to // str so it is a multiple of a block as required by block-based algorithms such as AES. // This padding is automatically removed by the AES_DECRYPT() function. @@ -197,7 +198,6 @@ func aesEncrypt(str, key []byte, build blockModeBuild) ([]byte, error) { if err != nil { return nil, err } - mode := build(cb) crypted := make([]byte, len(data)) mode.CryptBlocks(crypted, data) return crypted, nil