diff --git a/.gitignore b/.gitignore index 3a02c0454..751d5958c 100644 --- a/.gitignore +++ b/.gitignore @@ -307,6 +307,8 @@ test *.o *.o.d .vscode +# ClangD cache files +.cache doxy/ doxygen-awesome*.css diff --git a/include/jwt-cpp/jwt.h b/include/jwt-cpp/jwt.h index f2be680be..a31d913f5 100644 --- a/include/jwt-cpp/jwt.h +++ b/include/jwt-cpp/jwt.h @@ -388,6 +388,83 @@ namespace jwt { * you maybe need to extract the modulus and exponent of an RSA Public Key. */ namespace helper { + /** + * \brief Handle class for EVP_PKEY structures + * + * Starting from OpenSSL 1.1.0, EVP_PKEY has internal reference counting. This handle class allows + * jwt-cpp to leverage that and thus safe an allocation for the control block in std::shared_ptr. + * The handle uses shared_ptr as a fallback on older versions. The behaviour should be identical between both. + */ + class evp_pkey_handle { + public: + constexpr evp_pkey_handle() noexcept = default; +#ifdef JWT_OPENSSL_1_0_0 + /** + * \brief Contruct a new handle. The handle takes ownership of the key. + * \param key The key to store + */ + explicit evp_pkey_handle(EVP_PKEY* key) { m_key = std::shared_ptr(key, EVP_PKEY_free); } + + EVP_PKEY* get() const noexcept { return m_key.get(); } + bool operator!() const noexcept { return m_key == nullptr; } + explicit operator bool() const noexcept { return m_key != nullptr; } + + private: + std::shared_ptr m_key{nullptr}; +#else + /** + * \brief Contruct a new handle. The handle takes ownership of the key. + * \param key The key to store + */ + explicit constexpr evp_pkey_handle(EVP_PKEY* key) noexcept : m_key{key} {} + evp_pkey_handle(const evp_pkey_handle& other) : m_key{other.m_key} { + if (m_key != nullptr && EVP_PKEY_up_ref(m_key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed"); + } +// C++11 requires the body of a constexpr constructor to be empty +#if __cplusplus >= 201402L + constexpr +#endif + evp_pkey_handle(evp_pkey_handle&& other) noexcept + : m_key{other.m_key} { + other.m_key = nullptr; + } + evp_pkey_handle& operator=(const evp_pkey_handle& other) { + if (&other == this) return *this; + decrement_ref_count(m_key); + m_key = other.m_key; + increment_ref_count(m_key); + return *this; + } + evp_pkey_handle& operator=(evp_pkey_handle&& other) noexcept { + if (&other == this) return *this; + decrement_ref_count(m_key); + m_key = other.m_key; + other.m_key = nullptr; + return *this; + } + evp_pkey_handle& operator=(EVP_PKEY* key) { + decrement_ref_count(m_key); + m_key = key; + increment_ref_count(m_key); + return *this; + } + ~evp_pkey_handle() noexcept { decrement_ref_count(m_key); } + + EVP_PKEY* get() const noexcept { return m_key; } + bool operator!() const noexcept { return m_key == nullptr; } + explicit operator bool() const noexcept { return m_key != nullptr; } + + private: + EVP_PKEY* m_key{nullptr}; + + static void increment_ref_count(EVP_PKEY* key) { + if (key != nullptr && EVP_PKEY_up_ref(key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed"); + } + static void decrement_ref_count(EVP_PKEY* key) noexcept { + if (key != nullptr) EVP_PKEY_free(key); + } +#endif + }; /** * \brief Extract the public key of a pem certificate * @@ -556,38 +633,34 @@ namespace jwt { * \param password Password used to decrypt certificate (leave empty if not encrypted) * \param ec error_code for error_detection (gets cleared if no error occures) */ - inline std::shared_ptr load_public_key_from_string(const std::string& key, - const std::string& password, std::error_code& ec) { + inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password, + std::error_code& ec) { ec.clear(); std::unique_ptr pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all); if (!pubkey_bio) { ec = error::rsa_error::create_mem_bio_failed; - return nullptr; + return {}; } if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") { auto epkey = helper::extract_pubkey_from_cert(key, password, ec); - if (ec) return nullptr; + if (ec) return {}; const int len = static_cast(epkey.size()); if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) { ec = error::rsa_error::load_key_bio_write; - return nullptr; + return {}; } } else { const int len = static_cast(key.size()); if (BIO_write(pubkey_bio.get(), key.data(), len) != len) { ec = error::rsa_error::load_key_bio_write; - return nullptr; + return {}; } } - std::shared_ptr pkey( - PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr, - (void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast` - EVP_PKEY_free); - if (!pkey) { - ec = error::rsa_error::load_key_bio_read; - return nullptr; - } + evp_pkey_handle pkey(PEM_read_bio_PUBKEY( + pubkey_bio.get(), nullptr, nullptr, + (void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast` + if (!pkey) ec = error::rsa_error::load_key_bio_read; return pkey; } @@ -600,8 +673,7 @@ namespace jwt { * \param password Password used to decrypt certificate or key (leave empty if not encrypted) * \throw rsa_exception if an error occurred */ - inline std::shared_ptr load_public_key_from_string(const std::string& key, - const std::string& password = "") { + inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password = "") { std::error_code ec; auto res = load_public_key_from_string(key, password, ec); error::throw_if_error(ec); @@ -615,25 +687,21 @@ namespace jwt { * \param password Password used to decrypt key (leave empty if not encrypted) * \param ec error_code for error_detection (gets cleared if no error occures) */ - inline std::shared_ptr - load_private_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) { + inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password, + std::error_code& ec) { std::unique_ptr privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all); if (!privkey_bio) { ec = error::rsa_error::create_mem_bio_failed; - return nullptr; + return {}; } const int len = static_cast(key.size()); if (BIO_write(privkey_bio.get(), key.data(), len) != len) { ec = error::rsa_error::load_key_bio_write; - return nullptr; - } - std::shared_ptr pkey( - PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast(password.c_str())), - EVP_PKEY_free); - if (!pkey) { - ec = error::rsa_error::load_key_bio_read; - return nullptr; + return {}; } + evp_pkey_handle pkey( + PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast(password.c_str()))); + if (!pkey) ec = error::rsa_error::load_key_bio_read; return pkey; } @@ -644,8 +712,7 @@ namespace jwt { * \param password Password used to decrypt key (leave empty if not encrypted) * \throw rsa_exception if an error occurred */ - inline std::shared_ptr load_private_key_from_string(const std::string& key, - const std::string& password = "") { + inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password = "") { std::error_code ec; auto res = load_private_key_from_string(key, password, ec); error::throw_if_error(ec); @@ -661,38 +728,34 @@ namespace jwt { * \param password Password used to decrypt certificate (leave empty if not encrypted) * \param ec error_code for error_detection (gets cleared if no error occures) */ - inline std::shared_ptr - load_public_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) { + inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key, const std::string& password, + std::error_code& ec) { ec.clear(); std::unique_ptr pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all); if (!pubkey_bio) { ec = error::ecdsa_error::create_mem_bio_failed; - return nullptr; + return {}; } if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") { auto epkey = helper::extract_pubkey_from_cert(key, password, ec); - if (ec) return nullptr; + if (ec) return {}; const int len = static_cast(epkey.size()); if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) { ec = error::ecdsa_error::load_key_bio_write; - return nullptr; + return {}; } } else { const int len = static_cast(key.size()); if (BIO_write(pubkey_bio.get(), key.data(), len) != len) { ec = error::ecdsa_error::load_key_bio_write; - return nullptr; + return {}; } } - std::shared_ptr pkey( - PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr, - (void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast` - EVP_PKEY_free); - if (!pkey) { - ec = error::ecdsa_error::load_key_bio_read; - return nullptr; - } + evp_pkey_handle pkey(PEM_read_bio_PUBKEY( + pubkey_bio.get(), nullptr, nullptr, + (void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast` + if (!pkey) ec = error::ecdsa_error::load_key_bio_read; return pkey; } @@ -705,8 +768,8 @@ namespace jwt { * \param password Password used to decrypt certificate or key (leave empty if not encrypted) * \throw ecdsa_exception if an error occurred */ - inline std::shared_ptr load_public_ec_key_from_string(const std::string& key, - const std::string& password = "") { + inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key, + const std::string& password = "") { std::error_code ec; auto res = load_public_ec_key_from_string(key, password, ec); error::throw_if_error(ec); @@ -720,25 +783,21 @@ namespace jwt { * \param password Password used to decrypt key (leave empty if not encrypted) * \param ec error_code for error_detection (gets cleared if no error occures) */ - inline std::shared_ptr - load_private_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) { + inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key, const std::string& password, + std::error_code& ec) { std::unique_ptr privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all); if (!privkey_bio) { ec = error::ecdsa_error::create_mem_bio_failed; - return nullptr; + return {}; } const int len = static_cast(key.size()); if (BIO_write(privkey_bio.get(), key.data(), len) != len) { ec = error::ecdsa_error::load_key_bio_write; - return nullptr; - } - std::shared_ptr pkey( - PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast(password.c_str())), - EVP_PKEY_free); - if (!pkey) { - ec = error::ecdsa_error::load_key_bio_read; - return nullptr; + return {}; } + evp_pkey_handle pkey( + PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast(password.c_str()))); + if (!pkey) ec = error::ecdsa_error::load_key_bio_read; return pkey; } @@ -749,8 +808,8 @@ namespace jwt { * \param password Password used to decrypt key (leave empty if not encrypted) * \throw ecdsa_exception if an error occurred */ - inline std::shared_ptr load_private_ec_key_from_string(const std::string& key, - const std::string& password = "") { + inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key, + const std::string& password = "") { std::error_code ec; auto res = load_private_ec_key_from_string(key, password, ec); error::throw_if_error(ec); @@ -990,7 +1049,7 @@ namespace jwt { private: /// OpenSSL structure containing converted keys - std::shared_ptr pkey; + helper::evp_pkey_handle pkey; /// Hash generator const EVP_MD* (*md)(); /// algorithm's name @@ -1214,7 +1273,7 @@ namespace jwt { } /// OpenSSL struct containing keys - std::shared_ptr pkey; + helper::evp_pkey_handle pkey; /// Hash generator function const EVP_MD* (*md)(); /// algorithm's name @@ -1360,7 +1419,7 @@ namespace jwt { private: /// OpenSSL struct containing keys - std::shared_ptr pkey; + helper::evp_pkey_handle pkey; /// algorithm's name const std::string alg_name; }; @@ -1496,7 +1555,7 @@ namespace jwt { private: /// OpenSSL structure containing keys - std::shared_ptr pkey; + helper::evp_pkey_handle pkey; /// Hash generator function const EVP_MD* (*md)(); /// algorithm's name diff --git a/tests/OpenSSLErrorTest.cpp b/tests/OpenSSLErrorTest.cpp index eda24705f..02da9793c 100644 --- a/tests/OpenSSLErrorTest.cpp +++ b/tests/OpenSSLErrorTest.cpp @@ -533,7 +533,7 @@ TEST(OpenSSLErrorTest, ConvertCertBase64DerToPemErrorCode) { TEST(OpenSSLErrorTest, LoadPublicKeyFromStringReference) { auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, ""); - ASSERT_NE(res, nullptr); + ASSERT_TRUE(res); } TEST(OpenSSLErrorTest, LoadPublicKeyFromString) { @@ -556,13 +556,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyFromStringErrorCode) { run_multitest(mapping, [](std::error_code& ec) { auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, "", ec); - ASSERT_EQ(res, nullptr); + ASSERT_FALSE(res); }); } TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringReference) { auto res = jwt::helper::load_public_key_from_string(sample_cert, ""); - ASSERT_NE(res, nullptr); + ASSERT_TRUE(res); } TEST(OpenSSLErrorTest, LoadPublicKeyCertFromString) { @@ -601,13 +601,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringErrorCode) { run_multitest(mapping, [](std::error_code& ec) { auto res = jwt::helper::load_public_key_from_string(sample_cert, "", ec); - ASSERT_EQ(res, nullptr); + ASSERT_FALSE(res); }); } TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringReference) { auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, ""); - ASSERT_NE(res, nullptr); + ASSERT_TRUE(res); } TEST(OpenSSLErrorTest, LoadPrivateKeyFromString) { @@ -630,7 +630,7 @@ TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringErrorCode) { run_multitest(mapping, [](std::error_code& ec) { auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, "", ec); - ASSERT_EQ(res, nullptr); + ASSERT_FALSE(res); }); }