Skip to content

Commit

Permalink
Add options to restrict deserializers in ComputeSerializer
Browse files Browse the repository at this point in the history
Also move some errors to SerdeErrors instead of SerializationError /
DeserializationError, if those errors happen before anything is actually
serialized or deserialized; and make the ComputeSerializer enforce that
selectable serialization strategies are properly for_code or not.
  • Loading branch information
chris-janidlo committed Dec 16, 2024
1 parent f710b80 commit 9f12022
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 19 deletions.
6 changes: 6 additions & 0 deletions compute_sdk/globus_compute_sdk/errors/error_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def __repr__(self):
class SerdeError(ComputeError):
"""Base class for SerializationError and DeserializationError"""

def __init__(self, reason: str):
self.reason = reason

def __repr__(self):
return self.reason


class SerializationError(SerdeError):
"""Something failed during serialization."""
Expand Down
116 changes: 100 additions & 16 deletions compute_sdk/globus_compute_sdk/serialize/facade.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import importlib
import logging
import typing as t

from globus_compute_sdk.errors import DeserializationError, SerializationError
from globus_compute_sdk.errors import (
DeserializationError,
SerdeError,
SerializationError,
)
from globus_compute_sdk.serialize.base import IDENTIFIER_LENGTH, SerializationStrategy
from globus_compute_sdk.serialize.concretes import (
DEFAULT_STRATEGY_CODE,
Expand All @@ -14,6 +19,61 @@

logger = logging.getLogger(__name__)

DeserializerAllowlist = t.Iterable[type[SerializationStrategy] | str]


def assert_strategy_type_valid(
strategy_type: type[SerializationStrategy], for_code: bool
) -> None:
if strategy_type not in SELECTABLE_STRATEGIES:
raise SerdeError(
f"{strategy_type.__name__} is not a known serialization strategy"
f" (must be one of {SELECTABLE_STRATEGIES})"
)

if strategy_type.for_code != for_code:
gtype = "code" if for_code else "data"
etype = "data" if for_code else "code"
raise SerdeError(
f"{strategy_type.__name__} is a {gtype} serialization strategy,"
f" expected a {etype} strategy"
)


def validate_strategy(
strategy: SerializationStrategy, for_code: bool
) -> SerializationStrategy:
assert_strategy_type_valid(type(strategy), for_code)
return strategy


def validate_allowlist(
unvalidated: DeserializerAllowlist, for_code: bool
) -> set[type[SerializationStrategy]]:
validated = set()
for value in unvalidated:
resolved_strategy_class = None
if isinstance(value, str):
try:
mod_name, class_name = value.rsplit(".", 1)
mod = importlib.import_module(mod_name)
resolved_strategy_class = getattr(mod, class_name)
except Exception as e:
raise SerdeError(f"`{value}` is not a valid path to a strategy") from e
else:
resolved_strategy_class = value

if not issubclass(resolved_strategy_class, SerializationStrategy):
raise SerdeError(
"Allowed deserializers must either be SerializationStrategies"
f" or valid paths to them (got {value})"
)

assert_strategy_type_valid(resolved_strategy_class, for_code)
validated.add(resolved_strategy_class)

return validated


class ComputeSerializer:
"""Provides uniform interface to underlying serialization strategies"""
Expand All @@ -22,23 +82,23 @@ def __init__(
self,
strategy_code: SerializationStrategy | None = None,
strategy_data: SerializationStrategy | None = None,
*,
allowed_code_deserializer_types: DeserializerAllowlist | None = None,
allowed_data_deserializer_types: DeserializerAllowlist | None = None,
):
"""Instantiate the appropriate classes"""

def validate(strategy: SerializationStrategy) -> SerializationStrategy:
if type(strategy) not in SELECTABLE_STRATEGIES:
raise SerializationError(
f"{strategy} is not a known serialization strategy "
f"(must be one of {SELECTABLE_STRATEGIES})"
)

return strategy

self.strategy_code = (
validate(strategy_code) if strategy_code else DEFAULT_STRATEGY_CODE
self.code_serializer = validate_strategy(
strategy_code or DEFAULT_STRATEGY_CODE, True
)
self.strategy_data = (
validate(strategy_data) if strategy_data else DEFAULT_STRATEGY_DATA
self.data_serializer = validate_strategy(
strategy_data or DEFAULT_STRATEGY_DATA, False
)
self.allowed_code_deserializer_types = validate_allowlist(
allowed_code_deserializer_types or [], True
)
self.allowed_data_deserializer_types = validate_allowlist(
allowed_data_deserializer_types or [], False
)

self.header_size = IDENTIFIER_LENGTH
Expand All @@ -50,9 +110,9 @@ def validate(strategy: SerializationStrategy) -> SerializationStrategy:

def serialize(self, data):
if callable(data):
stype, strategy = "Code", self.strategy_code
stype, strategy = "Code", self.code_serializer
else:
stype, strategy = "Data", self.strategy_data
stype, strategy = "Data", self.data_serializer

try:
return strategy.serialize(data)
Expand All @@ -74,6 +134,8 @@ def deserialize(self, payload):
if not strategy:
raise DeserializationError(f"Invalid header: {header} in data payload")

self.assert_deserializer_allowed(strategy)

return strategy.deserialize(payload)

@staticmethod
Expand Down Expand Up @@ -149,3 +211,25 @@ def check_strategies(self, function: t.Callable, *args, **kwargs):
return self.unpack_and_deserialize(packed)
except Exception as e:
raise DeserializationError("check_strategies failed to deserialize") from e

def assert_deserializer_allowed(self, strategy: SerializationStrategy) -> None:
allowlist = (
self.allowed_code_deserializer_types
if strategy.for_code
else self.allowed_data_deserializer_types
)

if not allowlist or type(strategy) in allowlist:
return

allowed_names = [t.__name__ for t in allowlist]
dtype = "Code" if strategy.for_code else "Data"
htype = "function" if strategy.for_code else "arguments"
help_url = "https://globus-compute.readthedocs.io/en/stable/sdk.html#specifying-a-serialization-strategy" # noqa

raise DeserializationError(
f"{dtype} deserializer {type(strategy).__name__} is not allowed; expected "
f"one of {allowed_names}. (Hint: reserialize the {htype} with one of the "
f"allowed serialization strategies and resubmit. For more information, see "
f"{help_url}.)"
)
133 changes: 130 additions & 3 deletions compute_sdk/tests/integration/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import globus_compute_sdk.serialize.concretes as concretes
import pytest
from globus_compute_sdk.errors import SerializationError
from globus_compute_sdk.errors import SerdeError, SerializationError
from globus_compute_sdk.serialize.base import SerializationStrategy
from globus_compute_sdk.serialize.facade import ComputeSerializer

Expand Down Expand Up @@ -337,6 +337,18 @@ def test_selectable_serialization(strategy):
assert ser_data[:ID_LEN] == strategy.identifier


@pytest.mark.parametrize("strategy", concretes.SELECTABLE_STRATEGIES)
def test_selectable_serialization_enforces_for_code(strategy):
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(strategy_code=strategy(), strategy_data=strategy())

if strategy.for_code:
e = "is a code serialization strategy, expected a data strategy"
else:
e = "is a data serialization strategy, expected a code strategy"
assert e in str(pyt_exc)


def test_serializer_errors_on_unknown_strategy():
class NewStrategy(SerializationStrategy):
identifier = "aa\n"
Expand All @@ -350,12 +362,12 @@ def deserialize(self, payload):

strategy = NewStrategy()

with pytest.raises(SerializationError):
with pytest.raises(SerdeError):
ComputeSerializer(strategy_code=strategy)

NewStrategy.for_code = False

with pytest.raises(SerializationError):
with pytest.raises(SerdeError):
ComputeSerializer(strategy_data=strategy)


Expand All @@ -382,3 +394,118 @@ def test_check_strategies(strategy_code, strategy_data, function, args, kwargs):
new_result = new_fn(*new_args, **new_kwargs)

assert original_result == new_result


@pytest.mark.parametrize("disallowed_strategy", concretes.SELECTABLE_STRATEGIES)
def test_allowed_deserializers(disallowed_strategy):
allowlist = [
strategy
for strategy in concretes.SELECTABLE_STRATEGIES
if strategy.for_code == disallowed_strategy.for_code
and not strategy != disallowed_strategy
]

assert allowlist, "expect to have at least one allowed deserializer"

if disallowed_strategy.for_code:
serializer = ComputeSerializer(allowed_code_deserializer_types=allowlist)
payload = disallowed_strategy().serialize(foo)
else:
serializer = ComputeSerializer(allowed_data_deserializer_types=allowlist)
payload = disallowed_strategy().serialize("foo")

with pytest.raises(SerdeError) as pyt_exc:
serializer.deserialize(payload)
assert f"deserializer {disallowed_strategy.__name__} is not allowed" in str(pyt_exc)


@pytest.mark.parametrize(
"list, valid",
[
(["globus_compute_sdk.serialize.concretes.DillCode"], True),
(["my_malicious_package.my_malicious_serializer"], False),
(["invalid_path_1"], False),
(["invalid path 2"], False),
([""], False),
(
[
"globus_compute_sdk.serialize.concretes.DillCode",
"globus_compute_sdk.serialize.concretes.DillCodeTextInspect",
"globus_compute_sdk.serialize.concretes.DillCodeSource",
"globus_compute_sdk.serialize.concretes.CombinedCode",
],
True,
),
(
[
"globus_compute_sdk.serialize.concretes.DillCode",
"my_malicious_package.my_malicious_serializer",
],
False,
),
(
[
"globus_compute_sdk.serialize.concretes.DillCode",
"invalid path",
],
False,
),
],
)
def test_allowed_deserializers_checks_imports(list, valid):
if valid:
ComputeSerializer(allowed_code_deserializer_types=list)
else:
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(allowed_code_deserializer_types=list)
assert "is not a valid path to a strategy" in str(pyt_exc)


@pytest.mark.parametrize(
"list, for_code",
[
(
[
"globus_compute_sdk.serialize.concretes.DillCode",
"globus_compute_sdk.serialize.concretes.DillCodeTextInspect",
"globus_compute_sdk.serialize.concretes.DillCodeSource",
"globus_compute_sdk.serialize.concretes.CombinedCode",
],
True,
),
(
[
concretes.DillCode,
concretes.DillCodeTextInspect,
concretes.DillCodeSource,
concretes.CombinedCode,
],
True,
),
(
[
"globus_compute_sdk.serialize.concretes.DillDataBase64",
"globus_compute_sdk.serialize.concretes.JSONData",
],
False,
),
(
[
concretes.DillDataBase64,
concretes.JSONData,
],
False,
),
],
)
def test_allowed_deserializers_enforces_for_code(list, for_code):
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(
allowed_code_deserializer_types=list, allowed_data_deserializer_types=list
)

if for_code:
e = "is a code serialization strategy, expected a data strategy"
else:
e = "is a data serialization strategy, expected a code strategy"
assert e in str(pyt_exc)

0 comments on commit 9f12022

Please sign in to comment.