diff --git a/client_encryption/api_encryption.py b/client_encryption/api_encryption.py index 0c24db4..470900b 100644 --- a/client_encryption/api_encryption.py +++ b/client_encryption/api_encryption.py @@ -34,8 +34,10 @@ def field_encryption_call_api(self, func): @wraps(func) def call_api_function(*args, **kwargs): - check_type = inspect.signature(func.__self__.call_api).parameters.get("_check_type") is None - if check_type: + original_parameters = inspect.signature(func.__self__.call_api).parameters + check_type_is_none = original_parameters.get("_check_type") is None + preload_content_is_not_none = original_parameters.get("_preload_content") is not None + if check_type_is_none and preload_content_is_not_none: kwargs["_preload_content"] = False # version 4.3.1 return func(*args, **kwargs) @@ -75,10 +77,17 @@ def _encrypt_payload(self, headers, body): conf = self._encryption_conf - if type(conf) is FieldLevelEncryptionConfig: - return self.encrypt_field_level_payload(headers, conf, body) - else: - return self.encrypt_jwe_payload(conf, body) + encrypted_payload = self.encrypt_field_level_payload(headers, conf, body) if type( + conf) is FieldLevelEncryptionConfig else self.encrypt_jwe_payload(conf, body) + + # convert the encrypted_payload to the same data type as the input body + if isinstance(body, str): + return json.dumps(encrypted_payload) + + if isinstance(body, bytes): + return json.dumps(encrypted_payload).encode("utf-8") + + return encrypted_payload def _decrypt_payload(self, headers, body): """Encryption enforcement based on configuration - decrypt using session key params from header or body""" diff --git a/client_encryption/version.py b/client_encryption/version.py index 29c1469..bed4083 100644 --- a/client_encryption/version.py +++ b/client_encryption/version.py @@ -1,3 +1,3 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -__version__ = "1.7.0" +__version__ = "1.8.0" diff --git a/tests/test_api_encryption.py b/tests/test_api_encryption.py index 8860792..d786691 100644 --- a/tests/test_api_encryption.py +++ b/tests/test_api_encryption.py @@ -39,6 +39,28 @@ def test_ApiEncryption_with_config_as_dict(self, FieldLevelEncryptionConfig): def test_ApiEncryption_fail_with_config_as_string(self): self.assertRaises(FileNotFoundError, to_test.ApiEncryption, "this is not accepted") + def test_encrypt_payload_returns_same_data_type_as_input(self): + api_encryption = to_test.ApiEncryption(self._json_config) + + test_headers = {"Content-Type": "application/json"} + + body = { + "data": { + "secret1": "test", + "secret2": "secret" + }, + "encryptedData": {} + } + + encrypted = api_encryption._encrypt_payload(body=body, headers=test_headers) + self.assertIsInstance(encrypted, dict) + + encrypted = api_encryption._encrypt_payload(body=json.dumps(body), headers=test_headers) + self.assertIsInstance(encrypted, str) + + encrypted = api_encryption._encrypt_payload(body=json.dumps(body).encode("utf-8"), headers=test_headers) + self.assertIsInstance(encrypted, bytes) + def test_encrypt_payload_with_params_in_body(self): api_encryption = to_test.ApiEncryption(self._json_config)