-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgraph.py
463 lines (384 loc) · 15.2 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
import bz2
from collections import deque, defaultdict
from logging import getLogger
import time
import os
import pathlib
import typing
import operator
from pprint import pformat
try:
import ruamel.yaml as ruamel_yaml
except ImportError:
import ruamel_yaml
import networkx
import requests
from pandas.io import json
from sortedcontainers import SortedList
from cachetools import LRUCache, cachedmethod, cached, TTLCache
logger = getLogger(__name__)
DEFAULT_BASE_URL = "https://conda.anaconda.org/"
REPODATA_FILE_CURRENT = "current_repodata.json"
REPODATA_FILE = "repodata.json.bz2"
def build_repodata_graph(
repodata: dict, arch: str, url_prefix: str
) -> networkx.DiGraph:
G = networkx.DiGraph()
for p, v in repodata["packages"].items():
name = v["name"]
G.add_node(name)
G.nodes[name].setdefault("arch", set())
G.nodes[name]["arch"].add(arch)
G.nodes[name].setdefault(f"packages_{arch}", {})
G.nodes[name][f"packages_{arch}"][p] = v
v["url"] = f"{url_prefix}/{p}"
for dep in v["depends"]:
dep_name, _, _ = dep.partition(" ")
G.add_edge(dep_name, name)
return G
def compose_with_attrs(G: networkx.DiGraph, H: networkx.DiGraph) -> networkx.DiGraph:
"""Composes the two graphs together in such a way as to retain / merge attributes"""
I = networkx.compose(G, H)
for name, attrs in G.nodes(data=True):
for key, val in attrs.items():
if not key.startswith("packages_"):
continue
arch = key[9:]
I.nodes[name].setdefault("arch", set())
I.nodes[name]["arch"].add(arch)
I.nodes[name].setdefault(key, {})
I.nodes[name][key].update(val)
return I
def recursive_parents(G: networkx.DiGraph, nodes):
if isinstance(nodes, str):
nodes = [nodes]
done = set()
todo = deque(nodes)
while todo:
n = todo.popleft()
if n in done:
continue
# conda automatically adds pip as a dep of python even when it isn't
# this preserves this quirk
if n == "python" and "pip" not in done:
todo.append("pip")
# If we requested a package that does not exist in our graph, skip it
if n not in G.nodes:
# TODO: switch logging to loguru so that we can have context
logger.warning(f"Package {n} not found in graph!")
done.add(n)
continue
# TODO: this seems to cause issues with root nodes like zlib
children = list(G.predecessors(n))
todo.extend(children)
done.add(n)
return done
class RawRepoData:
_ttl = 600
_cache = TTLCache(100, ttl=_ttl)
_last_expiry = time.monotonic()
def __init__(
self,
*,
channel: str,
arch: str = "linux-64",
base_url: str = DEFAULT_BASE_URL,
repodata_file: str = REPODATA_FILE,
ttl=600,
):
# setup cache
self.ttl = ttl
# normal seetings
logger.info(f"RETRIEVING: {channel}, {arch}")
# for channels that have explicitly specified the channel
if channel.startswith("http://") or channel.startswith("https://"):
url_prefix = channel.rstrip("/") + f"/{arch}"
elif "{channel}" in base_url and "{arch}" in base_url:
url_prefix = base_url.format(channel=channel, arch=arch)
elif "{channel}" in base_url:
url_prefix = base_url.format(channel=channel).rstrip("/") + f"/{arch}"
else:
url_prefix = f"{base_url.rstrip('/')}/{channel}/{arch}"
repodata_url = f"{url_prefix}/{repodata_file}"
self.channel = channel
self.arch = arch
self.repodata_url = repodata_url
# Attempt to fetch current repodata
data = requests.get(repodata_url)
if data.ok:
if repodata_url.endswith(".bz2"):
decompressed_content = bz2.decompress(data.content)
else:
decompressed_content = data.content
repodata = json.loads(decompressed_content)
self.graph = build_repodata_graph(repodata, arch, url_prefix)
logger.info(f"GRAPH BUILD FOR {repodata_url}")
else:
self.graph = None
logger.warning(f"NO BUILD FOR {repodata_url}")
def __hash__(self):
return hash(self.repodata_url)
def __repr__(self):
return f"RawRepoData({self.repodata_url})"
@classmethod
def _expire(cls):
# when getting the cache, be sure to clear it, if needed.
current = time.monotonic()
if current - cls._last_expiry >= cls._ttl:
cls._cache.expire()
cls._last_expiry = current
class FusedRepoData:
"""Utility class describing a set of repodatas treated as a single repository.
Packages in prior repodatas take precendence.
"""
def __init__(self, raw_repodata: typing.Sequence[RawRepoData], arch):
logger.debug(f"FUSING: {raw_repodata}")
self.arch = arch
self.component_channels = [raw_repodata[0].channel]
# TODO: Maybe cache this?
G = raw_repodata[0].graph
for i in range(1, len(raw_repodata)):
raw = raw_repodata[i]
self.component_channels.append(raw.channel)
H = raw.graph
G = networkx.compose(G, H)
self.graph = G
def __repr__(self):
return f"FusedRepoData([{''.join(self.component_channels)}], {self.arch})"
def get_repo_data(
channel: typing.List[str],
arch: str,
repodata_file: str,
base_url: str = DEFAULT_BASE_URL,
) -> FusedRepoData:
repodatas = []
RawRepoData._expire()
for c in channel:
key = (c, arch, repodata_file)
# TODO: This should happen in parallel
if key not in RawRepoData._cache:
logger.info("refreshing cache for {c}/{arch}")
RawRepoData._cache[key] = RawRepoData(
channel=c, arch=arch, base_url=base_url, repodata_file=repodata_file
)
repodatas.append(RawRepoData._cache[key])
return FusedRepoData(repodatas, arch)
def parse_constraints(constraints):
package_constraints = []
# functional constrains are used to constrain within packages.
#
# These can be version constraints. Build number constraints etc
functional_constraints = defaultdict(set)
for c in constraints:
if c.startswith("--"):
key, _, val = c.partition("=")
functional_constraints[key].add(val)
else:
package_constraints.append(c)
return package_constraints, functional_constraints
@cached(cache={})
def get_blacklist(blacklist_name, channel, arch):
path = pathlib.Path("blacklists") / channel / (blacklist_name + ".yml")
if path.exists():
with path.open() as fo:
obj = ruamel_yaml.safe_load(fo)
return set(obj.get(arch, []))
else:
return set()
class ArtifactGraph:
_ttl = 600
_artifact_graph_cache = TTLCache(100, ttl=_ttl)
_last_expiry = time.monotonic()
def __init__(
self, channel, arch, constraints, repodata_file, base_url=DEFAULT_BASE_URL
):
self.base_url = base_url
self.channel = channel
self.arch = arch
self.constraints = constraints
self.raw = get_repo_data(
channel=channel, arch=arch, base_url=base_url, repodata_file=repodata_file
)
if self.raw.graph is not None:
# TODO: Since solving the artifact graph happens twice for a given conda operation, once for arch and once for
# noarch we need to treat the noarch channel here as an arch channel.
# The choice of noarch standin as linux-64 is mostly convenience.
# In the future it may be wiser to just store the whole are collectively.
if arch != "noarch":
self.noarch = get_repo_data(
channel=channel,
arch="noarch",
base_url=base_url,
repodata_file=repodata_file,
)
else:
self.noarch = get_repo_data(
channel=channel,
arch="linux-64",
base_url=base_url,
repodata_file=repodata_file,
)
self.package_constraints, self.functional_constraints = parse_constraints(
constraints
)
self.constrain_graph(
self.raw.graph, self.noarch.graph, self.package_constraints
)
else:
self.constrained_graph = None
self._repodata_cache = TTLCache(100, ttl=self._ttl)
def __repr__(self):
return f"{self.__class__.__name__}({self.channel!r}, {self.arch!r}, {self.constraints!r})"
@classmethod
def artifact_graph_cache(cls):
# when getting the cache, be sure to clear it, if needed.
current = time.monotonic()
if current - cls._last_expiry >= cls._ttl:
cls._artifact_graph_cache.expire()
cls._last_expiry = current
return cls._artifact_graph_cache
def constrain_graph(self, graph, noarch_graph, constraints):
# Since noarch is solved along with our normal channel we need to combine the two for our effective
# graph.
combined_graph = compose_with_attrs(graph, noarch_graph)
if constraints:
nodes = recursive_parents(combined_graph, constraints)
subset = combined_graph.subgraph(nodes)
self.constrained_graph = subset
else:
self.constrained_graph = combined_graph
def repodata_json_dict(self):
if self.constrained_graph:
all_packages = {}
for n in self.constrained_graph:
logger.debug(n)
packages = self.constrained_graph.nodes[n].get(f"packages_{self.arch}", {})
if "--max-build-no" in self.functional_constraints:
# packages with build strings should always be included
packages = self.constrain_by_build_number(packages)
if "--untrack-features" in self.functional_constraints:
packages = self.untrack_features(packages)
if "--blacklist" in self.functional_constraints:
for blacklist_name in self.functional_constraints["--blacklist"]:
packages = self.constrain_by_blacklist(packages, blacklist_name)
if n == "blas":
logger.debug(pformat(packages))
all_packages.update(packages)
return {"packages": all_packages}
else:
return None
def constrain_by_build_number(self, packages):
"""For a given packages dictionary ensure that only the top build number for a given build_string is kept
Packages without a build number (such as the blas mutex package are unaffected)
k: artifact_name, v: package information dictionary
For example
0.23.0-py27_0, 0.23.0-py27_1, 0.23.0-py36_0
->
0.23.0-py27_1, 0.23.0-py36_0
"""
keep_packages = []
packages_by_version = defaultdict(
lambda: SortedList(key=lambda o: -o[1].get("build_number", 0))
)
for k, v in packages.items():
build_string: str = v.get("build", "")
build_string, _, build_number = build_string.rpartition("_")
if not build_number.isnumeric():
keep_packages.append((k, v))
else:
packages_by_version[(v["version"], build_string)].add((k, v))
for version, ordered_builds in packages_by_version.items():
keep_packages.append(ordered_builds[0])
packages = dict(keep_packages)
return packages
def constrain_by_blacklist(self, packages, blacklist_name):
effective_blacklist = set()
for channel in self.raw.component_channels:
effective_blacklist.update(
get_blacklist(blacklist_name, channel, self.arch)
)
if len(effective_blacklist):
o = {k: v for k, v in packages.items() if k not in effective_blacklist}
logger.debug(
"constrained channel from {} to {} artifacts".format(
len(packages), len(o)
)
)
return o
else:
return packages
def untrack_features(self, packages: dict) -> dict:
"""TODO: This function edits the package information dictionary so that packages that are tracked are
instead replaced by the appropriate dependencies.
"""
feature_map = {
"blas_openblas": "blas * openblas",
"blas_mkl": "blas * mkl",
"blas_nomkl": "blas * nomkl",
"vc9": "vs2008_runtime",
"vc10": "vs2010_runtime",
"vc14": "vs2015_runtime",
}
for k, v in packages.items():
features = v.get("features", "").split(" ")
kept_features = []
for feature in features:
if feature in feature_map:
v["depends"].append(feature_map[feature])
else:
kept_features.append(feature)
kept_features = " ".join(kept_features)
if kept_features:
v["features"] = kept_features
else:
v.pop("features", None)
# For feature packages get rid of mapped things
track_feature = v.get("track_features")
if track_feature in feature_map:
del v["track_features"]
return packages
@cachedmethod(operator.attrgetter("_repodata_cache"))
def repodata_json(self) -> str:
out_string = json.dumps(self.repodata_json_dict())
return out_string
@cachedmethod(operator.attrgetter("_repodata_cache"))
def repodata_json_bzip(self) -> bytes:
import bz2
out_bytes = bz2.compress(self.repodata_json().encode("utf8"), compresslevel=1)
return out_bytes
def get_artifact_graph(
channel: typing.List[str],
arch: str,
constraints,
repodata_file: str,
base_url: str = DEFAULT_BASE_URL,
) -> ArtifactGraph:
if isinstance(constraints, str):
constraints = [constraints]
# Special handling for defaults because it is special
if "defaults" in channel:
if arch == "win-64":
new_channel = [
"https://repo.anaconda.com/pkgs/main",
"https://repo.anaconda.com/pkgs/msys",
"https://repo.anaconda.com/pkgs/r",
]
else:
new_channel = [
"https://repo.anaconda.com/pkgs/main",
"https://repo.anaconda.com/pkgs/r",
]
idx = channel.index("defaults")
channel = channel[:idx] + new_channel + channel[idx + 1 :]
print(f"Using channel {channel}")
key = (tuple(channel), arch, tuple(sorted(constraints)), repodata_file)
agcache = ArtifactGraph.artifact_graph_cache()
if key not in agcache:
agcache[key] = ArtifactGraph(
channel=channel,
arch=arch,
constraints=constraints,
repodata_file=repodata_file,
base_url=base_url,
)
return agcache[key]