Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use graph.nodes(data='payload') to simpify code #1036

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions conda_forge_tick/auto_tick.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def add_replacement_migrator(
"""
total_graph = copy.deepcopy(gx)

for node, node_attrs in gx.nodes.items():
requirements = node_attrs["payload"].get("requirements", {})
for node, node_attrs in gx.nodes(data='payload'):
requirements = node_attrs.get("requirements", {})
rq = (
requirements.get("build", set())
| requirements.get("host", set())
Expand Down Expand Up @@ -476,14 +476,14 @@ def migration_factory(
# TODO: use the inbuilt LUT in the graph
output_to_feedstock = {
output: name
for name, node in gx.nodes.items()
for output in node.get("payload", {}).get("outputs_names", [])
for name, node in gx.nodes(data='payload')
for output in node.get("outputs_names", [])
}
all_package_names = set(gx.nodes) | set(
sum(
[
node.get("payload", {}).get("outputs_names", [])
for node in gx.nodes.values()
node.get("outputs_names", [])
for _, node in gx.nodes(data='payload')
],
[],
)
Expand Down
149 changes: 74 additions & 75 deletions conda_forge_tick/contexts.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,74 @@
import copy
from dataclasses import dataclass
from networkx import DiGraph
import typing
import threading
import github3

if typing.TYPE_CHECKING:
from conda_forge_tick.migrators import Migrator
from conda_forge_tick.migrators_types import AttrsTypedDict


@dataclass
class GithubContext:
github_username: str
github_password: str
circle_build_url: str
github_token: typing.Optional[str] = ""
dry_run: bool = True
_tl: threading.local = threading.local()

@property
def gh(self) -> github3.GitHub:
if getattr(self._tl, "gh", None) is None:
if self.github_token:
gh = github3.login(token=self.github_token)
else:
gh = github3.login(self.github_username, self.github_password)
setattr(self._tl, "gh", gh)
return self._tl.gh


@dataclass
class MigratorSessionContext(GithubContext):
"""Singleton session context. There should generally only be one of these"""

graph: DiGraph = None
smithy_version: str = ""
pinning_version: str = ""
prjson_dir = "pr_json"
rever_dir: str = "./feedstocks/"
quiet = True


@dataclass
class MigratorContext:
"""The context for a given migrator. This houses the runtime information that a migrator needs"""

session: MigratorSessionContext
migrator: "Migrator"
_effective_graph: DiGraph = None

@property
def github_username(self) -> str:
return self.session.github_username

@property
def effective_graph(self) -> DiGraph:
if self._effective_graph is None:
gx2 = copy.deepcopy(getattr(self.migrator, "graph", self.session.graph))

# Prune graph to only things that need builds right now
for node, node_attrs in self.session.graph.nodes.items():
attrs = node_attrs.get("payload", {})
if node in gx2 and self.migrator.filter(attrs):
gx2.remove_node(node)
self._effective_graph = gx2
return self._effective_graph


@dataclass
class FeedstockContext:
package_name: str
feedstock_name: str
attrs: "AttrsTypedDict"
import copy
from dataclasses import dataclass
from networkx import DiGraph
import typing
import threading
import github3

if typing.TYPE_CHECKING:
from conda_forge_tick.migrators import Migrator
from conda_forge_tick.migrators_types import AttrsTypedDict


@dataclass
class GithubContext:
github_username: str
github_password: str
circle_build_url: str
github_token: typing.Optional[str] = ""
dry_run: bool = True
_tl: threading.local = threading.local()

@property
def gh(self) -> github3.GitHub:
if getattr(self._tl, "gh", None) is None:
if self.github_token:
gh = github3.login(token=self.github_token)
else:
gh = github3.login(self.github_username, self.github_password)
setattr(self._tl, "gh", gh)
return self._tl.gh


@dataclass
class MigratorSessionContext(GithubContext):
"""Singleton session context. There should generally only be one of these"""

graph: DiGraph = None
smithy_version: str = ""
pinning_version: str = ""
prjson_dir = "pr_json"
rever_dir: str = "./feedstocks/"
quiet = True


@dataclass
class MigratorContext:
"""The context for a given migrator. This houses the runtime information that a migrator needs"""

session: MigratorSessionContext
migrator: "Migrator"
_effective_graph: DiGraph = None

@property
def github_username(self) -> str:
return self.session.github_username

@property
def effective_graph(self) -> DiGraph:
if self._effective_graph is None:
gx2 = copy.deepcopy(getattr(self.migrator, "graph", self.session.graph))

# Prune graph to only things that need builds right now
for node, attrs in self.session.graph.nodes(data="payload"):
if node in gx2 and self.migrator.filter(attrs):
gx2.remove_node(node)
self._effective_graph = gx2
return self._effective_graph


@dataclass
class FeedstockContext:
package_name: str
feedstock_name: str
attrs: "AttrsTypedDict"
16 changes: 8 additions & 8 deletions conda_forge_tick/make_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,21 @@ def make_graph(
# make the outputs look up table so we can link properly
outputs_lut = {
k: node_name
for node_name, node in gx.nodes.items()
for k in node.get("payload", {}).get("outputs_names", [])
for node_name, node in gx.nodes(data='payload')
for k in node.get("outputs_names", [])
}
# add this as an attr so we can use later
gx.graph["outputs_lut"] = outputs_lut
strong_exports = {
node_name
for node_name, node in gx.nodes.items()
if node.get("payload").get("strong_exports", False)
for node_name, node in gx.nodes(data='payload')
if node.get("strong_exports", False)
}
# This drops all the edge data and only keeps the node data
gx = nx.create_empty_copy(gx)
# TODO: label these edges with the kind of dep they are and their platform
for node, node_attrs in gx2.nodes.items():
with node_attrs["payload"] as attrs:
for node, attrs in gx2.nodes(data='payload'):
with attrs:
# replace output package names with feedstock names via LUT
deps = set(
map(
Expand Down Expand Up @@ -335,8 +335,8 @@ def make_graph(

def update_nodes_with_bot_rerun(gx):
"""Go through all the open PRs and check if they are rerun"""
for name, node in gx.nodes.items():
with node['payload'] as payload:
for name, payload in gx.nodes(data='payload'):
with payload:
for migration in payload.get('PRed', []):
pr_json = migration.get('PR', {})
# if there is a valid PR and it isn't currently listed as rerun
Expand Down
4 changes: 2 additions & 2 deletions conda_forge_tick/migrators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ def __init__(
else:
self.outputs_lut = {
k: node_name
for node_name, node in self.graph.nodes.items()
for k in node.get("payload", {}).get("outputs_names", [])
for node_name, node in self.graph.nodes(data='payload')
for k in node.get("outputs_names", [])
}

self.name = name
Expand Down
8 changes: 4 additions & 4 deletions conda_forge_tick/migrators/migration_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def __init__(
number_pred = len(
[
k
for k, v in self.graph.nodes.items()
if self.migrator_uid(v.get("payload", {}))
for k, v in self.graph.nodes(data='payload')
if self.migrator_uid(v)
in [vv.get("data", {}) for vv in v.get("payload", {}).get("PRed", [])]
],
)
Expand Down Expand Up @@ -422,11 +422,11 @@ def create_rebuild_graph(
total_graph = copy.deepcopy(gx)
excluded_feedstocks = set() if excluded_feedstocks is None else excluded_feedstocks

for node, node_attrs in gx.nodes.items():
for node, attrs in gx.nodes(data='payload'):
# always keep pinning
if node == 'conda-forge-pinning':
continue
attrs: "AttrsTypedDict" = node_attrs["payload"]
attrs: "AttrsTypedDict"
requirements = attrs.get("requirements", {})
host = requirements.get("host", set())
build = requirements.get("build", set())
Expand Down
9 changes: 4 additions & 5 deletions conda_forge_tick/status_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def graph_migrator_status(
if 'conda-forge-pinning' in gx2.nodes():
gx2.remove_node('conda-forge-pinning')

for node, node_attrs in gx2.nodes.items():
attrs = node_attrs["payload"]
for node, attrs in gx2.nodes(data="payload"):
# remove archived from status
if attrs.get("archived", False):
continue
Expand Down Expand Up @@ -235,11 +234,11 @@ def main(args: Any = None) -> None:

lst = [
k
for k, v in mctx.graph.nodes.items()
for k, v in mctx.graph.nodes(data="payload")
if len(
[
z
for z in v.get("payload", {}).get("PRed", [])
for z in v.get("PRed", [])
if z.get("PR", {}).get("state", "closed") == "open"
and z.get("data", {}).get("migrator_name", "") == "Version"
],
Expand All @@ -259,7 +258,7 @@ def main(args: Any = None) -> None:

lm = LicenseMigrator()
lst = [
k for k, v in mctx.graph.nodes.items() if not lm.filter(v.get("payload", {}))
k for k, v in mctx.graph.nodes(data="payload") if not lm.filter(v)
]
with open("./status/unlicensed.json", "w") as f:
json.dump(
Expand Down
11 changes: 5 additions & 6 deletions conda_forge_tick/update_upstream_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,11 @@ def get_latest_version(
def _update_upstream_versions_sequential(
gx: nx.DiGraph, sources: Iterable[AbstractSource],
) -> None:
_all_nodes = [t for t in gx.nodes.items()]
_all_nodes = [t for t in gx.nodes(data='payload')]
random.shuffle(_all_nodes)

to_update = []
for node, node_attrs in _all_nodes:
attrs = node_attrs["payload"]
for node, attrs in _all_nodes:
if attrs.get("bad") or attrs.get("archived"):
attrs["new_version"] = False
continue
Expand Down Expand Up @@ -534,10 +533,10 @@ def _update_upstream_versions_process_pool(
# this has to be threads because the url hashing code uses a Pipe which
# cannot be spawned from a process
with executor(kind="dask", max_workers=20) as pool:
_all_nodes = [t for t in gx.nodes.items()]
_all_nodes = [t for t in gx.nodes(data='payload')]
random.shuffle(_all_nodes)
for node, node_attrs in tqdm.tqdm(_all_nodes):
with node_attrs["payload"] as attrs:
for node, attrs in tqdm.tqdm(_all_nodes):
with attrs:
if node == "ca-policy-lcg":
attrs["new_version"] = False
continue
Expand Down