diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 02bc7912c9..542ce47360 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -28,6 +28,8 @@ import org.bson.BsonString; import javax.crypto.Mac; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; import javax.crypto.spec.SecretKeySpec; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; @@ -36,6 +38,7 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; +import java.security.spec.InvalidKeySpecException; import java.util.Base64; import java.util.HashMap; import java.util.Random; @@ -55,7 +58,6 @@ class ScramShaAuthenticator extends SaslAuthenticator { private static final int MINIMUM_ITERATION_COUNT = 4096; private static final String GS2_HEADER = "n,,"; private static final int RANDOM_LENGTH = 24; - private static final byte[] INT_1 = {0, 0, 0, 1}; ScramShaAuthenticator(final MongoCredentialWithCache credential, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { @@ -65,8 +67,8 @@ class ScramShaAuthenticator extends SaslAuthenticator { } ScramShaAuthenticator(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator, final ClusterConnectionMode clusterConnectionMode, - @Nullable final ServerApi serverApi) { + final AuthenticationHashGenerator authenticationHashGenerator, final ClusterConnectionMode clusterConnectionMode, + @Nullable final ServerApi serverApi) { super(credential, clusterConnectionMode, serverApi); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; @@ -127,6 +129,8 @@ class ScramShaSaslClient extends SaslClientImpl { private final AuthenticationHashGenerator authenticationHashGenerator; private final String hAlgorithm; private final String hmacAlgorithm; + private final String pbeAlgorithm; + private final int keyLength; private String clientFirstMessageBare; private String clientNonce; @@ -137,16 +141,20 @@ class ScramShaSaslClient extends SaslClientImpl { ScramShaSaslClient( final MongoCredential credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator) { + final AuthenticationHashGenerator authenticationHashGenerator) { super(credential); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; if (assertNotNull(credential.getAuthenticationMechanism()).equals(SCRAM_SHA_1)) { hAlgorithm = "SHA-1"; hmacAlgorithm = "HmacSHA1"; + pbeAlgorithm = "PBKDF2WithHmacSHA1"; + keyLength = 160; } else { hAlgorithm = "SHA-256"; hmacAlgorithm = "HmacSHA256"; + pbeAlgorithm = "PBKDF2WithHmacSHA256"; + keyLength = 256; } } @@ -224,7 +232,7 @@ String getClientProof(final String password, final String salt, final int iterat CacheKey cacheKey = new CacheKey(hashedPasswordAndSalt, salt, iterationCount); CacheValue cachedKeys = getMongoCredentialWithCache().getFromCache(cacheKey, CacheValue.class); if (cachedKeys == null) { - byte[] saltedPassword = hi(password.getBytes(StandardCharsets.UTF_8), Base64.getDecoder().decode(salt), iterationCount); + byte[] saltedPassword = hi(password, Base64.getDecoder().decode(salt), iterationCount); byte[] clientKey = hmac(saltedPassword, "Client Key"); byte[] serverKey = hmac(saltedPassword, "Server Key"); cachedKeys = new CacheValue(clientKey, serverKey); @@ -246,25 +254,15 @@ private byte[] h(final byte[] data) throws SaslException { } } - private byte[] hi(final byte[] password, final byte[] salt, final int iterations) throws SaslException { + private byte[] hi(final String password, final byte[] salt, final int iterations) throws SaslException { try { - SecretKeySpec key = new SecretKeySpec(password, hmacAlgorithm); - Mac mac = Mac.getInstance(hmacAlgorithm); - mac.init(key); - mac.update(salt); - mac.update(INT_1); - byte[] result = mac.doFinal(); - byte[] previous = null; - for (int i = 1; i < iterations; i++) { - mac.update(previous != null ? previous : result); - previous = mac.doFinal(); - xorInPlace(result, previous); - } - return result; + SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(pbeAlgorithm); + PBEKeySpec spec = new PBEKeySpec(password.toCharArray(), salt, iterations, keyLength); + return secretKeyFactory.generateSecret(spec).getEncoded(); } catch (NoSuchAlgorithmException e) { - throw new SaslException(format("Algorithm for '%s' could not be found.", hmacAlgorithm), e); - } catch (InvalidKeyException e) { - throw new SaslException(format("Invalid key for %s", hmacAlgorithm), e); + throw new SaslException(format("Algorithm for '%s' could not be found.", pbeAlgorithm), e); + } catch (InvalidKeySpecException e) { + throw new SaslException(format("Invalid key specification for '%s'", pbeAlgorithm), e); } }