Skip to content

Commit

Permalink
Support "External Mu" ML-DSA (#423)
Browse files Browse the repository at this point in the history
This PR adds support for the External Mu variant of ML-DSA described in
[AWS-LC PR 2113](aws/aws-lc#2113). To facilitate
calculation of mu in testing and for consumers, we expose a
`computeMLDSAMu` helper function in a new `PublicUtils` class. Notably,
External Mu ML-DSA is interoperable with Pure ML-DSA across signing and
verification, as shown in tests.

We expose External Mu as a distinct `Signature` service to preserve
forwards-compatibility with other providers' potential future
implementations. Conventionally, this "prehash" mode would be configured
with a call to
[`setParameter`](https://docs.oracle.com/en/java/javase/22/docs/api/java.base/java/security/Signature.html#setParameter(java.security.spec.AlgorithmParameterSpec)),
passing an `AlgorithmParameterSpec` implementation [as in
EdDSA](https://download.java.net/java/early_access/panama/docs/api/java.base/java/security/spec/EdDSAParameterSpec.html).
Because no other providers have yet created an `AlgorithmParameterSpec`
implementation that is aware of External Mu, this approach would have
required us to expose an ACCP-specific `AlgorithmParameterSpec` that
callers would need to use, hampering interoperability.

---

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
WillChilds-Klein authored Jan 31, 2025
1 parent b2472a5 commit aa9bfd2
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 12 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ set(C_SRC
csrc/sha384.cpp
csrc/sha512.cpp
csrc/sign.cpp
csrc/test_util.cpp
csrc/testhooks.cpp
csrc/util.cpp
csrc/util_class.cpp
Expand Down
23 changes: 13 additions & 10 deletions csrc/sign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ bool initializeContext(raii_env& env,
const EVP_MD* md,
jint paddingType,
const EVP_MD* mgfMdPtr,
jint pssSaltLen)
jint pssSaltLen,
bool preHash)
{
EVP_PKEY_CTX* pctx; // Logically owned by the ctx so doesn't need to be freed separately

Expand All @@ -69,7 +70,7 @@ bool initializeContext(raii_env& env,
#if defined(FIPS_BUILD) && !defined(EXPERIMENTAL_FIPS_BUILD)
if (md != nullptr || EVP_PKEY_id(pKey) == EVP_PKEY_ED25519) {
#else
if (md != nullptr || EVP_PKEY_id(pKey) == EVP_PKEY_ED25519 || EVP_PKEY_id(pKey) == EVP_PKEY_PQDSA) {
if (md != nullptr || EVP_PKEY_id(pKey) == EVP_PKEY_ED25519 || (EVP_PKEY_id(pKey) == EVP_PKEY_PQDSA && !preHash)) {
#endif
if (!ctx->setDigestCtx(EVP_MD_CTX_create())) {
throw_openssl("Unable to create MD_CTX");
Expand Down Expand Up @@ -171,7 +172,7 @@ JNIEXPORT jlong JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature_si
initializeContext(env, &ctx,
true, // true->sign
reinterpret_cast<EVP_PKEY*>(pKey), reinterpret_cast<const EVP_MD*>(mdPtr), paddingType,
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, false);

update(env, &ctx, digestSignUpdate, java_buffer::from_array(env, message, offset, length));
return reinterpret_cast<jlong>(ctx.moveToHeap());
Expand All @@ -192,7 +193,7 @@ JNIEXPORT jlong JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature_si
initializeContext(env, &ctx,
true, // true->sign
reinterpret_cast<EVP_PKEY*>(pKey), reinterpret_cast<const EVP_MD*>(mdPtr), paddingType,
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, false);
update(env, &ctx, digestSignUpdate, java_buffer::from_direct(env, message));

return reinterpret_cast<jlong>(ctx.moveToHeap());
Expand Down Expand Up @@ -221,7 +222,7 @@ JNIEXPORT jlong JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature_ve
initializeContext(env, &ctx,
false, // false->verify
reinterpret_cast<EVP_PKEY*>(pKey), reinterpret_cast<const EVP_MD*>(mdPtr), paddingType,
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, false);
update(env, &ctx, digestVerifyUpdate, java_buffer::from_array(env, message, offset, length));

return reinterpret_cast<jlong>(ctx.moveToHeap());
Expand All @@ -242,7 +243,7 @@ JNIEXPORT jlong JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature_ve
initializeContext(env, &ctx,
false, // false->verify
reinterpret_cast<EVP_PKEY*>(pKey), reinterpret_cast<const EVP_MD*>(mdPtr), paddingType,
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, false);
update(env, &ctx, digestVerifyUpdate, java_buffer::from_direct(env, message));

return reinterpret_cast<jlong>(ctx.moveToHeap());
Expand Down Expand Up @@ -438,6 +439,7 @@ JNIEXPORT jbyteArray JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignatu
jclass clazz,
jlong pKey,
jint paddingType,
jboolean preHash,
jlong mgfMdPtr,
jint pssSaltLen,
jbyteArray messageArr,
Expand All @@ -453,7 +455,7 @@ JNIEXPORT jbyteArray JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignatu
true, // true->sign
reinterpret_cast<EVP_PKEY*>(pKey),
nullptr, // No message digest
paddingType, reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
paddingType, reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, preHash);

std::vector<uint8_t, SecureAlloc<uint8_t> > signature;
size_t sigLength;
Expand All @@ -463,7 +465,7 @@ JNIEXPORT jbyteArray JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignatu
#if defined(FIPS_BUILD) && !defined(EXPERIMENTAL_FIPS_BUILD)
if (keyType == EVP_PKEY_ED25519) {
#else
if (keyType == EVP_PKEY_ED25519 || keyType == EVP_PKEY_PQDSA) {
if (keyType == EVP_PKEY_ED25519 || (keyType == EVP_PKEY_PQDSA && !preHash)) {
#endif
jni_borrow message(env, messageBuf, "message");

Expand Down Expand Up @@ -508,6 +510,7 @@ JNIEXPORT jboolean JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature
jclass clazz,
jlong pKey,
jint paddingType,
jboolean preHash,
jlong mgfMdPtr,
jint pssSaltLen,
jbyteArray messageArr,
Expand All @@ -527,7 +530,7 @@ JNIEXPORT jboolean JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature
false, // false->verify
reinterpret_cast<EVP_PKEY*>(pKey),
nullptr, // no message digest
paddingType, reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen);
paddingType, reinterpret_cast<const EVP_MD*>(mgfMdPtr), pssSaltLen, preHash);

jni_borrow message(env, messageBuf, "message");
jni_borrow signature(env, signatureBuf, "signature");
Expand All @@ -537,7 +540,7 @@ JNIEXPORT jboolean JNICALL Java_com_amazon_corretto_crypto_provider_EvpSignature
#if defined(FIPS_BUILD) && !defined(EXPERIMENTAL_FIPS_BUILD)
if (keyType == EVP_PKEY_ED25519) {
#else
if (keyType == EVP_PKEY_ED25519 || keyType == EVP_PKEY_PQDSA) {
if (keyType == EVP_PKEY_ED25519 || (keyType == EVP_PKEY_PQDSA && !preHash)) {
#endif
ret = EVP_DigestVerify(
ctx.getDigestCtx(), signature.data(), signature.len(), message.data(), message.len());
Expand Down
68 changes: 68 additions & 0 deletions csrc/test_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
#include <openssl/bytestring.h>
#include <openssl/evp.h>
#include <cstdlib>

#include "auto_free.h"
#include "env.h"

namespace AmazonCorrettoCryptoProvider {

/*
* Class: com_amazon_corretto_crypto_provider_test_TestUtil
* Method: computeMLDSAMuInternal
* Signature: ([B[B)[B
*/
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_amazon_corretto_crypto_provider_test_TestUtil_computeMLDSAMuInternal(
JNIEnv* pEnv, jclass, jbyteArray pubKeyEncodedArr, jbyteArray messageArr)
{
try {
raii_env env(pEnv);
jsize pub_key_der_len = env->GetArrayLength(pubKeyEncodedArr);
jsize message_len = env->GetArrayLength(messageArr);
uint8_t* pub_key_der = (uint8_t*)env->GetByteArrayElements(pubKeyEncodedArr, nullptr);
CHECK_OPENSSL(pub_key_der);
uint8_t* message = (uint8_t*)env->GetByteArrayElements(messageArr, nullptr);
CHECK_OPENSSL(message);

CBS cbs;
CBS_init(&cbs, pub_key_der, pub_key_der_len);
EVP_PKEY_auto pkey = EVP_PKEY_auto::from((EVP_parse_public_key(&cbs)));
EVP_PKEY_CTX_auto ctx = EVP_PKEY_CTX_auto::from(EVP_PKEY_CTX_new(pkey.get(), nullptr));
EVP_MD_CTX_auto md_ctx_mu = EVP_MD_CTX_auto::from(EVP_MD_CTX_new());
EVP_MD_CTX_auto md_ctx_pk = EVP_MD_CTX_auto::from(EVP_MD_CTX_new());

size_t pk_len; // fetch the public key length
CHECK_OPENSSL(EVP_PKEY_get_raw_public_key(pkey.get(), nullptr, &pk_len));
std::vector<uint8_t> pk(pk_len);
CHECK_OPENSSL(EVP_PKEY_get_raw_public_key(pkey.get(), pk.data(), &pk_len));
uint8_t tr[64] = { 0 };
uint8_t mu[64] = { 0 };
uint8_t pre[2] = { 0 };

// get raw public key and hash it
CHECK_OPENSSL(EVP_DigestInit_ex(md_ctx_pk.get(), EVP_shake256(), nullptr));
CHECK_OPENSSL(EVP_DigestUpdate(md_ctx_pk.get(), pk.data(), pk_len));
CHECK_OPENSSL(EVP_DigestFinalXOF(md_ctx_pk.get(), tr, sizeof(tr)));

// compute mu
CHECK_OPENSSL(EVP_DigestInit_ex(md_ctx_mu.get(), EVP_shake256(), nullptr));
CHECK_OPENSSL(EVP_DigestUpdate(md_ctx_mu.get(), tr, sizeof(tr)));
CHECK_OPENSSL(EVP_DigestUpdate(md_ctx_mu.get(), pre, sizeof(pre)));
CHECK_OPENSSL(EVP_DigestUpdate(md_ctx_mu.get(), message, message_len));
CHECK_OPENSSL(EVP_DigestFinalXOF(md_ctx_mu.get(), mu, sizeof(mu)));

env->ReleaseByteArrayElements(pubKeyEncodedArr, (jbyte*)pub_key_der, 0);
env->ReleaseByteArrayElements(messageArr, (jbyte*)message, 0);

jbyteArray ret = env->NewByteArray(sizeof(mu));
env->SetByteArrayRegion(ret, 0, sizeof(mu), (const jbyte*)mu);
return ret;
} catch (java_ex& ex) {
ex.throw_to_java(pEnv);
return 0;
}
}

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ private void addSignatures() {

if (shouldRegisterMLDSA) {
addService("Signature", "ML-DSA", "EvpSignatureRaw$MLDSA");
addService("Signature", "ML-DSA-ExtMu", "EvpSignatureRaw$MLDSAExtMu");
}
}

Expand Down
22 changes: 21 additions & 1 deletion src/com/amazon/corretto/crypto/provider/EvpSignatureRaw.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@
class EvpSignatureRaw extends EvpSignatureBase {
private AccessibleByteArrayOutputStream buffer =
new AccessibleByteArrayOutputStream(64, 1024 * 1024);
private final boolean preHash_;

private EvpSignatureRaw(
final AmazonCorrettoCryptoProvider provider,
final EvpKeyType keyType,
final int paddingType) {
this(provider, keyType, paddingType, false);
}

private EvpSignatureRaw(
final AmazonCorrettoCryptoProvider provider,
final EvpKeyType keyType,
final int paddingType,
final boolean preHash) {
super(provider, keyType, paddingType, 0 /* No digest */);
preHash_ = preHash;
}

@Override
Expand Down Expand Up @@ -42,7 +52,8 @@ protected byte[] engineSign() throws SignatureException {
try {
ensureInitialized(true);
return key_.use(
ptr -> signRaw(ptr, paddingType_, 0, 0, buffer.getDataBuffer(), 0, buffer.size()));
ptr ->
signRaw(ptr, paddingType_, preHash_, 0, 0, buffer.getDataBuffer(), 0, buffer.size()));
} finally {
engineReset();
}
Expand All @@ -64,6 +75,7 @@ protected boolean engineVerify(final byte[] sigBytes, final int offset, final in
verifyRaw(
ptr,
paddingType_,
preHash_,
0,
0,
buffer.getDataBuffer(),
Expand All @@ -84,6 +96,7 @@ protected boolean isBufferEmpty() {
private static native byte[] signRaw(
long privateKey,
int paddingType,
boolean preHash,
long mgfMd,
int saltLen,
byte[] message,
Expand All @@ -93,6 +106,7 @@ private static native byte[] signRaw(
private static native boolean verifyRaw(
long publicKey,
int paddingType,
boolean preHash,
long mgfMd,
int saltLen,
byte[] message,
Expand Down Expand Up @@ -120,4 +134,10 @@ static final class MLDSA extends EvpSignatureRaw {
super(provider, EvpKeyType.MLDSA, 0);
}
}

static final class MLDSAExtMu extends EvpSignatureRaw {
MLDSAExtMu(final AmazonCorrettoCryptoProvider provider) {
super(provider, EvpKeyType.MLDSA, 0, /*preHash*/ true);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ public void simpleCorrectnessAllAlgorithms() throws Throwable {
if (!service.getType().equals("Signature") || "RSASSA-PSS".equals(algorithm)) {
continue;
}
if (algorithm.equals("Ed25519") || algorithm.equals("EdDSA") || algorithm.equals("ML-DSA")) {
if (algorithm.equals("Ed25519")
|| algorithm.equals("EdDSA")
|| algorithm.startsWith("ML-DSA")) {
continue;
}
String bcAlgorithm = algorithm;
Expand Down
80 changes: 80 additions & 0 deletions tst/com/amazon/corretto/crypto/provider/test/MLDSATest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package com.amazon.corretto.crypto.provider.test;

import static com.amazon.corretto.crypto.provider.test.TestUtil.assertThrows;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -254,4 +255,83 @@ public void documentBouncyCastleDifferences() throws Exception {
nativeSignature.initVerify(bcPub);
assertTrue(nativeSignature.verify(sigBytes));
}

@ParameterizedTest
@MethodSource("getParams")
public void testExtMu(TestParams params) throws Exception {
// Only ACCP currently supports External Mu
assumeTrue(params.signerProv == NATIVE_PROVIDER && params.verifierProv == NATIVE_PROVIDER);

Signature signer = Signature.getInstance("ML-DSA", NATIVE_PROVIDER);
Signature verifier = Signature.getInstance("ML-DSA", NATIVE_PROVIDER);
Signature extMuSigner = Signature.getInstance("ML-DSA-ExtMu", NATIVE_PROVIDER);
Signature extMuVerifier = Signature.getInstance("ML-DSA-ExtMu", NATIVE_PROVIDER);
PrivateKey priv = params.priv;
PublicKey pub = params.pub;

byte[] message = Arrays.copyOf(params.message, params.message.length);
byte[] mu = TestUtil.computeMLDSAMu(pub, message);
assertEquals(64, mu.length);
byte[] fakeMu = new byte[64];
Arrays.fill(fakeMu, (byte) 0);

// Test with "fake mu" -- contents don't matter if we're signing and verifying mu
extMuSigner.initSign(priv);
extMuSigner.update(fakeMu);
byte[] signatureBytes = extMuSigner.sign();
extMuVerifier.initVerify(pub);
extMuVerifier.update(fakeMu);
assertTrue(extMuVerifier.verify(signatureBytes));

// Test with real mu
extMuSigner.initSign(priv);
extMuSigner.update(mu);
signatureBytes = extMuSigner.sign();
extMuVerifier.initVerify(pub);
extMuVerifier.update(mu);
assertTrue(extMuVerifier.verify(signatureBytes));

// Sign mu, verify with message
extMuSigner.initSign(priv);
extMuSigner.update(mu);
signatureBytes = extMuSigner.sign();
verifier.initVerify(pub);
verifier.update(message);
assertTrue(verifier.verify(signatureBytes));

// Sign message, verify with mu
signer.initSign(priv);
signer.update(message);
signatureBytes = signer.sign();
extMuVerifier.initVerify(pub);
extMuVerifier.update(mu);
assertTrue(extMuVerifier.verify(signatureBytes));

// Tampering the signature induces failure
extMuSigner.initSign(priv);
extMuSigner.update(mu);
signatureBytes = extMuSigner.sign();
signatureBytes[0] ^= 0xff;
extMuVerifier.initVerify(pub);
extMuVerifier.update(mu);
assertFalse(extMuVerifier.verify(signatureBytes));
}

@ParameterizedTest
@ValueSource(strings = {"ML-DSA-44", "ML-DSA-65", "ML-DSA-87"})
public void testComputeMLDSAExtMu(String algorithm) throws Exception {
KeyPair keyPair = KeyPairGenerator.getInstance(algorithm, NATIVE_PROVIDER).generateKeyPair();
PublicKey nativePub = keyPair.getPublic();
KeyFactory bcKf = KeyFactory.getInstance("ML-DSA", TestUtil.BC_PROVIDER);
PublicKey bcPub = bcKf.generatePublic(new X509EncodedKeySpec(nativePub.getEncoded()));

byte[] message = new byte[256];
Arrays.fill(message, (byte) 0x41);
byte[] mu = TestUtil.computeMLDSAMu(nativePub, message);
assertEquals(64, mu.length);
// We don't have any other implementations of mu calculation to test against, so just assert
// that mu is equivalent
// generated from both ACCP and BouncyCastle keys.
assertArrayEquals(mu, TestUtil.computeMLDSAMu(bcPub, message));
}
}
19 changes: 19 additions & 0 deletions tst/com/amazon/corretto/crypto/provider/test/TestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Security;
import java.util.ArrayList;
Expand Down Expand Up @@ -841,4 +842,22 @@ static boolean edKeyFactoryRegistered() {
return "true"
.equals(System.getProperty("com.amazon.corretto.crypto.provider.registerEdKeyFactory"));
}

private static native byte[] computeMLDSAMuInternal(byte[] pubKeyEncoded, byte[] message);

/**
* Computes mu as defined on line 6 of Algorithm 7 and line 7 of Algorithm 8 in NIST FIPS 204.
*
* <p>See <a href="https://csrc.nist.gov/pubs/fips/204/final">FIPS 204</a>
*
* @param publicKey ML-DSA public key
* @param message byte array of the message over which to compute mu
* @return a byte[] of length 64 containing mu
*/
static byte[] computeMLDSAMu(PublicKey publicKey, byte[] message) {
if (publicKey == null || !publicKey.getAlgorithm().startsWith("ML-DSA") || message == null) {
throw new IllegalArgumentException();
}
return computeMLDSAMuInternal(publicKey.getEncoded(), message);
}
}

0 comments on commit aa9bfd2

Please sign in to comment.