Skip to content

Commit

Permalink
Add support for EdDSA, XDH and RSA-PSS key parsing
Browse files Browse the repository at this point in the history
This works with Java 17 and up. Also refactor the test for more
structure.

Closes gh-37237
  • Loading branch information
mhalbritter committed Sep 15, 2023
1 parent 16d1a31 commit 408fb8a
Show file tree
Hide file tree
Showing 43 changed files with 536 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
Expand All @@ -48,26 +48,28 @@
*/
final class PrivateKeyParser {

private static final String PKCS1_HEADER = "-+BEGIN\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";
private static final String PKCS1_RSA_HEADER = "-+BEGIN\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String PKCS1_FOOTER = "-+END\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+";
private static final String PKCS1_RSA_FOOTER = "-+END\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String PKCS8_HEADER = "-+BEGIN\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String PKCS8_FOOTER = "-+END\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String EC_HEADER = "-+BEGIN\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";
private static final String SEC1_EC_HEADER = "-+BEGIN\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String EC_FOOTER = "-+END\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+";
private static final String SEC1_EC_FOOTER = "-+END\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String BASE64_TEXT = "([a-z0-9+/=\\r\\n]+)";

private static final List<PemParser> PEM_PARSERS;
static {
List<PemParser> parsers = new ArrayList<>();
parsers.add(new PemParser(PKCS1_HEADER, PKCS1_FOOTER, PrivateKeyParser::createKeySpecForPkcs1, "RSA"));
parsers.add(new PemParser(EC_HEADER, EC_FOOTER, PrivateKeyParser::createKeySpecForEc, "EC"));
parsers.add(new PemParser(PKCS8_HEADER, PKCS8_FOOTER, PKCS8EncodedKeySpec::new, "RSA", "EC", "DSA"));
parsers
.add(new PemParser(PKCS1_RSA_HEADER, PKCS1_RSA_FOOTER, PrivateKeyParser::createKeySpecForPkcs1Rsa, "RSA"));
parsers.add(new PemParser(SEC1_EC_HEADER, SEC1_EC_FOOTER, PrivateKeyParser::createKeySpecForSec1Ec, "EC"));
parsers.add(new PemParser(PKCS8_HEADER, PKCS8_FOOTER, PKCS8EncodedKeySpec::new, "RSA", "RSASSA-PSS", "EC",
"DSA", "EdDSA", "XDH"));
PEM_PARSERS = Collections.unmodifiableList(parsers);
}

Expand All @@ -89,11 +91,11 @@ final class PrivateKeyParser {
private PrivateKeyParser() {
}

private static PKCS8EncodedKeySpec createKeySpecForPkcs1(byte[] bytes) {
private static PKCS8EncodedKeySpec createKeySpecForPkcs1Rsa(byte[] bytes) {
return createKeySpecForAlgorithm(bytes, RSA_ALGORITHM, null);
}

private static PKCS8EncodedKeySpec createKeySpecForEc(byte[] bytes) {
private static PKCS8EncodedKeySpec createKeySpecForSec1Ec(byte[] bytes) {
DerElement ecPrivateKey = DerElement.of(bytes);
Assert.state(ecPrivateKey.isType(ValueType.ENCODED, TagType.SEQUENCE),
"Key spec should be an ASN.1 encoded sequence");
Expand Down Expand Up @@ -200,21 +202,16 @@ private static byte[] decodeBase64(String content) {
}

private PrivateKey parse(byte[] bytes) {
try {
PKCS8EncodedKeySpec keySpec = this.keySpecFactory.apply(bytes);
for (String algorithm : this.algorithms) {
PKCS8EncodedKeySpec keySpec = this.keySpecFactory.apply(bytes);
for (String algorithm : this.algorithms) {
try {
KeyFactory keyFactory = KeyFactory.getInstance(algorithm);
try {
return keyFactory.generatePrivate(keySpec);
}
catch (InvalidKeySpecException ex) {
}
return keyFactory.generatePrivate(keySpec);
}
catch (InvalidKeySpecException | NoSuchAlgorithmException ex) {
}
return null;
}
catch (GeneralSecurityException ex) {
throw new IllegalArgumentException("Unexpected key format", ex);
}
return null;
}

}
Expand Down Expand Up @@ -302,7 +299,7 @@ static final class DerElement {

private final long tagType;

private ByteBuffer contents;
private final ByteBuffer contents;

private DerElement(ByteBuffer bytes) {
byte b = bytes.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import java.io.Reader;
import java.net.URL;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
Expand All @@ -50,26 +50,28 @@
*/
final class PrivateKeyParser {

private static final String PKCS1_HEADER = "-+BEGIN\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";
private static final String PKCS1_RSA_HEADER = "-+BEGIN\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String PKCS1_FOOTER = "-+END\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+";
private static final String PKCS1_RSA_FOOTER = "-+END\\s+RSA\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String PKCS8_HEADER = "-+BEGIN\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String PKCS8_FOOTER = "-+END\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String EC_HEADER = "-+BEGIN\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";
private static final String SEC1_EC_HEADER = "-+BEGIN\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+";

private static final String EC_FOOTER = "-+END\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+";
private static final String SEC1_EC_FOOTER = "-+END\\s+EC\\s+PRIVATE\\s+KEY[^-]*-+";

private static final String BASE64_TEXT = "([a-z0-9+/=\\r\\n]+)";

private static final List<PemParser> PEM_PARSERS;
static {
List<PemParser> parsers = new ArrayList<>();
parsers.add(new PemParser(PKCS1_HEADER, PKCS1_FOOTER, PrivateKeyParser::createKeySpecForPkcs1, "RSA"));
parsers.add(new PemParser(EC_HEADER, EC_FOOTER, PrivateKeyParser::createKeySpecForEc, "EC"));
parsers.add(new PemParser(PKCS8_HEADER, PKCS8_FOOTER, PKCS8EncodedKeySpec::new, "RSA", "EC", "DSA"));
parsers
.add(new PemParser(PKCS1_RSA_HEADER, PKCS1_RSA_FOOTER, PrivateKeyParser::createKeySpecForPkcs1Rsa, "RSA"));
parsers.add(new PemParser(SEC1_EC_HEADER, SEC1_EC_FOOTER, PrivateKeyParser::createKeySpecForSec1Ec, "EC"));
parsers.add(new PemParser(PKCS8_HEADER, PKCS8_FOOTER, PKCS8EncodedKeySpec::new, "RSA", "RSASSA-PSS", "EC",
"DSA", "EdDSA", "XDH"));
PEM_PARSERS = Collections.unmodifiableList(parsers);
}

Expand All @@ -91,11 +93,11 @@ final class PrivateKeyParser {
private PrivateKeyParser() {
}

private static PKCS8EncodedKeySpec createKeySpecForPkcs1(byte[] bytes) {
private static PKCS8EncodedKeySpec createKeySpecForPkcs1Rsa(byte[] bytes) {
return createKeySpecForAlgorithm(bytes, RSA_ALGORITHM, null);
}

private static PKCS8EncodedKeySpec createKeySpecForEc(byte[] bytes) {
private static PKCS8EncodedKeySpec createKeySpecForSec1Ec(byte[] bytes) {
DerElement ecPrivateKey = DerElement.of(bytes);
Assert.state(ecPrivateKey.isType(ValueType.ENCODED, TagType.SEQUENCE),
"Key spec should be an ASN.1 encoded sequence");
Expand Down Expand Up @@ -203,21 +205,16 @@ private static byte[] decodeBase64(String content) {
}

private PrivateKey parse(byte[] bytes) {
try {
PKCS8EncodedKeySpec keySpec = this.keySpecFactory.apply(bytes);
for (String algorithm : this.algorithms) {
PKCS8EncodedKeySpec keySpec = this.keySpecFactory.apply(bytes);
for (String algorithm : this.algorithms) {
try {
KeyFactory keyFactory = KeyFactory.getInstance(algorithm);
try {
return keyFactory.generatePrivate(keySpec);
}
catch (InvalidKeySpecException ex) {
}
return keyFactory.generatePrivate(keySpec);
}
catch (InvalidKeySpecException | NoSuchAlgorithmException ex) {
}
return null;
}
catch (GeneralSecurityException ex) {
throw new IllegalArgumentException("Unexpected key format", ex);
}
return null;
}

}
Expand Down Expand Up @@ -305,7 +302,7 @@ static final class DerElement {

private final long tagType;

private ByteBuffer contents;
private final ByteBuffer contents;

private DerElement(ByteBuffer bytes) {
byte b = bytes.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import java.security.interfaces.ECPrivateKey;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Tests for {@link PrivateKeyParser}.
Expand All @@ -33,50 +37,166 @@
* @author Moritz Halbritter
* @author Phillip Webb
*/
// https://docs.oracle.com/en/java/javase/17/security/oracle-providers.html#GUID-091BF58C-82AB-4C9C-850F-1660824D5254
class PrivateKeyParserTests {

@Test
void parsePkcs8RsaKeyFile() {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:ssl/pkcs8/key-rsa.pem");
@ParameterizedTest
// @formatter:off
@CsvSource({
"dsa.key, DSA",
"rsa.key, RSA",
"rsa-pss.key, RSASSA-PSS"
})
// @formatter:on
void shouldParseTraditionalPkcs8(String file, String algorithm) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs8/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo(algorithm);
}

@ParameterizedTest
// @formatter:off
@CsvSource({
"rsa.key, RSA"
})
// @formatter:on
void shouldParseTraditionalPkcs1(String file, String algorithm) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs1/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo("RSA");
assertThat(privateKey.getAlgorithm()).isEqualTo(algorithm);
}

@ParameterizedTest
// @formatter:off
@ValueSource(strings = {
"dsa.key"
})
// @formatter:on
void shouldNotParseUnsupportedTraditionalPkcs1(String file) {
assertThatThrownBy(() -> PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs1/" + file))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Error loading private key file")
.hasCauseInstanceOf(IllegalStateException.class)
.getCause()
.hasMessageContaining("Unrecognized private key format");
}

@ParameterizedTest
@ValueSource(strings = { "key-ec-nist-p256.pem", "key-ec-nist-p384.pem", "key-ec-prime256v1.pem",
"key-ec-secp256r1.pem" })
void parsePkcs8EcKeyFile(String fileName) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:ssl/pkcs8/" + fileName);
// @formatter:off
@CsvSource({
"brainpoolP256r1.key, brainpoolP256r1, 1.3.36.3.3.2.8.1.1.7",
"brainpoolP320r1.key, brainpoolP320r1, 1.3.36.3.3.2.8.1.1.9",
"brainpoolP384r1.key, brainpoolP384r1, 1.3.36.3.3.2.8.1.1.11",
"brainpoolP512r1.key, brainpoolP512r1, 1.3.36.3.3.2.8.1.1.13",
"prime256v1.key, secp256r1, 1.2.840.10045.3.1.7",
"secp224r1.key, secp224r1, 1.3.132.0.33",
"secp256k1.key, secp256k1, 1.3.132.0.10",
"secp256r1.key, secp256r1, 1.2.840.10045.3.1.7",
"secp384r1.key, secp384r1, 1.3.132.0.34",
"secp521r1.key, secp521r1, 1.3.132.0.35"
})
// @formatter:on
void shouldParseEcPkcs8(String file, String curveName, String oid) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs8/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo("EC");
assertThat(privateKey).isInstanceOf(ECPrivateKey.class);
ECPrivateKey ecPrivateKey = (ECPrivateKey) privateKey;
assertThat(ecPrivateKey.getParams().toString()).contains(curveName).contains(oid);
}

@Test
void parsePkcs8DsaKeyFile() {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:ssl/pkcs8/key-dsa.pem");
@ParameterizedTest
// @formatter:off
@ValueSource(strings = {
"brainpoolP256t1.key",
"brainpoolP320t1.key",
"brainpoolP384t1.key",
"brainpoolP512t1.key"
})
// @formatter:on
void shouldNotParseUnsupportedEcPkcs8(String file) {
assertThatThrownBy(() -> PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs8/" + file))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Error loading private key file")
.hasCauseInstanceOf(IllegalStateException.class)
.getCause()
.hasMessageContaining("Unrecognized private key format");
}

@EnabledForJreRange(min = JRE.JAVA_17, disabledReason = "EdDSA is only supported since Java 17")
@ParameterizedTest
// @formatter:off
@ValueSource(strings = {
"ed448.key",
"ed25519.key"
})
// @formatter:on
void shouldParseEdDsaPkcs8(String file) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs8/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo("DSA");
assertThat(privateKey.getAlgorithm()).isEqualTo("EdDSA");
}

@Test
void parsePemKeyFileWithEcdsa() {
ECPrivateKey privateKey = (ECPrivateKey) PrivateKeyParser.parse("classpath:test-ec-key.pem");
@EnabledForJreRange(min = JRE.JAVA_17, disabledReason = "XDH is only supported since Java 17")
@ParameterizedTest
// @formatter:off
@ValueSource(strings = {
"x448.key",
"x25519.key"
})
// @formatter:on
void shouldParseXdhPkcs8(String file) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/pkcs8/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo("EC");
assertThat(privateKey.getParams().toString()).contains("1.3.132.0.34").doesNotContain("prime256v1");
assertThat(privateKey.getAlgorithm()).isEqualTo("XDH");
}

@Test
void parsePemKeyFileWithEcdsaPrime256v1() {
ECPrivateKey privateKey = (ECPrivateKey) PrivateKeyParser.parse("classpath:test-ec-key-prime256v1.pem");
@ParameterizedTest
// @formatter:off
@CsvSource({
"brainpoolP256r1.key, brainpoolP256r1, 1.3.36.3.3.2.8.1.1.7",
"brainpoolP320r1.key, brainpoolP320r1, 1.3.36.3.3.2.8.1.1.9",
"brainpoolP384r1.key, brainpoolP384r1, 1.3.36.3.3.2.8.1.1.11",
"brainpoolP512r1.key, brainpoolP512r1, 1.3.36.3.3.2.8.1.1.13",
"prime256v1.key, secp256r1, 1.2.840.10045.3.1.7",
"secp224r1.key, secp224r1, 1.3.132.0.33",
"secp256k1.key, secp256k1, 1.3.132.0.10",
"secp256r1.key, secp256r1, 1.2.840.10045.3.1.7",
"secp384r1.key, secp384r1, 1.3.132.0.34",
"secp521r1.key, secp521r1, 1.3.132.0.35"
})
// @formatter:on
void shouldParseEcSec1(String file, String curveName, String oid) {
PrivateKey privateKey = PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/sec1/" + file);
assertThat(privateKey).isNotNull();
assertThat(privateKey.getFormat()).isEqualTo("PKCS#8");
assertThat(privateKey.getAlgorithm()).isEqualTo("EC");
assertThat(privateKey.getParams().toString()).contains("prime256v1").doesNotContain("1.3.132.0.34");
assertThat(privateKey).isInstanceOf(ECPrivateKey.class);
ECPrivateKey ecPrivateKey = (ECPrivateKey) privateKey;
assertThat(ecPrivateKey.getParams().toString()).contains(curveName).contains(oid);
}

@ParameterizedTest
// @formatter:off
@ValueSource(strings = {
"brainpoolP256t1.key",
"brainpoolP320t1.key",
"brainpoolP384t1.key",
"brainpoolP512t1.key"
})
// @formatter:on
void shouldNotParseUnsupportedEcSec1(String file) {
assertThatThrownBy(() -> PrivateKeyParser.parse("classpath:org/springframework/boot/web/server/sec1/" + file))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Error loading private key file")
.hasCauseInstanceOf(IllegalStateException.class)
.getCause()
.hasMessageContaining("Unrecognized private key format");
}

@Test
Expand Down
Loading

0 comments on commit 408fb8a

Please sign in to comment.