diff --git a/python/pyarrow/_dataset_parquet.pxd b/python/pyarrow/_dataset_parquet.pxd index d5bc172d324d5..0a3a2ff526ea4 100644 --- a/python/pyarrow/_dataset_parquet.pxd +++ b/python/pyarrow/_dataset_parquet.pxd @@ -29,6 +29,7 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions): cdef: CParquetFragmentScanOptions* parquet_options object _parquet_decryption_config + object _decryption_properties cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp) cdef CReaderProperties* reader_properties(self) diff --git a/python/pyarrow/_dataset_parquet.pyx b/python/pyarrow/_dataset_parquet.pyx index a55e889ba8246..4942336a12666 100644 --- a/python/pyarrow/_dataset_parquet.pyx +++ b/python/pyarrow/_dataset_parquet.pyx @@ -56,7 +56,7 @@ from pyarrow._parquet cimport ( try: from pyarrow._dataset_parquet_encryption import ( - set_encryption_config, set_decryption_config + set_encryption_config, set_decryption_config, set_decryption_properties ) parquet_encryption_enabled = True except ImportError: @@ -127,8 +127,7 @@ cdef class ParquetFileFormat(FileFormat): 'instance of ParquetReadOptions') if default_fragment_scan_options is None: - default_fragment_scan_options = ParquetFragmentScanOptions( - **scan_args) + default_fragment_scan_options = ParquetFragmentScanOptions(**scan_args) elif isinstance(default_fragment_scan_options, dict): default_fragment_scan_options = ParquetFragmentScanOptions( **default_fragment_scan_options) @@ -715,6 +714,9 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions): decryption_config : pyarrow.dataset.ParquetDecryptionConfig, default None If not None, use the provided ParquetDecryptionConfig to decrypt the Parquet file. + decryption_properties : pyarrow.parquet.FileDecryptionProperties, default None + If not None, use the provided FileDecryptionProperties to decrypt encrypted + Parquet file. page_checksum_verification : bool, default False If True, verify the page checksum for each page read from the file. """ @@ -729,6 +731,7 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions): thrift_string_size_limit=None, thrift_container_size_limit=None, decryption_config=None, + decryption_properties=None, bint page_checksum_verification=False): self.init(shared_ptr[CFragmentScanOptions]( new CParquetFragmentScanOptions())) @@ -743,6 +746,8 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions): self.thrift_container_size_limit = thrift_container_size_limit if decryption_config is not None: self.parquet_decryption_config = decryption_config + if decryption_properties is not None: + self.decryption_properties = decryption_properties self.page_checksum_verification = page_checksum_verification cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): @@ -812,6 +817,25 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions): raise ValueError("size must be larger than zero") self.reader_properties().set_thrift_container_size_limit(size) + @property + def decryption_properties(self): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Unable to access encryption features. " + "Encryption is not enabled in your installation of pyarrow." + ) + return self._decryption_properties + + @decryption_properties.setter + def decryption_properties(self, config): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Encryption is not enabled in your installation of pyarrow, but " + "decryption_properties were provided." + ) + set_decryption_properties(self, config) + self._decryption_properties = config + @property def parquet_decryption_config(self): if not parquet_encryption_enabled: diff --git a/python/pyarrow/_dataset_parquet_encryption.pyx b/python/pyarrow/_dataset_parquet_encryption.pyx index 11a7174eb3c9d..c8f5e5b01bf81 100644 --- a/python/pyarrow/_dataset_parquet_encryption.pyx +++ b/python/pyarrow/_dataset_parquet_encryption.pyx @@ -162,6 +162,14 @@ def set_encryption_config( opts.parquet_options.parquet_encryption_config = c_config +def set_decryption_properties( + ParquetFragmentScanOptions opts not None, + FileDecryptionProperties config not None +): + cdef CReaderProperties* reader_props = opts.reader_properties() + reader_props.file_decryption_properties(config.unwrap()) + + def set_decryption_config( ParquetFragmentScanOptions opts not None, ParquetDecryptionConfig config not None diff --git a/python/pyarrow/parquet/core.py b/python/pyarrow/parquet/core.py index 69a1c9d19aae2..f54a203c8794c 100644 --- a/python/pyarrow/parquet/core.py +++ b/python/pyarrow/parquet/core.py @@ -1299,7 +1299,7 @@ def __init__(self, path_or_paths, filesystem=None, schema=None, *, filters=None, f"local file systems, not {type(filesystem)}" ) - # check for single fragment dataset + # check for single fragment dataset or dataset directory single_file = None self._base_dir = None if not isinstance(path_or_paths, list): @@ -1313,8 +1313,6 @@ def __init__(self, path_or_paths, filesystem=None, schema=None, *, filters=None, except ValueError: filesystem = LocalFileSystem(use_mmap=memory_map) finfo = filesystem.get_file_info(path_or_paths) - if finfo.is_file: - single_file = path_or_paths if finfo.type == FileType.Directory: self._base_dir = path_or_paths else: @@ -1771,6 +1769,7 @@ def read_table(source, *, columns=None, use_threads=True, ignore_prefixes=ignore_prefixes, pre_buffer=pre_buffer, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit, + decryption_properties=decryption_properties, thrift_string_size_limit=thrift_string_size_limit, thrift_container_size_limit=thrift_container_size_limit, page_checksum_verification=page_checksum_verification, diff --git a/python/pyarrow/tests/parquet/test_encryption.py b/python/pyarrow/tests/parquet/test_encryption.py index edb6410d2fa0d..ff388ef506997 100644 --- a/python/pyarrow/tests/parquet/test_encryption.py +++ b/python/pyarrow/tests/parquet/test_encryption.py @@ -65,6 +65,44 @@ def basic_encryption_config(): return basic_encryption_config +def setup_encryption_environment(custom_kms_conf): + """ + Sets up and returns the KMS connection configuration and crypto factory + based on provided KMS configuration parameters. + """ + kms_connection_config = pe.KmsConnectionConfig(custom_kms_conf=custom_kms_conf) + + def kms_factory(kms_connection_configuration): + return InMemoryKmsClient(kms_connection_configuration) + + # Create our CryptoFactory + crypto_factory = pe.CryptoFactory(kms_factory) + + return kms_connection_config, crypto_factory + + +def write_encrypted_file(path, data_table, footer_key_name, col_key_name, + footer_key, col_key, encryption_config): + """ + Writes an encrypted parquet file based on the provided parameters. + """ + # Setup the custom KMS configuration with provided keys + custom_kms_conf = { + footer_key_name: footer_key.decode("UTF-8"), + col_key_name: col_key.decode("UTF-8"), + } + + # Setup encryption environment + kms_connection_config, crypto_factory = setup_encryption_environment( + custom_kms_conf) + + # Write the encrypted parquet file + write_encrypted_parquet(path, data_table, encryption_config, + kms_connection_config, crypto_factory) + + return kms_connection_config, crypto_factory + + def test_encrypted_parquet_write_read(tempdir, data_table): """Write an encrypted parquet, verify it's encrypted, and then read it.""" path = tempdir / PARQUET_NAME @@ -81,20 +119,10 @@ def test_encrypted_parquet_write_read(tempdir, data_table): cache_lifetime=timedelta(minutes=5.0), data_key_length_bits=256) - kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), - } - ) - - def kms_factory(kms_connection_configuration): - return InMemoryKmsClient(kms_connection_configuration) + kms_connection_config, crypto_factory = write_encrypted_file( + path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY, + encryption_config) - crypto_factory = pe.CryptoFactory(kms_factory) - # Write with encryption properties - write_encrypted_parquet(path, data_table, encryption_config, - kms_connection_config, crypto_factory) verify_file_encrypted(path) # Read with decryption properties @@ -150,36 +178,22 @@ def test_encrypted_parquet_write_read_wrong_key(tempdir, data_table): cache_lifetime=timedelta(minutes=5.0), data_key_length_bits=256) - kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), - } - ) - - def kms_factory(kms_connection_configuration): - return InMemoryKmsClient(kms_connection_configuration) + write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, + FOOTER_KEY, COL_KEY, encryption_config) - crypto_factory = pe.CryptoFactory(kms_factory) - # Write with encryption properties - write_encrypted_parquet(path, data_table, encryption_config, - kms_connection_config, crypto_factory) verify_file_encrypted(path) - # Read with decryption properties - wrong_kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - # Wrong keys - mixup in names - FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"), - COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - } - ) + wrong_kms_connection_config, wrong_crypto_factory = setup_encryption_environment({ + FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"), # Intentionally wrong + COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"), # Intentionally wrong + }) + decryption_config = pe.DecryptionConfiguration( cache_lifetime=timedelta(minutes=5.0)) with pytest.raises(ValueError, match=r"Incorrect master key used"): read_encrypted_parquet( path, decryption_config, wrong_kms_connection_config, - crypto_factory) + wrong_crypto_factory) def test_encrypted_parquet_read_no_decryption_config(tempdir, data_table): @@ -219,23 +233,12 @@ def test_encrypted_parquet_write_no_col_key(tempdir, data_table): encryption_config = pe.EncryptionConfiguration( footer_key=FOOTER_KEY_NAME) - kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), - } - ) - - def kms_factory(kms_connection_configuration): - return InMemoryKmsClient(kms_connection_configuration) - - crypto_factory = pe.CryptoFactory(kms_factory) with pytest.raises(OSError, match="Either column_keys or uniform_encryption " "must be set"): # Write with encryption properties - write_encrypted_parquet(path, data_table, encryption_config, - kms_connection_config, crypto_factory) + write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, + FOOTER_KEY, b"", encryption_config) def test_encrypted_parquet_write_kms_error(tempdir, data_table, @@ -497,24 +500,11 @@ def test_encrypted_parquet_loop(tempdir, data_table, basic_encryption_config): # Encrypt the footer with the footer key, # encrypt column `a` and column `b` with another key, - # keep `c` plaintext - encryption_config = basic_encryption_config + # keep `c` plaintext, defined in basic_encryption_config + kms_connection_config, crypto_factory = write_encrypted_file( + path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY, + basic_encryption_config) - kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), - } - ) - - def kms_factory(kms_connection_configuration): - return InMemoryKmsClient(kms_connection_configuration) - - crypto_factory = pe.CryptoFactory(kms_factory) - - # Write with encryption properties - write_encrypted_parquet(path, data_table, encryption_config, - kms_connection_config, crypto_factory) verify_file_encrypted(path) decryption_config = pe.DecryptionConfiguration( @@ -537,32 +527,46 @@ def test_read_with_deleted_crypto_factory(tempdir, data_table, basic_encryption_ Test that decryption properties can be used if the crypto factory is no longer alive """ path = tempdir / PARQUET_NAME - encryption_config = basic_encryption_config - kms_connection_config = pe.KmsConnectionConfig( - custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), - } - ) - - def kms_factory(kms_connection_configuration): - return InMemoryKmsClient(kms_connection_configuration) - - encryption_crypto_factory = pe.CryptoFactory(kms_factory) - write_encrypted_parquet(path, data_table, encryption_config, - kms_connection_config, encryption_crypto_factory) + kms_connection_config, crypto_factory = write_encrypted_file( + path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY, + basic_encryption_config) verify_file_encrypted(path) - # Use a local function to get decryption properties, so the crypto factory that - # creates the properties will be deleted after it returns. - def get_decryption_properties(): - decryption_crypto_factory = pe.CryptoFactory(kms_factory) - decryption_config = pe.DecryptionConfiguration( - cache_lifetime=timedelta(minutes=5.0)) - return decryption_crypto_factory.file_decryption_properties( - kms_connection_config, decryption_config) + # Create decryption properties and delete the crypto factory that created + # the properties afterwards. + decryption_config = pe.DecryptionConfiguration( + cache_lifetime=timedelta(minutes=5.0)) + file_decryption_properties = crypto_factory.file_decryption_properties( + kms_connection_config, decryption_config) + del crypto_factory result = pq.ParquetFile( - path, decryption_properties=get_decryption_properties()) + path, decryption_properties=file_decryption_properties) result_table = result.read(use_threads=True) assert data_table.equals(result_table) + + +def test_encrypted_parquet_read_table(tempdir, data_table, basic_encryption_config): + """Write an encrypted parquet then read it back using read_table.""" + path = tempdir / PARQUET_NAME + + # Write the encrypted parquet file using the utility function + kms_connection_config, crypto_factory = write_encrypted_file( + path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY, + basic_encryption_config) + + decryption_config = pe.DecryptionConfiguration( + cache_lifetime=timedelta(minutes=5.0)) + file_decryption_properties = crypto_factory.file_decryption_properties( + kms_connection_config, decryption_config) + + # Read the encrypted parquet file using read_table + result_table = pq.read_table(path, decryption_properties=file_decryption_properties) + + # Assert that the read table matches the original data + assert data_table.equals(result_table) + + # Read the encrypted parquet folder using read_table + result_table = pq.read_table( + tempdir, decryption_properties=file_decryption_properties) + assert data_table.equals(result_table) diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index 2a631db9fc0fa..0d8b4a152ab9f 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -142,6 +142,18 @@ def test_dataset_encryption_decryption(): assert table.equals(dataset.to_table()) + # set decryption properties for parquet fragment scan options + decryption_properties = crypto_factory.file_decryption_properties( + kms_connection_config, decryption_config) + pq_scan_opts = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties + ) + + pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + + assert table.equals(dataset.to_table()) + @pytest.mark.skipif( not encryption_unavailable, reason="Parquet Encryption is currently enabled"