Skip to content

Commit

Permalink
Merge pull request #387 from dolthub/fulghum/caching_sha2
Browse files Browse the repository at this point in the history
Add support for serializing/deserializing `caching_sha2_password` auth strings
  • Loading branch information
fulghum authored Dec 9, 2024
2 parents 814752c + 0cfa560 commit 588631a
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 3 deletions.
11 changes: 8 additions & 3 deletions go/mysql/auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,18 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string,
return result, nil
}
if !c.TLSEnabled() && !c.IsUnixSocket() {
return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError,
"Access denied for user '%v' (not using TLS or Unix socket)", user)
}
data := c.startEphemeralPacket(1)

data := c.startEphemeralPacket(2)
pos := 0
pos = writeByte(data, pos, AuthMoreDataPacket)
writeByte(data, pos, CachingSha2FullAuth)
c.writeEphemeralPacket()
if err = c.writeEphemeralPacket(); err != nil {
return nil, err
}

password, err := readPacketPasswordString(c)
if err != nil {
return nil, err
Expand Down
243 changes: 243 additions & 0 deletions go/mysql/auth_server_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/*
Copyright ApeCloud, Inc.
Licensed under the Apache v2(found in the LICENSE file in the root directory).
*/

// NOTE: The logic in SerializeCachingSha2PasswordAuthString, and the b64From24bit and sha256Hash functions
// were taken from the wesql/wescale project (https://github.com/wesql/wescale) and is copyright ApeCloud, Inc.
// All other code in this file is copyright DoltHub, Inc.

package mysql

import (
"bytes"
"crypto/sha256"
"fmt"
"strconv"
)

const (
// DefaultCachingSha2PasswordHashIterations is the default number of hashing iterations used (before the
// iterationMultiplier is applied) when hashing a password using the caching_sha2_password auth plugin.
DefaultCachingSha2PasswordHashIterations = 5

// mixChars is the number of characters to use in the mix
mixChars = 32

// iterationMultiplier is the multiplier applied to the number of hashing iterations the user has requested.
// For example, if the user requests 10 iterations, the actual number of iterations will be 10 * iterationMultiplier.
iterationMultiplier = 1000

// delimiter is used to separate the metadata fields in a caching_sha2_password authentication string.
delimiter = '$'

// saltLength is the length of the salt used in the caching_sha2_password authentication protocol.
saltLength = 20

// storedSha256DigestLength is the length of the base64 encoded sha256 digest in an auth string
storedSha256DigestLength = 43

// maxIterations is the maximum iterations (before the iterationMultiplier is applied) that can be used
// in the hasing process for the caching_sha2_password auth plugin. The iterations applied are not directly
// user-controllable, so realistically, this limit can't be breached.
maxIterations = 0xFFF
)

// DeserializeCachingSha2PasswordAuthString takes in |authStringBytes|, a caching_sha2_password auth plugin generated
// authentication string, and parses out the individual components: the digest type, number of iterations, salt, and
// the password hash. |iterations| is the number of iterations the hashing function has been through (not including
// the internal iteration multiplier, 1,000). If any errors are encountered during parsing, such as the authentication
// string bytes not having the expected format, an error is returned.
//
// The protocol for generating an auth string for the caching_sha2_password plugin is not documented, but the MySQL
// source code can be found here: https://github.com/mysql/mysql-server/blob/trunk/sql/auth/sha2_password.cc#L440
func DeserializeCachingSha2PasswordAuthString(authStringBytes []byte) (digestType string, iterations int, salt, digest []byte, err error) {
if authStringBytes[0] != delimiter {
return "", 0, nil, nil, fmt.Errorf(
"authentication string does not start with the expected delimiter '$'")
}

// Digest Type
digestTypeCode := authStringBytes[1]
switch digestTypeCode {
case 'A':
digestType = "SHA256"
default:
return "", 0, nil, nil, fmt.Errorf(
"unsupported digest type: %v", digestTypeCode)
}

// Validate the delimiter
if authStringBytes[2] != delimiter {
return "", 0, nil, nil, fmt.Errorf(
"authentication string does not contain with the expected delimiter '$' between digest type and iterations")
}

// Iterations
iterationsString := string(authStringBytes[3:6])
iterations32bit, err := strconv.ParseInt(iterationsString, 16, 32)
if err != nil {
return "", 0, nil, nil, fmt.Errorf(
"iterations specified in authentication string is not a valid integer: %v", iterationsString)
}
iterations = int(iterations32bit)

// Validate the delimiter
if authStringBytes[6] != delimiter {
return "", 0, nil, nil, fmt.Errorf(
"authentication string does not contain with the expected delimiter '$' between iterations and salt")
}

// Salt
salt = authStringBytes[7 : 7+saltLength]

// Digest
digest = authStringBytes[7+saltLength:]
if len(digest) != storedSha256DigestLength {
return "", 0, nil, nil, fmt.Errorf("Unexpected digest length: %v", len(digest))
}

return digestType, iterations, salt, digest, nil
}

// SerializeCachingSha2PasswordAuthString uses SHA256 hashing algorithm to hash a plaintext password (|plaintext|)
// with the specified |salt|. The hashing is repeated |iterations| times. Note that |iterations| is the external,
// user-controllable number of iterations BEFORE the iterations multipler (i.e. 1000) is applied. The return bytes
// represent an authentication string compatible with the caching_sha2_password plugin authentication method.
func SerializeCachingSha2PasswordAuthString(plaintext string, salt []byte, iterations int) ([]byte, error) {
if iterations > maxIterations {
return nil, fmt.Errorf("iterations value (%d) is greater than max allowed iterations (%d)", iterations, maxIterations)
}

// 1, 2, 3
bufA := bytes.NewBuffer(make([]byte, 0, 4096))
bufA.WriteString(plaintext)
bufA.Write(salt)

// 4, 5, 6, 7, 8
bufB := bytes.NewBuffer(make([]byte, 0, 4096))
bufB.WriteString(plaintext)
bufB.Write(salt)
bufB.WriteString(plaintext)
sumB := sha256Hash(bufB.Bytes())
bufB.Reset()

// 9, 10
var i int
for i = len(plaintext); i > mixChars; i -= mixChars {
bufA.Write(sumB[:mixChars])
}
bufA.Write(sumB[:i])
// 11
for i = len(plaintext); i > 0; i >>= 1 {
if i%2 == 0 {
bufA.WriteString(plaintext)
} else {
bufA.Write(sumB[:])
}
}

// 12
sumA := sha256Hash(bufA.Bytes())
bufA.Reset()

// 13, 14, 15
bufDP := bufA
for range []byte(plaintext) {
bufDP.WriteString(plaintext)
}
sumDP := sha256Hash(bufDP.Bytes())
bufDP.Reset()

// 16
p := make([]byte, 0, sha256.Size)
for i = len(plaintext); i > 0; i -= mixChars {
if i > mixChars {
p = append(p, sumDP[:]...)
} else {
p = append(p, sumDP[0:i]...)
}
}
// 17, 18, 19
bufDS := bufA
for i = 0; i < 16+int(sumA[0]); i++ {
bufDS.Write(salt)
}
sumDS := sha256Hash(bufDS.Bytes())
bufDS.Reset()

// 20
s := make([]byte, 0, 32)
for i = len(salt); i > 0; i -= mixChars {
if i > mixChars {
s = append(s, sumDS[:]...)
} else {
s = append(s, sumDS[0:i]...)
}
}

// 21
bufC := bufA
var sumC []byte
for i = 0; i < iterations*iterationMultiplier; i++ {
bufC.Reset()
if i&1 != 0 {
bufC.Write(p)
} else {
bufC.Write(sumA[:])
}
if i%3 != 0 {
bufC.Write(s)
}
if i%7 != 0 {
bufC.Write(p)
}
if i&1 != 0 {
bufC.Write(sumA[:])
} else {
bufC.Write(p)
}
sumC = sha256Hash(bufC.Bytes())
sumA = sumC
}
// 22
buf := bytes.NewBuffer(make([]byte, 0, 100))
buf.Write([]byte{'$', 'A', '$'})
rounds := fmt.Sprintf("%03X", iterations)
buf.WriteString(rounds)
buf.Write([]byte{'$'})
buf.Write(salt)

b64From24bit([]byte{sumC[0], sumC[10], sumC[20]}, 4, buf)
b64From24bit([]byte{sumC[21], sumC[1], sumC[11]}, 4, buf)
b64From24bit([]byte{sumC[12], sumC[22], sumC[2]}, 4, buf)
b64From24bit([]byte{sumC[3], sumC[13], sumC[23]}, 4, buf)
b64From24bit([]byte{sumC[24], sumC[4], sumC[14]}, 4, buf)
b64From24bit([]byte{sumC[15], sumC[25], sumC[5]}, 4, buf)
b64From24bit([]byte{sumC[6], sumC[16], sumC[26]}, 4, buf)
b64From24bit([]byte{sumC[27], sumC[7], sumC[17]}, 4, buf)
b64From24bit([]byte{sumC[18], sumC[28], sumC[8]}, 4, buf)
b64From24bit([]byte{sumC[9], sumC[19], sumC[29]}, 4, buf)
b64From24bit([]byte{0, sumC[31], sumC[30]}, 3, buf)

return []byte(buf.String()), nil
}

// sha256Hash is a util function to calculate a sha256 hash.
func sha256Hash(input []byte) []byte {
res := sha256.Sum256(input)
return res[:]
}

// b64From24bit is a util function to base64 encode up to 24 bits at a time (|n|) from the
// byte slice |b| and writes the encoded data to |buf|.
func b64From24bit(b []byte, n int, buf *bytes.Buffer) {
b64t := []byte("./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

w := (int64(b[0]) << 16) | (int64(b[1]) << 8) | int64(b[2])
for n > 0 {
n--
buf.WriteByte(b64t[w&0x3f])
w >>= 6
}
}
124 changes: 124 additions & 0 deletions go/mysql/auth_server_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql

import (
"encoding/hex"
"strconv"
"testing"

"github.com/stretchr/testify/require"
)

// TestDeserializeCachingSha2PasswordAuthString tests that MySQL-generated caching_sha2_password authentication strings
// can be correctly deserialized into their component parts. We use a hex encoded string for the authentication string,
// because it is binary data and displaying it as a string and copying/pasting it corrupts the data.
func TestDeserializeCachingSha2PasswordAuthString(t *testing.T) {
tests := []struct {
hexEncodedAuthStringBytes string
expectedDigestType string
expectedIterations int
expectedSalt []byte
expectedDigest []byte
expectedErrorSubstring string
}{
{
hexEncodedAuthStringBytes: "2441243030352434341F5017121D0420134615056D3519305C4C57507A4B4E584E482E5351544E324B2E44764B586566567243336F56367739736F61386E424B695741395443",
expectedDigestType: "SHA256",
expectedIterations: 5,
expectedSalt: []byte{52, 52, 31, 80, 23, 18, 29, 4, 32, 19, 70, 21, 5, 109, 53, 25, 48, 92, 76, 87},
expectedDigest: []byte{0x50, 0x7a, 0x4b, 0x4e, 0x58, 0x4e, 0x48, 0x2e, 0x53, 0x51, 0x54,
0x4e, 0x32, 0x4b, 0x2e, 0x44, 0x76, 0x4b, 0x58, 0x65, 0x66, 0x56, 0x72, 0x43, 0x33, 0x6f, 0x56,
0x36, 0x77, 0x39, 0x73, 0x6f, 0x61, 0x38, 0x6e, 0x42, 0x4b, 0x69, 0x57, 0x41, 0x39, 0x54, 0x43},
},
{
hexEncodedAuthStringBytes: "244124303035241A502F3D02576A0150494D096659325E017E08086E516B42326E5762733366615556756E6131666174354533594255684536356E79772F5971397876772F32",
expectedDigestType: "SHA256",
expectedIterations: 5,
expectedSalt: []byte{0x1a, 0x50, 0x2f, 0x3d, 0x2, 0x57, 0x6a, 0x1, 0x50, 0x49, 0x4d, 0x9, 0x66, 0x59, 0x32, 0x5e, 0x1, 0x7e, 0x8, 0x8},
expectedDigest: []byte{0x6e, 0x51, 0x6b, 0x42, 0x32, 0x6e, 0x57, 0x62, 0x73, 0x33, 0x66,
0x61, 0x55, 0x56, 0x75, 0x6e, 0x61, 0x31, 0x66, 0x61, 0x74, 0x35, 0x45, 0x33, 0x59, 0x42, 0x55,
0x68, 0x45, 0x36, 0x35, 0x6e, 0x79, 0x77, 0x2f, 0x59, 0x71, 0x39, 0x78, 0x76, 0x77, 0x2f, 0x32},
},

// TODO: Test malformed auth strings
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
bytes, err := hex.DecodeString(test.hexEncodedAuthStringBytes)
require.NoError(t, err)

digestType, iterations, salt, digest, err := DeserializeCachingSha2PasswordAuthString(bytes)
if test.expectedErrorSubstring == "" {
require.NoError(t, err)
require.Equal(t, test.expectedDigestType, digestType)
require.Equal(t, test.expectedIterations, iterations)
require.Equal(t, test.expectedSalt, salt)
require.Equal(t, test.expectedDigest, digest)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), test.expectedErrorSubstring)
}
})
}
}

