diff --git a/compute_sdk/globus_compute_sdk/errors/error_types.py b/compute_sdk/globus_compute_sdk/errors/error_types.py index 1dce3928e..19fee23a9 100644 --- a/compute_sdk/globus_compute_sdk/errors/error_types.py +++ b/compute_sdk/globus_compute_sdk/errors/error_types.py @@ -25,6 +25,12 @@ def __repr__(self): class SerdeError(ComputeError): """Base class for SerializationError and DeserializationError""" + def __init__(self, reason: str): + self.reason = reason + + def __repr__(self): + return self.reason + class SerializationError(SerdeError): """Something failed during serialization.""" diff --git a/compute_sdk/globus_compute_sdk/serialize/facade.py b/compute_sdk/globus_compute_sdk/serialize/facade.py index 0a93e3081..247171de8 100644 --- a/compute_sdk/globus_compute_sdk/serialize/facade.py +++ b/compute_sdk/globus_compute_sdk/serialize/facade.py @@ -1,9 +1,14 @@ from __future__ import annotations +import importlib import logging import typing as t -from globus_compute_sdk.errors import DeserializationError, SerializationError +from globus_compute_sdk.errors import ( + DeserializationError, + SerdeError, + SerializationError, +) from globus_compute_sdk.serialize.base import IDENTIFIER_LENGTH, SerializationStrategy from globus_compute_sdk.serialize.concretes import ( DEFAULT_STRATEGY_CODE, @@ -14,6 +19,61 @@ logger = logging.getLogger(__name__) +DeserializerAllowlist = t.Iterable[type[SerializationStrategy] | str] + + +def assert_strategy_type_valid( + strategy_type: type[SerializationStrategy], for_code: bool +) -> None: + if strategy_type not in SELECTABLE_STRATEGIES: + raise SerdeError( + f"{strategy_type.__name__} is not a known serialization strategy" + f" (must be one of {SELECTABLE_STRATEGIES})" + ) + + if strategy_type.for_code != for_code: + gtype = "code" if for_code else "data" + etype = "data" if for_code else "code" + raise SerdeError( + f"{strategy_type.__name__} is a {gtype} serialization strategy," + f" expected a {etype} strategy" + ) + + +def validate_strategy( + strategy: SerializationStrategy, for_code: bool +) -> SerializationStrategy: + assert_strategy_type_valid(type(strategy), for_code) + return strategy + + +def validate_allowlist( + unvalidated: DeserializerAllowlist, for_code: bool +) -> set[type[SerializationStrategy]]: + validated = set() + for value in unvalidated: + resolved_strategy_class = None + if isinstance(value, str): + try: + mod_name, class_name = value.rsplit(".", 1) + mod = importlib.import_module(mod_name) + resolved_strategy_class = getattr(mod, class_name) + except Exception as e: + raise SerdeError(f"`{value}` is not a valid path to a strategy") from e + else: + resolved_strategy_class = value + + if not issubclass(resolved_strategy_class, SerializationStrategy): + raise SerdeError( + "Allowed deserializers must either be SerializationStrategies" + f" or valid paths to them (got {value})" + ) + + assert_strategy_type_valid(resolved_strategy_class, for_code) + validated.add(resolved_strategy_class) + + return validated + class ComputeSerializer: """Provides uniform interface to underlying serialization strategies""" @@ -22,23 +82,23 @@ def __init__( self, strategy_code: SerializationStrategy | None = None, strategy_data: SerializationStrategy | None = None, + *, + allowed_code_deserializer_types: DeserializerAllowlist | None = None, + allowed_data_deserializer_types: DeserializerAllowlist | None = None, ): """Instantiate the appropriate classes""" - def validate(strategy: SerializationStrategy) -> SerializationStrategy: - if type(strategy) not in SELECTABLE_STRATEGIES: - raise SerializationError( - f"{strategy} is not a known serialization strategy " - f"(must be one of {SELECTABLE_STRATEGIES})" - ) - - return strategy - - self.strategy_code = ( - validate(strategy_code) if strategy_code else DEFAULT_STRATEGY_CODE + self.code_serializer = validate_strategy( + strategy_code or DEFAULT_STRATEGY_CODE, True ) - self.strategy_data = ( - validate(strategy_data) if strategy_data else DEFAULT_STRATEGY_DATA + self.data_serializer = validate_strategy( + strategy_data or DEFAULT_STRATEGY_DATA, False + ) + self.allowed_code_deserializer_types = validate_allowlist( + allowed_code_deserializer_types or [], True + ) + self.allowed_data_deserializer_types = validate_allowlist( + allowed_data_deserializer_types or [], False ) self.header_size = IDENTIFIER_LENGTH @@ -50,9 +110,9 @@ def validate(strategy: SerializationStrategy) -> SerializationStrategy: def serialize(self, data): if callable(data): - stype, strategy = "Code", self.strategy_code + stype, strategy = "Code", self.code_serializer else: - stype, strategy = "Data", self.strategy_data + stype, strategy = "Data", self.data_serializer try: return strategy.serialize(data) @@ -74,6 +134,8 @@ def deserialize(self, payload): if not strategy: raise DeserializationError(f"Invalid header: {header} in data payload") + self.assert_deserializer_allowed(strategy) + return strategy.deserialize(payload) @staticmethod @@ -149,3 +211,25 @@ def check_strategies(self, function: t.Callable, *args, **kwargs): return self.unpack_and_deserialize(packed) except Exception as e: raise DeserializationError("check_strategies failed to deserialize") from e + + def assert_deserializer_allowed(self, strategy: SerializationStrategy) -> None: + allowlist = ( + self.allowed_code_deserializer_types + if strategy.for_code + else self.allowed_data_deserializer_types + ) + + if not allowlist or type(strategy) in allowlist: + return + + allowed_names = [t.__name__ for t in allowlist] + dtype = "Code" if strategy.for_code else "Data" + htype = "function" if strategy.for_code else "arguments" + help_url = "https://globus-compute.readthedocs.io/en/stable/sdk.html#specifying-a-serialization-strategy" # noqa + + raise DeserializationError( + f"{dtype} deserializer {type(strategy).__name__} is not allowed; expected " + f"one of {allowed_names}. (Hint: reserialize the {htype} with one of the " + f"allowed serialization strategies and resubmit. For more information, see " + f"{help_url}.)" + ) diff --git a/compute_sdk/tests/integration/test_serialization.py b/compute_sdk/tests/integration/test_serialization.py index 44f988af5..9b668239b 100644 --- a/compute_sdk/tests/integration/test_serialization.py +++ b/compute_sdk/tests/integration/test_serialization.py @@ -5,7 +5,7 @@ import globus_compute_sdk.serialize.concretes as concretes import pytest -from globus_compute_sdk.errors import SerializationError +from globus_compute_sdk.errors import SerdeError, SerializationError from globus_compute_sdk.serialize.base import SerializationStrategy from globus_compute_sdk.serialize.facade import ComputeSerializer @@ -337,6 +337,18 @@ def test_selectable_serialization(strategy): assert ser_data[:ID_LEN] == strategy.identifier +@pytest.mark.parametrize("strategy", concretes.SELECTABLE_STRATEGIES) +def test_selectable_serialization_enforces_for_code(strategy): + with pytest.raises(SerdeError) as pyt_exc: + ComputeSerializer(strategy_code=strategy(), strategy_data=strategy()) + + if strategy.for_code: + e = "is a code serialization strategy, expected a data strategy" + else: + e = "is a data serialization strategy, expected a code strategy" + assert e in str(pyt_exc) + + def test_serializer_errors_on_unknown_strategy(): class NewStrategy(SerializationStrategy): identifier = "aa\n" @@ -350,12 +362,12 @@ def deserialize(self, payload): strategy = NewStrategy() - with pytest.raises(SerializationError): + with pytest.raises(SerdeError): ComputeSerializer(strategy_code=strategy) NewStrategy.for_code = False - with pytest.raises(SerializationError): + with pytest.raises(SerdeError): ComputeSerializer(strategy_data=strategy) @@ -382,3 +394,118 @@ def test_check_strategies(strategy_code, strategy_data, function, args, kwargs): new_result = new_fn(*new_args, **new_kwargs) assert original_result == new_result + + +@pytest.mark.parametrize("disallowed_strategy", concretes.SELECTABLE_STRATEGIES) +def test_allowed_deserializers(disallowed_strategy): + allowlist = [ + strategy + for strategy in concretes.SELECTABLE_STRATEGIES + if strategy.for_code == disallowed_strategy.for_code + and not strategy != disallowed_strategy + ] + + assert allowlist, "expect to have at least one allowed deserializer" + + if disallowed_strategy.for_code: + serializer = ComputeSerializer(allowed_code_deserializer_types=allowlist) + payload = disallowed_strategy().serialize(foo) + else: + serializer = ComputeSerializer(allowed_data_deserializer_types=allowlist) + payload = disallowed_strategy().serialize("foo") + + with pytest.raises(SerdeError) as pyt_exc: + serializer.deserialize(payload) + assert f"deserializer {disallowed_strategy.__name__} is not allowed" in str(pyt_exc) + + +@pytest.mark.parametrize( + "list, valid", + [ + (["globus_compute_sdk.serialize.concretes.DillCode"], True), + (["my_malicious_package.my_malicious_serializer"], False), + (["invalid_path_1"], False), + (["invalid path 2"], False), + ([""], False), + ( + [ + "globus_compute_sdk.serialize.concretes.DillCode", + "globus_compute_sdk.serialize.concretes.DillCodeTextInspect", + "globus_compute_sdk.serialize.concretes.DillCodeSource", + "globus_compute_sdk.serialize.concretes.CombinedCode", + ], + True, + ), + ( + [ + "globus_compute_sdk.serialize.concretes.DillCode", + "my_malicious_package.my_malicious_serializer", + ], + False, + ), + ( + [ + "globus_compute_sdk.serialize.concretes.DillCode", + "invalid path", + ], + False, + ), + ], +) +def test_allowed_deserializers_checks_imports(list, valid): + if valid: + ComputeSerializer(allowed_code_deserializer_types=list) + else: + with pytest.raises(SerdeError) as pyt_exc: + ComputeSerializer(allowed_code_deserializer_types=list) + assert "is not a valid path to a strategy" in str(pyt_exc) + + +@pytest.mark.parametrize( + "list, for_code", + [ + ( + [ + "globus_compute_sdk.serialize.concretes.DillCode", + "globus_compute_sdk.serialize.concretes.DillCodeTextInspect", + "globus_compute_sdk.serialize.concretes.DillCodeSource", + "globus_compute_sdk.serialize.concretes.CombinedCode", + ], + True, + ), + ( + [ + concretes.DillCode, + concretes.DillCodeTextInspect, + concretes.DillCodeSource, + concretes.CombinedCode, + ], + True, + ), + ( + [ + "globus_compute_sdk.serialize.concretes.DillDataBase64", + "globus_compute_sdk.serialize.concretes.JSONData", + ], + False, + ), + ( + [ + concretes.DillDataBase64, + concretes.JSONData, + ], + False, + ), + ], +) +def test_allowed_deserializers_enforces_for_code(list, for_code): + with pytest.raises(SerdeError) as pyt_exc: + ComputeSerializer( + allowed_code_deserializer_types=list, allowed_data_deserializer_types=list + ) + + if for_code: + e = "is a code serialization strategy, expected a data strategy" + else: + e = "is a data serialization strategy, expected a code strategy" + assert e in str(pyt_exc)