Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[config] protect is_serializable against circular references #12196

Merged
merged 6 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions sphinx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,30 @@ class ConfigValue(NamedTuple):
rebuild: _ConfigRebuild


def is_serializable(obj: Any) -> bool:
def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool:
"""Check if object is serializable or not."""
if isinstance(obj, UNSERIALIZABLE_TYPES):
return False
elif isinstance(obj, dict):

# use id() to handle un-hashable objects
if id(obj) in _recursive_guard:
return True

if isinstance(obj, dict):
guard = _recursive_guard | {id(obj)}
for key, value in obj.items():
if not is_serializable(key) or not is_serializable(value):
if (
not is_serializable(key, _recursive_guard=guard)
or not is_serializable(value, _recursive_guard=guard)
):
return False
elif isinstance(obj, (list, tuple, set)):
picnixz marked this conversation as resolved.
Show resolved Hide resolved
return all(map(is_serializable, obj))
guard = _recursive_guard | {id(obj)}
return all(is_serializable(item, _recursive_guard=guard) for item in obj)

# if an issue occurs for a non-serializable type, pickle will complain
# since the object is likely coming from a third-party extension (we
# natively expect 'simple' types and not weird ones)
return True


Expand Down
175 changes: 175 additions & 0 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test the sphinx.config.Config class."""
from __future__ import annotations

import pickle
import time
from collections import Counter
from pathlib import Path
from unittest import mock

Expand All @@ -14,11 +17,38 @@
_Opt,
check_confval_types,
correct_copyright_year,
is_serializable,
)
from sphinx.deprecation import RemovedInSphinx90Warning
from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError


def test_is_serializable():
# check that objects with circular references are correctly handled

a, b = [1], [2]
a.append(b)
b.append(a)
assert is_serializable(a)

x = {'a': 1, 'b': {'c': 1}}
x['b'] = x
assert is_serializable(x)

class _IgnoreExtend(frozenset):
def __or__(self, other):
picnixz marked this conversation as resolved.
Show resolved Hide resolved
# do nothing
return self

# check that without recursive guards, a recursion error occurs

with pytest.raises(RecursionError):
assert is_serializable(a, _recursive_guard=_IgnoreExtend())

with pytest.raises(RecursionError):
assert is_serializable(x, _recursive_guard=_IgnoreExtend())


def test_config_opt_deprecated(recwarn):
opt = _Opt('default', '', ())

Expand Down Expand Up @@ -102,6 +132,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int):
assert repr(config) == repr(pickled_config)


def test_config_pickle_circular_reference_in_list():
a, b = [1], [2]
a.append(b)
b.append(a)

assert is_serializable(a)
assert is_serializable(b)

config = Config()
config.add('a', [], 'env', types=list)
config.add('b', [], 'env', types=list)
config.a, config.b = a, b

actual = pickle.loads(pickle.dumps(config))
assert isinstance(actual.a, list)
assert is_serializable(actual.a)

assert isinstance(actual.b, list)
assert is_serializable(actual.b)

assert actual.a[0] == 1
assert actual.a[1][0] == 2
assert actual.a[1][1][0] == 1
assert actual.a[1][1][1][0] == 2

assert actual.b[0] == 2
assert actual.b[1][0] == 1
assert actual.b[1][1][0] == 2
assert actual.b[1][1][1][0] == 1

assert len(actual.a) == 2
assert len(actual.a[1]) == 2
assert len(actual.a[1][1]) == 2
assert len(actual.a[1][1][1]) == 2
assert len(actual.a[1][1][1][1]) == 2

assert len(actual.b) == 2
assert len(actual.b[1]) == 2
assert len(actual.b[1][1]) == 2
assert len(actual.b[1][1][1]) == 2
assert len(actual.b[1][1][1][1]) == 2

def check(
u: list[list | int],
v: list[list | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter[type, int]:
counter = Counter() if counter is None else counter

if id(u) in guard and id(v) in guard:
return counter

if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter

assert isinstance(u, list)
assert v.__class__ is u.__class__
assert len(u) == len(v)

for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)})

return counter

counter = check(actual.a, a)
# check(actual.a, a)
# check(actual.a[0], a[0]) -> ++counter[dict]
# ++counter[int] (a[0] is an int)
# check(actual.a[1], a[1]) -> ++counter[dict]
# check(actual.a[1][0], a[1][0]) -> ++counter[dict]
# ++counter[int] (a[1][0] is an int)
# check(actual.a[1][1], a[1][1]) -> ++counter[dict]
# recursive guard since a[1][1] == a
assert counter[type(a[0])] == 2
assert counter[type(a[1])] == 4

# same logic as above
counter = check(actual.b, b)
assert counter[type(b[0])] == 2
assert counter[type(b[1])] == 4


def test_config_pickle_circular_reference_in_dict():
x = {'a': 1, 'b': {'c': 1}}
x['b'] = x
assert is_serializable(x)

config = Config()
config.add('x', [], 'env', types=dict)
config.x = x

actual = pickle.loads(pickle.dumps(config))
assert is_serializable(actual.x)
assert isinstance(actual.x, dict)

assert actual.x['a'] == 1
assert actual.x['b']['a'] == 1

assert len(actual.x) == 2
assert len(actual.x['b']) == 2
assert len(actual.x['b']['b']) == 2

def check(
u: dict[dict | int],
v: dict[dict | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter:
counter = Counter() if counter is None else counter

if id(u) in guard and id(v) in guard:
return counter

if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter

assert isinstance(u, dict)
assert v.__class__ is u.__class__
assert len(u) == len(v)

for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)})
return counter

counters = check(actual.x, x, counter=Counter())
# check(actual.x, x)
# check(actual.x['a'], x['a']) -> ++counter[dict]
# ++counter[int] (x['a'] is an int)
# check(actual.x['b'], x['b']) -> ++counter[dict]
# recursive guard since x['b'] == x
assert counters[type(x['a'])] == 1
assert counters[type(x['b'])] == 2


def test_extension_values():
config = Config()

Expand Down