diff --git a/sphinx/config.py b/sphinx/config.py index ef1eaa848e3..9e49423d5cc 100644 --- a/sphinx/config.py +++ b/sphinx/config.py @@ -51,17 +51,30 @@ class ConfigValue(NamedTuple): rebuild: _ConfigRebuild -def is_serializable(obj: Any) -> bool: +def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool: """Check if object is serializable or not.""" if isinstance(obj, UNSERIALIZABLE_TYPES): return False - elif isinstance(obj, dict): + + # use id() to handle un-hashable objects + if id(obj) in _recursive_guard: + return True + + if isinstance(obj, dict): + guard = _recursive_guard | {id(obj)} for key, value in obj.items(): - if not is_serializable(key) or not is_serializable(value): + if ( + not is_serializable(key, _recursive_guard=guard) + or not is_serializable(value, _recursive_guard=guard) + ): return False - elif isinstance(obj, (list, tuple, set)): - return all(map(is_serializable, obj)) + elif isinstance(obj, (list, tuple, set, frozenset)): + guard = _recursive_guard | {id(obj)} + return all(is_serializable(item, _recursive_guard=guard) for item in obj) + # if an issue occurs for a non-serializable type, pickle will complain + # since the object is likely coming from a third-party extension (we + # natively expect 'simple' types and not weird ones) return True diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index ee305274ecf..d269b7169b0 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -1,7 +1,11 @@ """Test the sphinx.config.Config class.""" +from __future__ import annotations + import pickle import time +from collections import Counter from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -14,10 +18,51 @@ _Opt, check_confval_types, correct_copyright_year, + is_serializable, ) from sphinx.deprecation import RemovedInSphinx90Warning from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Union + + CircularList = list[Union[int, 'CircularList']] + CircularDict = dict[str, Union[int, 'CircularDict']] + + +def check_is_serializable(subject: object, *, circular: bool) -> None: + assert is_serializable(subject) + + if circular: + class UselessGuard(frozenset[int]): + def __or__(self, other: object, /) -> UselessGuard: + # do nothing + return self + + def union(self, *args: Iterable[object]) -> UselessGuard: + # do nothing + return self + + # check that without recursive guards, a recursion error occurs + with pytest.raises(RecursionError): + assert is_serializable(subject, _recursive_guard=UselessGuard()) + + +def test_is_serializable() -> None: + subject = [1, [2, {3, 'a'}], {'x': {'y': frozenset((4, 5))}}] + check_is_serializable(subject, circular=False) + + a, b = [1], [2] # type: (CircularList, CircularList) + a.append(b) + b.append(a) + check_is_serializable(a, circular=True) + check_is_serializable(b, circular=True) + + x: CircularDict = {'a': 1, 'b': {'c': 1}} + x['b'] = x + check_is_serializable(x, circular=True) + def test_config_opt_deprecated(recwarn): opt = _Opt('default', '', ()) @@ -102,6 +147,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int): assert repr(config) == repr(pickled_config) +def test_config_pickle_circular_reference_in_list(): + a, b = [1], [2] # type: (CircularList, CircularList) + a.append(b) + b.append(a) + + check_is_serializable(a, circular=True) + check_is_serializable(b, circular=True) + + config = Config() + config.add('a', [], '', types=list) + config.add('b', [], '', types=list) + config.a, config.b = a, b + + actual = pickle.loads(pickle.dumps(config)) + assert isinstance(actual.a, list) + check_is_serializable(actual.a, circular=True) + + assert isinstance(actual.b, list) + check_is_serializable(actual.b, circular=True) + + assert actual.a[0] == 1 + assert actual.a[1][0] == 2 + assert actual.a[1][1][0] == 1 + assert actual.a[1][1][1][0] == 2 + + assert actual.b[0] == 2 + assert actual.b[1][0] == 1 + assert actual.b[1][1][0] == 2 + assert actual.b[1][1][1][0] == 1 + + assert len(actual.a) == 2 + assert len(actual.a[1]) == 2 + assert len(actual.a[1][1]) == 2 + assert len(actual.a[1][1][1]) == 2 + assert len(actual.a[1][1][1][1]) == 2 + + assert len(actual.b) == 2 + assert len(actual.b[1]) == 2 + assert len(actual.b[1][1]) == 2 + assert len(actual.b[1][1][1]) == 2 + assert len(actual.b[1][1][1][1]) == 2 + + def check( + u: list[list[object] | int], + v: list[list[object] | int], + *, + counter: Counter[type, int] | None = None, + guard: frozenset[int] = frozenset(), + ) -> Counter[type, int]: + counter = Counter() if counter is None else counter + + if id(u) in guard and id(v) in guard: + return counter + + if isinstance(u, int): + assert v.__class__ is u.__class__ + assert u == v + counter[type(u)] += 1 + return counter + + assert isinstance(u, list) + assert v.__class__ is u.__class__ + assert len(u) == len(v) + + for u_i, v_i in zip(u, v): + counter[type(u)] += 1 + check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)}) + + return counter + + counter = check(actual.a, a) + # check(actual.a, a) + # check(actual.a[0], a[0]) -> ++counter[dict] + # ++counter[int] (a[0] is an int) + # check(actual.a[1], a[1]) -> ++counter[dict] + # check(actual.a[1][0], a[1][0]) -> ++counter[dict] + # ++counter[int] (a[1][0] is an int) + # check(actual.a[1][1], a[1][1]) -> ++counter[dict] + # recursive guard since a[1][1] == a + assert counter[type(a[0])] == 2 + assert counter[type(a[1])] == 4 + + # same logic as above + counter = check(actual.b, b) + assert counter[type(b[0])] == 2 + assert counter[type(b[1])] == 4 + + +def test_config_pickle_circular_reference_in_dict(): + x: CircularDict = {'a': 1, 'b': {'c': 1}} + x['b'] = x + check_is_serializable(x, circular=True) + + config = Config() + config.add('x', [], '', types=dict) + config.x = x + + actual = pickle.loads(pickle.dumps(config)) + check_is_serializable(actual.x, circular=True) + assert isinstance(actual.x, dict) + + assert actual.x['a'] == 1 + assert actual.x['b']['a'] == 1 + + assert len(actual.x) == 2 + assert len(actual.x['b']) == 2 + assert len(actual.x['b']['b']) == 2 + + def check( + u: dict[str, dict[str, object] | int], + v: dict[str, dict[str, object] | int], + *, + counter: Counter[type, int] | None = None, + guard: frozenset[int] = frozenset(), + ) -> Counter: + counter = Counter() if counter is None else counter + + if id(u) in guard and id(v) in guard: + return counter + + if isinstance(u, int): + assert v.__class__ is u.__class__ + assert u == v + counter[type(u)] += 1 + return counter + + assert isinstance(u, dict) + assert v.__class__ is u.__class__ + assert len(u) == len(v) + + for u_i, v_i in zip(u, v): + counter[type(u)] += 1 + check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)}) + return counter + + counters = check(actual.x, x, counter=Counter()) + # check(actual.x, x) + # check(actual.x['a'], x['a']) -> ++counter[dict] + # ++counter[int] (x['a'] is an int) + # check(actual.x['b'], x['b']) -> ++counter[dict] + # recursive guard since x['b'] == x + assert counters[type(x['a'])] == 1 + assert counters[type(x['b'])] == 2 + + def test_extension_values(): config = Config()