Skip to content

Commit

Permalink
[config] protect is_serializable against circular references (#12196
Browse files Browse the repository at this point in the history
)
  • Loading branch information
picnixz authored Mar 25, 2024
1 parent 885818b commit f26d492
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 5 deletions.
23 changes: 18 additions & 5 deletions sphinx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,30 @@ class ConfigValue(NamedTuple):
rebuild: _ConfigRebuild


def is_serializable(obj: Any) -> bool:
def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool:
"""Check if object is serializable or not."""
if isinstance(obj, UNSERIALIZABLE_TYPES):
return False
elif isinstance(obj, dict):

# use id() to handle un-hashable objects
if id(obj) in _recursive_guard:
return True

if isinstance(obj, dict):
guard = _recursive_guard | {id(obj)}
for key, value in obj.items():
if not is_serializable(key) or not is_serializable(value):
if (
not is_serializable(key, _recursive_guard=guard)
or not is_serializable(value, _recursive_guard=guard)
):
return False
elif isinstance(obj, (list, tuple, set)):
return all(map(is_serializable, obj))
elif isinstance(obj, (list, tuple, set, frozenset)):
guard = _recursive_guard | {id(obj)}
return all(is_serializable(item, _recursive_guard=guard) for item in obj)

# if an issue occurs for a non-serializable type, pickle will complain
# since the object is likely coming from a third-party extension (we
# natively expect 'simple' types and not weird ones)
return True


Expand Down
190 changes: 190 additions & 0 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Test the sphinx.config.Config class."""
from __future__ import annotations

import pickle
import time
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock

import pytest
Expand All @@ -14,10 +18,51 @@
_Opt,
check_confval_types,
correct_copyright_year,
is_serializable,
)
from sphinx.deprecation import RemovedInSphinx90Warning
from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError

if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Union

CircularList = list[Union[int, 'CircularList']]
CircularDict = dict[str, Union[int, 'CircularDict']]


def check_is_serializable(subject: object, *, circular: bool) -> None:
assert is_serializable(subject)

if circular:
class UselessGuard(frozenset[int]):
def __or__(self, other: object, /) -> UselessGuard:
# do nothing
return self

def union(self, *args: Iterable[object]) -> UselessGuard:
# do nothing
return self

# check that without recursive guards, a recursion error occurs
with pytest.raises(RecursionError):
assert is_serializable(subject, _recursive_guard=UselessGuard())


def test_is_serializable() -> None:
subject = [1, [2, {3, 'a'}], {'x': {'y': frozenset((4, 5))}}]
check_is_serializable(subject, circular=False)

a, b = [1], [2] # type: (CircularList, CircularList)
a.append(b)
b.append(a)
check_is_serializable(a, circular=True)
check_is_serializable(b, circular=True)

x: CircularDict = {'a': 1, 'b': {'c': 1}}
x['b'] = x
check_is_serializable(x, circular=True)


def test_config_opt_deprecated(recwarn):
opt = _Opt('default', '', ())
Expand Down Expand Up @@ -102,6 +147,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int):
assert repr(config) == repr(pickled_config)


def test_config_pickle_circular_reference_in_list():
a, b = [1], [2] # type: (CircularList, CircularList)
a.append(b)
b.append(a)

check_is_serializable(a, circular=True)
check_is_serializable(b, circular=True)

config = Config()
config.add('a', [], '', types=list)
config.add('b', [], '', types=list)
config.a, config.b = a, b

actual = pickle.loads(pickle.dumps(config))
assert isinstance(actual.a, list)
check_is_serializable(actual.a, circular=True)

assert isinstance(actual.b, list)
check_is_serializable(actual.b, circular=True)

assert actual.a[0] == 1
assert actual.a[1][0] == 2
assert actual.a[1][1][0] == 1
assert actual.a[1][1][1][0] == 2

assert actual.b[0] == 2
assert actual.b[1][0] == 1
assert actual.b[1][1][0] == 2
assert actual.b[1][1][1][0] == 1

assert len(actual.a) == 2
assert len(actual.a[1]) == 2
assert len(actual.a[1][1]) == 2
assert len(actual.a[1][1][1]) == 2
assert len(actual.a[1][1][1][1]) == 2

assert len(actual.b) == 2
assert len(actual.b[1]) == 2
assert len(actual.b[1][1]) == 2
assert len(actual.b[1][1][1]) == 2
assert len(actual.b[1][1][1][1]) == 2

def check(
u: list[list[object] | int],
v: list[list[object] | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter[type, int]:
counter = Counter() if counter is None else counter

if id(u) in guard and id(v) in guard:
return counter

if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter

assert isinstance(u, list)
assert v.__class__ is u.__class__
assert len(u) == len(v)

for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)})

return counter

counter = check(actual.a, a)
# check(actual.a, a)
# check(actual.a[0], a[0]) -> ++counter[dict]
# ++counter[int] (a[0] is an int)
# check(actual.a[1], a[1]) -> ++counter[dict]
# check(actual.a[1][0], a[1][0]) -> ++counter[dict]
# ++counter[int] (a[1][0] is an int)
# check(actual.a[1][1], a[1][1]) -> ++counter[dict]
# recursive guard since a[1][1] == a
assert counter[type(a[0])] == 2
assert counter[type(a[1])] == 4

# same logic as above
counter = check(actual.b, b)
assert counter[type(b[0])] == 2
assert counter[type(b[1])] == 4


def test_config_pickle_circular_reference_in_dict():
x: CircularDict = {'a': 1, 'b': {'c': 1}}
x['b'] = x
check_is_serializable(x, circular=True)

config = Config()
config.add('x', [], '', types=dict)
config.x = x

actual = pickle.loads(pickle.dumps(config))
check_is_serializable(actual.x, circular=True)
assert isinstance(actual.x, dict)

assert actual.x['a'] == 1
assert actual.x['b']['a'] == 1

assert len(actual.x) == 2
assert len(actual.x['b']) == 2
assert len(actual.x['b']['b']) == 2

def check(
u: dict[str, dict[str, object] | int],
v: dict[str, dict[str, object] | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter:
counter = Counter() if counter is None else counter

if id(u) in guard and id(v) in guard:
return counter

if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter

assert isinstance(u, dict)
assert v.__class__ is u.__class__
assert len(u) == len(v)

for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)})
return counter

counters = check(actual.x, x, counter=Counter())
# check(actual.x, x)
# check(actual.x['a'], x['a']) -> ++counter[dict]
# ++counter[int] (x['a'] is an int)
# check(actual.x['b'], x['b']) -> ++counter[dict]
# recursive guard since x['b'] == x
assert counters[type(x['a'])] == 1
assert counters[type(x['b'])] == 2


def test_extension_values():
config = Config()

Expand Down

0 comments on commit f26d492

Please sign in to comment.