diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index a0d9367b619c6..542741fa98178 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -89,6 +89,12 @@ class AesEncryptor::AesEncryptorImpl { } private: + void CheckValid() { + if (ctx_ == nullptr) { + throw ParquetException("AesEncryptor was wiped out"); + } + } + EVP_CIPHER_CTX* ctx_; int32_t aes_mode_; int32_t key_length_; @@ -156,6 +162,8 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( span footer, span key, span aad, span nonce, span encrypted_footer) { + CheckValid(); + if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -180,6 +188,8 @@ int32_t AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, span key, span aad, span ciphertext) { + CheckValid(); + if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -413,6 +423,12 @@ class AesDecryptor::AesDecryptorImpl { } private: + void CheckValid() { + if (ctx_ == nullptr) { + throw ParquetException("AesDecryptor was wiped out"); + } + } + EVP_CIPHER_CTX* ctx_; int32_t aes_mode_; int32_t key_length_; @@ -714,6 +730,8 @@ int32_t AesDecryptor::AesDecryptorImpl::Decrypt(span ciphertext, span key, span aad, span plaintext) { + CheckValid(); + if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -806,4 +824,7 @@ void RandBytes(unsigned char* buf, size_t num) { void EnsureBackendInitialized() { openssl::EnsureInitialized(); } +#undef ENCRYPT_INIT +#undef DECRYPT_INIT + } // namespace parquet::encryption