-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.py
236 lines (180 loc) · 7.72 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""Shared state for the sphinx-graph extension."""
from __future__ import annotations
from collections.abc import Iterable, Iterator, Mapping
from contextlib import contextmanager
from typing import TYPE_CHECKING
import rustworkx as rx
from sphinx.errors import DocumentError
from sphinx.util import logging
from sphinx_graph.vertex.info import Info
if TYPE_CHECKING:
from sphinx.application import Sphinx
from sphinx.environment import BuildEnvironment
logger = logging.getLogger(__name__)
__all__ = [
"State",
]
class DuplicateIdError(DocumentError):
"""Raised when a vertex with the same ID is added to the graph twice."""
category = "Document Error"
@contextmanager
def _vertices_tmp(env: BuildEnvironment) -> Iterator[dict[str, Info]]:
vertices: dict[str, Info] = getattr(env, "graph_vertices_tmp", {})
yield vertices
env.graph_vertices_tmp = vertices # type: ignore[attr-defined]
def purge(_app: Sphinx, env: BuildEnvironment, docname: str) -> None:
"""Clear out all stale vertices.
All vertices whose docname matches the given one from the graph_all_vertices list
will be removed.
If there are vertices left in the document, they will be added again during parsing.
"""
with _vertices_tmp(env) as vertices:
vertices = { # noqa: PLW2901
uid: vert for uid, vert in vertices.items() if vert.docname != docname
}
def merge(
_app: Sphinx,
env: BuildEnvironment,
_docnames: list[str],
other: BuildEnvironment,
) -> None:
"""Merge the vertices from multiple environments during parallel builds."""
with _vertices_tmp(env) as vertices, _vertices_tmp(other) as other_vertices:
vertices.update(other_vertices)
def insert_vertex(env: BuildEnvironment, uid: str, info: Info) -> None:
"""Insert a vertex into the build environment."""
with _vertices_tmp(env) as vertices:
if uid in vertices:
err_msg = f"Vertex {uid} already exists."
raise DuplicateIdError(err_msg)
vertices[uid] = info
def build_and_check_graph(env: BuildEnvironment) -> State:
"""Build the graph from the collected vertices.
Also checks the graph for consistency.
"""
vertices_tmp: dict[str, Info] = env.graph_vertices_tmp # type: ignore[attr-defined]
vertices: dict[str, tuple[int, Info]] = {}
graph: rx.PyDiGraph[str, str | None] = rx.PyDiGraph()
for uid, info in vertices_tmp.items():
node_id = graph.add_node(uid)
vertices[uid] = node_id, info
build_graph_edges(vertices, graph)
env.graph_vertices = vertices # type: ignore[attr-defined]
env.graph_graph = graph # type: ignore[attr-defined]
return State(vertices, graph)
class State:
"""State object for Sphinx Graph vertices."""
def __init__(
self,
vertices: dict[str, tuple[int, Info]],
graph: rx.PyDiGraph[str, str | None],
) -> None:
"""Create a new state object."""
self._vertices = vertices
self._graph = graph
@classmethod
def read(cls, env: BuildEnvironment) -> State:
"""Read the State object for the given environment.
This is a read-only view of the state. Changes will not be saved.
"""
vertices = getattr(env, "graph_vertices", {})
graph: rx.PyDiGraph[str, str | None] = getattr(
env, "graph_graph", rx.PyDiGraph(multigraph=False)
)
return State(vertices, graph)
@property
def graph(self) -> rx.PyDiGraph[str, str | None]:
"""A graph representing the relationships between vertices.
Vertices in the graph are stored using 'node ids'. These can be retrieved using
the `State.node_ids` mapping.
"""
return self._graph
@property
def vertices(self) -> Mapping[str, Info]:
"""A mapping from vertex uid to vertex Info."""
return Vertices(self._vertices)
@property
def node_ids(self) -> Mapping[str, int]:
"""A mapping from vertex uid to graph node ID."""
return NodeIds(self._vertices)
def children(self, uid: str) -> Iterable[str]:
"""Iterate over the children of the given node."""
node_id, _info = self._vertices[uid]
yield from self._graph.successors(node_id)
def ancestors(self, uid: str) -> Iterable[str]:
"""Recursively find all direct parents and ancestors of the given node."""
node_id = self.node_ids[uid]
yield from (
self.graph[anc_node_id] for anc_node_id in rx.ancestors(self.graph, node_id)
)
def descendants(self, uid: str) -> Iterable[str]:
"""Recursively find all direct children and descendants of the given node."""
node_id = self.node_ids[uid]
yield from (
self.graph[desc_node_id]
for desc_node_id in rx.descendants(self.graph, node_id)
)
class Vertices(Mapping[str, Info]):
"""A dict-like view of vertex Info keyed by vertex ID."""
def __init__(self, vertices: dict[str, tuple[int, Info]]) -> None:
self._vertices = vertices
def __getitem__(self, key: str) -> Info:
_node_id, info = self._vertices[key]
return info
def __iter__(self) -> Iterator[str]:
return iter(self._vertices)
def __len__(self) -> int:
return len(self._vertices)
class NodeIds(Mapping[str, int]):
"""A dict-like view of graph node IDs keyed by vertex ID."""
def __init__(self, vertices: dict[str, tuple[int, Info]]) -> None:
self._vertices = vertices
def __getitem__(self, key: str) -> int:
node_id, _info = self._vertices[key]
return node_id
def __iter__(self) -> Iterator[str]:
return iter(self._vertices)
def __len__(self) -> int:
return len(self._vertices)
def build_graph_edges(
vertices: Mapping[str, tuple[int, Info]], graph: rx.PyDiGraph[str, str | None]
) -> None:
"""Build the graph from the list of vertices.
This is called during setup, and doesn't need to be called again.
"""
# add all 'parent' edges
for uid, (node_id, info) in vertices.items():
fingerprints_required = info.config.require_fingerprints
for parent_uid, fingerprint in info.parents.items():
try:
parent_node_id, parent = vertices[parent_uid]
except KeyError:
logger.exception(
f"vertex '{uid}' has a parent link to '{parent_uid}',"
f" but '{parent_uid}' doesn't exist"
)
if fingerprints_required and fingerprint is None:
logger.warning(
f"link fingerprints are required, but {uid} doesn't have a"
f" fingerprint for its link to its parent {parent_uid}.\nthe"
f" fingerprint can be added by changing the parent reference on"
f" {uid} to '{parent_uid}:{parent.fingerprint}'.",
)
if fingerprint and fingerprint != parent.fingerprint:
logger.warning(
f"suspect link found. vertex {uid} is linked to vertex"
f" {parent_uid} with a fingerprint of '{fingerprint}', but"
f" {parent_uid}'s fingerprint is '{parent.fingerprint}'.\n{uid}"
" should be reviewed, and the link fingerprint manually updated.",
)
graph.add_edge(parent_node_id, node_id, fingerprint)
cycles = [
[graph[node_id] for node_id in node_ids] for node_ids in rx.simple_cycles(graph)
]
if cycles:
suffix = ", ".join(
f"[{uids[0]} -> {' -> '.join(uids[1:])} -> {uids[0]}]" for uids in cycles
)
logger.exception(
f"vertices must not have cyclic dependencies. cycles detected: {suffix}"
)