Skip to content

Commit

Permalink
apacheGH-39645: [Python] Fix read_table for encrypted parquet (apache…
Browse files Browse the repository at this point in the history
…#39438)

### Rationale for this change

Currently, if you try to read a decrypted parquet with read_table, passing decryption_properties - in the happy path (pyarrow.data available for import) the reading/decryption of the file fails, as the decryption properties are missing.

### What changes are included in this PR?

Pass through the argument that was intended to have been passed.

### Are these changes tested?

We have tested this locally on an encrypted parquet dataset - please advise on any further testing you would like beyond that and the standard CI.

### Are there any user-facing changes?

Not in any cases where their code was previously working? The intended behaviour for encrypted dataset decryption should start working.

* Closes: apache#39645

Lead-authored-by: Tom McTiernan <[email protected]>
Co-authored-by: Don <[email protected]>
Co-authored-by: Rok Mihevc <[email protected]>
Signed-off-by: Rok Mihevc <[email protected]>
  • Loading branch information
3 people authored and JerAguilon committed May 29, 2024
1 parent 6dfe9c3 commit ec8b53a
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 94 deletions.
1 change: 1 addition & 0 deletions python/pyarrow/_dataset_parquet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
cdef:
CParquetFragmentScanOptions* parquet_options
object _parquet_decryption_config
object _decryption_properties

cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp)
cdef CReaderProperties* reader_properties(self)
Expand Down
30 changes: 27 additions & 3 deletions python/pyarrow/_dataset_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ from pyarrow._parquet cimport (

try:
from pyarrow._dataset_parquet_encryption import (
set_encryption_config, set_decryption_config
set_encryption_config, set_decryption_config, set_decryption_properties
)
parquet_encryption_enabled = True
except ImportError:
Expand Down Expand Up @@ -127,8 +127,7 @@ cdef class ParquetFileFormat(FileFormat):
'instance of ParquetReadOptions')

if default_fragment_scan_options is None:
default_fragment_scan_options = ParquetFragmentScanOptions(
**scan_args)
default_fragment_scan_options = ParquetFragmentScanOptions(**scan_args)
elif isinstance(default_fragment_scan_options, dict):
default_fragment_scan_options = ParquetFragmentScanOptions(
**default_fragment_scan_options)
Expand Down Expand Up @@ -715,6 +714,9 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
decryption_config : pyarrow.dataset.ParquetDecryptionConfig, default None
If not None, use the provided ParquetDecryptionConfig to decrypt the
Parquet file.
decryption_properties : pyarrow.parquet.FileDecryptionProperties, default None
If not None, use the provided FileDecryptionProperties to decrypt encrypted
Parquet file.
page_checksum_verification : bool, default False
If True, verify the page checksum for each page read from the file.
"""
Expand All @@ -729,6 +731,7 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
thrift_string_size_limit=None,
thrift_container_size_limit=None,
decryption_config=None,
decryption_properties=None,
bint page_checksum_verification=False):
self.init(shared_ptr[CFragmentScanOptions](
new CParquetFragmentScanOptions()))
Expand All @@ -743,6 +746,8 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
self.thrift_container_size_limit = thrift_container_size_limit
if decryption_config is not None:
self.parquet_decryption_config = decryption_config
if decryption_properties is not None:
self.decryption_properties = decryption_properties
self.page_checksum_verification = page_checksum_verification

cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp):
Expand Down Expand Up @@ -812,6 +817,25 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
raise ValueError("size must be larger than zero")
self.reader_properties().set_thrift_container_size_limit(size)

@property
def decryption_properties(self):
if not parquet_encryption_enabled:
raise NotImplementedError(
"Unable to access encryption features. "
"Encryption is not enabled in your installation of pyarrow."
)
return self._decryption_properties

@decryption_properties.setter
def decryption_properties(self, config):
if not parquet_encryption_enabled:
raise NotImplementedError(
"Encryption is not enabled in your installation of pyarrow, but "
"decryption_properties were provided."
)
set_decryption_properties(self, config)
self._decryption_properties = config

@property
def parquet_decryption_config(self):
if not parquet_encryption_enabled:
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/_dataset_parquet_encryption.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ def set_encryption_config(
opts.parquet_options.parquet_encryption_config = c_config


def set_decryption_properties(
ParquetFragmentScanOptions opts not None,
FileDecryptionProperties config not None
):
cdef CReaderProperties* reader_props = opts.reader_properties()
reader_props.file_decryption_properties(config.unwrap())


def set_decryption_config(
ParquetFragmentScanOptions opts not None,
ParquetDecryptionConfig config not None
Expand Down
5 changes: 2 additions & 3 deletions python/pyarrow/parquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def __init__(self, path_or_paths, filesystem=None, schema=None, *, filters=None,
f"local file systems, not {type(filesystem)}"
)

# check for single fragment dataset
# check for single fragment dataset or dataset directory
single_file = None
self._base_dir = None
if not isinstance(path_or_paths, list):
Expand All @@ -1313,8 +1313,6 @@ def __init__(self, path_or_paths, filesystem=None, schema=None, *, filters=None,
except ValueError:
filesystem = LocalFileSystem(use_mmap=memory_map)
finfo = filesystem.get_file_info(path_or_paths)
if finfo.is_file:
single_file = path_or_paths
if finfo.type == FileType.Directory:
self._base_dir = path_or_paths
else:
Expand Down Expand Up @@ -1771,6 +1769,7 @@ def read_table(source, *, columns=None, use_threads=True,
ignore_prefixes=ignore_prefixes,
pre_buffer=pre_buffer,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
decryption_properties=decryption_properties,
thrift_string_size_limit=thrift_string_size_limit,
thrift_container_size_limit=thrift_container_size_limit,
page_checksum_verification=page_checksum_verification,
Expand Down
180 changes: 92 additions & 88 deletions python/pyarrow/tests/parquet/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,44 @@ def basic_encryption_config():
return basic_encryption_config


def setup_encryption_environment(custom_kms_conf):
"""
Sets up and returns the KMS connection configuration and crypto factory
based on provided KMS configuration parameters.
"""
kms_connection_config = pe.KmsConnectionConfig(custom_kms_conf=custom_kms_conf)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)

# Create our CryptoFactory
crypto_factory = pe.CryptoFactory(kms_factory)

return kms_connection_config, crypto_factory


def write_encrypted_file(path, data_table, footer_key_name, col_key_name,
footer_key, col_key, encryption_config):
"""
Writes an encrypted parquet file based on the provided parameters.
"""
# Setup the custom KMS configuration with provided keys
custom_kms_conf = {
footer_key_name: footer_key.decode("UTF-8"),
col_key_name: col_key.decode("UTF-8"),
}

# Setup encryption environment
kms_connection_config, crypto_factory = setup_encryption_environment(
custom_kms_conf)

# Write the encrypted parquet file
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)

return kms_connection_config, crypto_factory


def test_encrypted_parquet_write_read(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted, and then read it."""
path = tempdir / PARQUET_NAME
Expand All @@ -81,20 +119,10 @@ def test_encrypted_parquet_write_read(tempdir, data_table):
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)

kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
encryption_config)

crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)

# Read with decryption properties
Expand Down Expand Up @@ -150,36 +178,22 @@ def test_encrypted_parquet_write_read_wrong_key(tempdir, data_table):
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)

kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME,
FOOTER_KEY, COL_KEY, encryption_config)

crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)

# Read with decryption properties
wrong_kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
# Wrong keys - mixup in names
FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"),
COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
}
)
wrong_kms_connection_config, wrong_crypto_factory = setup_encryption_environment({
FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"), # Intentionally wrong
COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"), # Intentionally wrong
})

decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
with pytest.raises(ValueError, match=r"Incorrect master key used"):
read_encrypted_parquet(
path, decryption_config, wrong_kms_connection_config,
crypto_factory)
wrong_crypto_factory)


def test_encrypted_parquet_read_no_decryption_config(tempdir, data_table):
Expand Down Expand Up @@ -219,23 +233,12 @@ def test_encrypted_parquet_write_no_col_key(tempdir, data_table):
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME)

kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)

crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(OSError,
match="Either column_keys or uniform_encryption "
"must be set"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME,
FOOTER_KEY, b"", encryption_config)


def test_encrypted_parquet_write_kms_error(tempdir, data_table,
Expand Down Expand Up @@ -497,24 +500,11 @@ def test_encrypted_parquet_loop(tempdir, data_table, basic_encryption_config):

# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = basic_encryption_config
# keep `c` plaintext, defined in basic_encryption_config
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)

kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)

crypto_factory = pe.CryptoFactory(kms_factory)

# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)

decryption_config = pe.DecryptionConfiguration(
Expand All @@ -537,32 +527,46 @@ def test_read_with_deleted_crypto_factory(tempdir, data_table, basic_encryption_
Test that decryption properties can be used if the crypto factory is no longer alive
"""
path = tempdir / PARQUET_NAME
encryption_config = basic_encryption_config
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)

def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)

encryption_crypto_factory = pe.CryptoFactory(kms_factory)
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, encryption_crypto_factory)
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)
verify_file_encrypted(path)

# Use a local function to get decryption properties, so the crypto factory that
# creates the properties will be deleted after it returns.
def get_decryption_properties():
decryption_crypto_factory = pe.CryptoFactory(kms_factory)
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
return decryption_crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
# Create decryption properties and delete the crypto factory that created
# the properties afterwards.
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
del crypto_factory

result = pq.ParquetFile(
path, decryption_properties=get_decryption_properties())
path, decryption_properties=file_decryption_properties)
result_table = result.read(use_threads=True)
assert data_table.equals(result_table)


def test_encrypted_parquet_read_table(tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet then read it back using read_table."""
path = tempdir / PARQUET_NAME

# Write the encrypted parquet file using the utility function
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)

decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)

# Read the encrypted parquet file using read_table
result_table = pq.read_table(path, decryption_properties=file_decryption_properties)

# Assert that the read table matches the original data
assert data_table.equals(result_table)

# Read the encrypted parquet folder using read_table
result_table = pq.read_table(
tempdir, decryption_properties=file_decryption_properties)
assert data_table.equals(result_table)
12 changes: 12 additions & 0 deletions python/pyarrow/tests/test_dataset_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ def test_dataset_encryption_decryption():

assert table.equals(dataset.to_table())

# set decryption properties for parquet fragment scan options
decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
pq_scan_opts = ds.ParquetFragmentScanOptions(
decryption_properties=decryption_properties
)

pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

assert table.equals(dataset.to_table())


@pytest.mark.skipif(
not encryption_unavailable, reason="Parquet Encryption is currently enabled"
Expand Down

0 comments on commit ec8b53a

Please sign in to comment.