Skip to content

Commit

Permalink
Use EVP_PKEY_up_ref if available (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thalhammer authored Sep 13, 2022
1 parent c9a511f commit 3ed4ff9
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 68 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ test
*.o
*.o.d
.vscode
# ClangD cache files
.cache

doxy/
doxygen-awesome*.css
Expand Down
183 changes: 121 additions & 62 deletions include/jwt-cpp/jwt.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,83 @@ namespace jwt {
* you maybe need to extract the modulus and exponent of an RSA Public Key.
*/
namespace helper {
/**
* \brief Handle class for EVP_PKEY structures
*
* Starting from OpenSSL 1.1.0, EVP_PKEY has internal reference counting. This handle class allows
* jwt-cpp to leverage that and thus safe an allocation for the control block in std::shared_ptr.
* The handle uses shared_ptr as a fallback on older versions. The behaviour should be identical between both.
*/
class evp_pkey_handle {
public:
constexpr evp_pkey_handle() noexcept = default;
#ifdef JWT_OPENSSL_1_0_0
/**
* \brief Contruct a new handle. The handle takes ownership of the key.
* \param key The key to store
*/
explicit evp_pkey_handle(EVP_PKEY* key) { m_key = std::shared_ptr<EVP_PKEY>(key, EVP_PKEY_free); }

EVP_PKEY* get() const noexcept { return m_key.get(); }
bool operator!() const noexcept { return m_key == nullptr; }
explicit operator bool() const noexcept { return m_key != nullptr; }

private:
std::shared_ptr<EVP_PKEY> m_key{nullptr};
#else
/**
* \brief Contruct a new handle. The handle takes ownership of the key.
* \param key The key to store
*/
explicit constexpr evp_pkey_handle(EVP_PKEY* key) noexcept : m_key{key} {}
evp_pkey_handle(const evp_pkey_handle& other) : m_key{other.m_key} {
if (m_key != nullptr && EVP_PKEY_up_ref(m_key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed");
}
// C++11 requires the body of a constexpr constructor to be empty
#if __cplusplus >= 201402L
constexpr
#endif
evp_pkey_handle(evp_pkey_handle&& other) noexcept
: m_key{other.m_key} {
other.m_key = nullptr;
}
evp_pkey_handle& operator=(const evp_pkey_handle& other) {
if (&other == this) return *this;
decrement_ref_count(m_key);
m_key = other.m_key;
increment_ref_count(m_key);
return *this;
}
evp_pkey_handle& operator=(evp_pkey_handle&& other) noexcept {
if (&other == this) return *this;
decrement_ref_count(m_key);
m_key = other.m_key;
other.m_key = nullptr;
return *this;
}
evp_pkey_handle& operator=(EVP_PKEY* key) {
decrement_ref_count(m_key);
m_key = key;
increment_ref_count(m_key);
return *this;
}
~evp_pkey_handle() noexcept { decrement_ref_count(m_key); }

EVP_PKEY* get() const noexcept { return m_key; }
bool operator!() const noexcept { return m_key == nullptr; }
explicit operator bool() const noexcept { return m_key != nullptr; }

private:
EVP_PKEY* m_key{nullptr};

static void increment_ref_count(EVP_PKEY* key) {
if (key != nullptr && EVP_PKEY_up_ref(key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed");
}
static void decrement_ref_count(EVP_PKEY* key) noexcept {
if (key != nullptr) EVP_PKEY_free(key);
}
#endif
};
/**
* \brief Extract the public key of a pem certificate
*
Expand Down Expand Up @@ -556,38 +633,34 @@ namespace jwt {
* \param password Password used to decrypt certificate (leave empty if not encrypted)
* \param ec error_code for error_detection (gets cleared if no error occures)
*/
inline std::shared_ptr<EVP_PKEY> load_public_key_from_string(const std::string& key,
const std::string& password, std::error_code& ec) {
inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password,
std::error_code& ec) {
ec.clear();
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if (!pubkey_bio) {
ec = error::rsa_error::create_mem_bio_failed;
return nullptr;
return {};
}
if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
auto epkey = helper::extract_pubkey_from_cert(key, password, ec);
if (ec) return nullptr;
if (ec) return {};
const int len = static_cast<int>(epkey.size());
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) {
ec = error::rsa_error::load_key_bio_write;
return nullptr;
return {};
}
} else {
const int len = static_cast<int>(key.size());
if (BIO_write(pubkey_bio.get(), key.data(), len) != len) {
ec = error::rsa_error::load_key_bio_write;
return nullptr;
return {};
}
}

std::shared_ptr<EVP_PKEY> pkey(
PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr,
(void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast`
EVP_PKEY_free);
if (!pkey) {
ec = error::rsa_error::load_key_bio_read;
return nullptr;
}
evp_pkey_handle pkey(PEM_read_bio_PUBKEY(
pubkey_bio.get(), nullptr, nullptr,
(void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast`
if (!pkey) ec = error::rsa_error::load_key_bio_read;
return pkey;
}

Expand All @@ -600,8 +673,7 @@ namespace jwt {
* \param password Password used to decrypt certificate or key (leave empty if not encrypted)
* \throw rsa_exception if an error occurred
*/
inline std::shared_ptr<EVP_PKEY> load_public_key_from_string(const std::string& key,
const std::string& password = "") {
inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password = "") {
std::error_code ec;
auto res = load_public_key_from_string(key, password, ec);
error::throw_if_error(ec);
Expand All @@ -615,25 +687,21 @@ namespace jwt {
* \param password Password used to decrypt key (leave empty if not encrypted)
* \param ec error_code for error_detection (gets cleared if no error occures)
*/
inline std::shared_ptr<EVP_PKEY>
load_private_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password,
std::error_code& ec) {
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if (!privkey_bio) {
ec = error::rsa_error::create_mem_bio_failed;
return nullptr;
return {};
}
const int len = static_cast<int>(key.size());
if (BIO_write(privkey_bio.get(), key.data(), len) != len) {
ec = error::rsa_error::load_key_bio_write;
return nullptr;
}
std::shared_ptr<EVP_PKEY> pkey(
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())),
EVP_PKEY_free);
if (!pkey) {
ec = error::rsa_error::load_key_bio_read;
return nullptr;
return {};
}
evp_pkey_handle pkey(
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())));
if (!pkey) ec = error::rsa_error::load_key_bio_read;
return pkey;
}

Expand All @@ -644,8 +712,7 @@ namespace jwt {
* \param password Password used to decrypt key (leave empty if not encrypted)
* \throw rsa_exception if an error occurred
*/
inline std::shared_ptr<EVP_PKEY> load_private_key_from_string(const std::string& key,
const std::string& password = "") {
inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password = "") {
std::error_code ec;
auto res = load_private_key_from_string(key, password, ec);
error::throw_if_error(ec);
Expand All @@ -661,38 +728,34 @@ namespace jwt {
* \param password Password used to decrypt certificate (leave empty if not encrypted)
* \param ec error_code for error_detection (gets cleared if no error occures)
*/
inline std::shared_ptr<EVP_PKEY>
load_public_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key, const std::string& password,
std::error_code& ec) {
ec.clear();
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if (!pubkey_bio) {
ec = error::ecdsa_error::create_mem_bio_failed;
return nullptr;
return {};
}
if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
auto epkey = helper::extract_pubkey_from_cert(key, password, ec);
if (ec) return nullptr;
if (ec) return {};
const int len = static_cast<int>(epkey.size());
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) {
ec = error::ecdsa_error::load_key_bio_write;
return nullptr;
return {};
}
} else {
const int len = static_cast<int>(key.size());
if (BIO_write(pubkey_bio.get(), key.data(), len) != len) {
ec = error::ecdsa_error::load_key_bio_write;
return nullptr;
return {};
}
}

std::shared_ptr<EVP_PKEY> pkey(
PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr,
(void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast`
EVP_PKEY_free);
if (!pkey) {
ec = error::ecdsa_error::load_key_bio_read;
return nullptr;
}
evp_pkey_handle pkey(PEM_read_bio_PUBKEY(
pubkey_bio.get(), nullptr, nullptr,
(void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast`
if (!pkey) ec = error::ecdsa_error::load_key_bio_read;
return pkey;
}

Expand All @@ -705,8 +768,8 @@ namespace jwt {
* \param password Password used to decrypt certificate or key (leave empty if not encrypted)
* \throw ecdsa_exception if an error occurred
*/
inline std::shared_ptr<EVP_PKEY> load_public_ec_key_from_string(const std::string& key,
const std::string& password = "") {
inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key,
const std::string& password = "") {
std::error_code ec;
auto res = load_public_ec_key_from_string(key, password, ec);
error::throw_if_error(ec);
Expand All @@ -720,25 +783,21 @@ namespace jwt {
* \param password Password used to decrypt key (leave empty if not encrypted)
* \param ec error_code for error_detection (gets cleared if no error occures)
*/
inline std::shared_ptr<EVP_PKEY>
load_private_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key, const std::string& password,
std::error_code& ec) {
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if (!privkey_bio) {
ec = error::ecdsa_error::create_mem_bio_failed;
return nullptr;
return {};
}
const int len = static_cast<int>(key.size());
if (BIO_write(privkey_bio.get(), key.data(), len) != len) {
ec = error::ecdsa_error::load_key_bio_write;
return nullptr;
}
std::shared_ptr<EVP_PKEY> pkey(
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())),
EVP_PKEY_free);
if (!pkey) {
ec = error::ecdsa_error::load_key_bio_read;
return nullptr;
return {};
}
evp_pkey_handle pkey(
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())));
if (!pkey) ec = error::ecdsa_error::load_key_bio_read;
return pkey;
}

Expand All @@ -749,8 +808,8 @@ namespace jwt {
* \param password Password used to decrypt key (leave empty if not encrypted)
* \throw ecdsa_exception if an error occurred
*/
inline std::shared_ptr<EVP_PKEY> load_private_ec_key_from_string(const std::string& key,
const std::string& password = "") {
inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key,
const std::string& password = "") {
std::error_code ec;
auto res = load_private_ec_key_from_string(key, password, ec);
error::throw_if_error(ec);
Expand Down Expand Up @@ -990,7 +1049,7 @@ namespace jwt {

private:
/// OpenSSL structure containing converted keys
std::shared_ptr<EVP_PKEY> pkey;
helper::evp_pkey_handle pkey;
/// Hash generator
const EVP_MD* (*md)();
/// algorithm's name
Expand Down Expand Up @@ -1214,7 +1273,7 @@ namespace jwt {
}

/// OpenSSL struct containing keys
std::shared_ptr<EVP_PKEY> pkey;
helper::evp_pkey_handle pkey;
/// Hash generator function
const EVP_MD* (*md)();
/// algorithm's name
Expand Down Expand Up @@ -1360,7 +1419,7 @@ namespace jwt {

private:
/// OpenSSL struct containing keys
std::shared_ptr<EVP_PKEY> pkey;
helper::evp_pkey_handle pkey;
/// algorithm's name
const std::string alg_name;
};
Expand Down Expand Up @@ -1496,7 +1555,7 @@ namespace jwt {

private:
/// OpenSSL structure containing keys
std::shared_ptr<EVP_PKEY> pkey;
helper::evp_pkey_handle pkey;
/// Hash generator function
const EVP_MD* (*md)();
/// algorithm's name
Expand Down
12 changes: 6 additions & 6 deletions tests/OpenSSLErrorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ TEST(OpenSSLErrorTest, ConvertCertBase64DerToPemErrorCode) {

TEST(OpenSSLErrorTest, LoadPublicKeyFromStringReference) {
auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, "");
ASSERT_NE(res, nullptr);
ASSERT_TRUE(res);
}

TEST(OpenSSLErrorTest, LoadPublicKeyFromString) {
Expand All @@ -556,13 +556,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyFromStringErrorCode) {

run_multitest(mapping, [](std::error_code& ec) {
auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, "", ec);
ASSERT_EQ(res, nullptr);
ASSERT_FALSE(res);
});
}

TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringReference) {
auto res = jwt::helper::load_public_key_from_string(sample_cert, "");
ASSERT_NE(res, nullptr);
ASSERT_TRUE(res);
}

TEST(OpenSSLErrorTest, LoadPublicKeyCertFromString) {
Expand Down Expand Up @@ -601,13 +601,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringErrorCode) {

run_multitest(mapping, [](std::error_code& ec) {
auto res = jwt::helper::load_public_key_from_string(sample_cert, "", ec);
ASSERT_EQ(res, nullptr);
ASSERT_FALSE(res);
});
}

TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringReference) {
auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, "");
ASSERT_NE(res, nullptr);
ASSERT_TRUE(res);
}

TEST(OpenSSLErrorTest, LoadPrivateKeyFromString) {
Expand All @@ -630,7 +630,7 @@ TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringErrorCode) {

run_multitest(mapping, [](std::error_code& ec) {
auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, "", ec);
ASSERT_EQ(res, nullptr);
ASSERT_FALSE(res);
});
}

Expand Down

0 comments on commit 3ed4ff9

Please sign in to comment.