// TestSerializeCachingSha2PasswordAuthString tests that we can generate a correct caching_sha2_password authentication
// string from a password, salt, and number of iterations. We use a hex encoded string for the expected authentication
// string, because it is binary data and displaying it as a string and copying/pasting it corrupts the data.
func TestSerializeCachingSha2PasswordAuthString(t *testing.T) {
tests := []struct {
password string
salt []byte
iterations int
expectedHexEncodedAuthString string
expectedErrorSubstring string
}{
{
password: "pass3",
salt: []byte{52, 52, 31, 80, 23, 18, 29, 4, 32, 19, 70, 21, 5, 109, 53, 25, 48, 92, 76, 87},
iterations: 5,
expectedHexEncodedAuthString: "2441243030352434341F5017121D0420134615056D3519305C4C57507A4B4E584E482E5351544E324B2E44764B586566567243336F56367739736F61386E424B695741395443",
},
{
password: "pass1",
salt: []byte{0x1a, 0x50, 0x2f, 0x3d, 0x2, 0x57, 0x6a, 0x1, 0x50, 0x49, 0x4d, 0x9, 0x66, 0x59, 0x32, 0x5e, 0x1, 0x7e, 0x8, 0x8},
iterations: 5,
expectedHexEncodedAuthString: "244124303035241A502F3D02576A0150494D096659325E017E08086E516B42326E5762733366615556756E6131666174354533594255684536356E79772F5971397876772F32",
},
{
// When an iteration count larger than 0xFFF is specified, an error should be returned
password: "pass1",
salt: []byte{0x1a, 0x50, 0x2f, 0x3d, 0x2, 0x57, 0x6a, 0x1, 0x50, 0x49, 0x4d, 0x9, 0x66, 0x59, 0x32, 0x5e, 0x1, 0x7e, 0x8, 0x8},
iterations: maxIterations + 1,
expectedErrorSubstring: "iterations value (4096) is greater than max allowed iterations (4095)",
},
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
authStringBytes, err := SerializeCachingSha2PasswordAuthString(test.password, test.salt, test.iterations)
if test.expectedErrorSubstring == "" {
expectedBytes, err := hex.DecodeString(test.expectedHexEncodedAuthString)
require.NoError(t, err)
require.Equal(t, expectedBytes, authStringBytes)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), test.expectedErrorSubstring)
}
})
}
}

0 comments on commit 588631a

Please sign in to comment.