Skip to content

Commit

Permalink
feat: use new batched queries in task replacement
Browse files Browse the repository at this point in the history
Batching queries to get from task index to status reduces the number of
(sometimes trans-continental) queries from 2*900+ to ~2. This reduces the
time spent replacing tasks from 20% to 75% depending on the use-case.

The 20% improvement in wall time was observed when running
`mach taskgraph morphed` in a CI worker, while the 75% improvement
was observed in a developer machine in France running `mach taskgraph full`.

More information in taskcluster/taskcluster-rfcs#189.
  • Loading branch information
Alphare authored and ahal committed May 6, 2024
1 parent ac26282 commit 5207917
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 39 deletions.
32 changes: 32 additions & 0 deletions src/taskgraph/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from taskgraph.taskgraph import TaskGraph
from taskgraph.util.parameterization import resolve_task_references, resolve_timestamps
from taskgraph.util.python_path import import_sibling_modules
from taskgraph.util.taskcluster import find_task_id_batched, status_task_batched

logger = logging.getLogger(__name__)
registry = {}
Expand Down Expand Up @@ -51,6 +52,9 @@ def optimize_task_graph(
Perform task optimization, returning a taskgraph and a map from label to
assigned taskId, including replacement tasks.
"""
# avoid circular import
from taskgraph.optimize.strategies import IndexSearch

label_to_taskid = {}
if not existing_tasks:
existing_tasks = {}
Expand All @@ -70,6 +74,23 @@ def optimize_task_graph(
do_not_optimize=do_not_optimize,
)

# Gather each relevant task's index
indexes = set()
for label in target_task_graph.graph.visit_postorder():
if label in do_not_optimize:
continue
_, strategy, arg = optimizations(label)
if isinstance(strategy, IndexSearch) and arg is not None:
indexes.update(arg)

index_to_taskid = {}
taskid_to_status = {}
if indexes:
# Find their respective status using TC index/queue batch APIs
indexes = list(indexes)
index_to_taskid = find_task_id_batched(indexes)
taskid_to_status = status_task_batched(list(index_to_taskid.values()))

replaced_tasks = replace_tasks(
target_task_graph=target_task_graph,
optimizations=optimizations,
Expand All @@ -78,6 +99,8 @@ def optimize_task_graph(
label_to_taskid=label_to_taskid,
existing_tasks=existing_tasks,
removed_tasks=removed_tasks,
index_to_taskid=index_to_taskid,
taskid_to_status=taskid_to_status,
)

return (
Expand Down Expand Up @@ -259,12 +282,17 @@ def replace_tasks(
label_to_taskid,
removed_tasks,
existing_tasks,
index_to_taskid,
taskid_to_status,
):
"""
Implement the "Replacing Tasks" phase, returning a set of task labels of
all replaced tasks. The replacement taskIds are added to label_to_taskid as
a side-effect.
"""
# avoid circular import
from taskgraph.optimize.strategies import IndexSearch

opt_counts = defaultdict(int)
replaced = set()
dependents_of = target_task_graph.graph.reverse_links_dict()
Expand Down Expand Up @@ -307,6 +335,10 @@ def replace_tasks(
deadline = max(
resolve_timestamps(now, task.task["deadline"]) for task in dependents
)

if isinstance(opt, IndexSearch):
arg = arg, index_to_taskid, taskid_to_status

repl = opt.should_replace_task(task, params, deadline, arg)
if repl:
if repl is True:
Expand Down
11 changes: 6 additions & 5 deletions src/taskgraph/optimize/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from taskgraph.optimize.base import OptimizationStrategy, register_strategy
from taskgraph.util.path import match as match_path
from taskgraph.util.taskcluster import find_task_id, status_task

logger = logging.getLogger(__name__)

Expand All @@ -22,12 +21,14 @@ class IndexSearch(OptimizationStrategy):

fmt = "%Y-%m-%dT%H:%M:%S.%fZ"

def should_replace_task(self, task, params, deadline, index_paths):
def should_replace_task(self, task, params, deadline, arg):
"Look for a task with one of the given index paths"
index_paths, label_to_taskid, taskid_to_status = arg

for index_path in index_paths:
try:
task_id = find_task_id(index_path)
status = status_task(task_id)
task_id = label_to_taskid[index_path]
status = taskid_to_status[task_id]
# status can be `None` if we're in `testing` mode
# (e.g. test-action-callback)
if not status or status.get("state") in ("exception", "failed"):
Expand All @@ -40,7 +41,7 @@ def should_replace_task(self, task, params, deadline, index_paths):

return task_id
except KeyError:
# 404 will end up here and go on to the next index path
# go on to the next index path
pass

return False
Expand Down
85 changes: 85 additions & 0 deletions src/taskgraph/util/taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,48 @@ def find_task_id(index_path, use_proxy=False):
return response.json()["taskId"]


def find_task_id_batched(index_paths, use_proxy=False):
"""Gets the task id of multiple tasks given their respective index.
Args:
index_paths (List[str]): A list of task indexes.
use_proxy (bool): Whether to use taskcluster-proxy (default: False)
Returns:
Dict[str, str]: A dictionary object mapping each valid index path
to its respective task id.
See the endpoint here:
https://docs.taskcluster.net/docs/reference/core/index/api#findTasksAtIndex
"""
endpoint = liburls.api(get_root_url(use_proxy), "index", "v1", "tasks/indexes")
task_ids = {}
continuation_token = None

while True:
response = _do_request(
endpoint,
json={
"indexes": index_paths,
},
params={"continuationToken": continuation_token},
)

response_data = response.json()
if not response_data["tasks"]:
break
response_tasks = response_data["tasks"]
if (len(task_ids) + len(response_tasks)) > len(index_paths):
# Sanity check
raise ValueError("more task ids were returned than were asked for")
task_ids.update((t["namespace"], t["taskId"]) for t in response_tasks)

continuationToken = response_data.get("continuationToken")
if continuationToken is None:
break
return task_ids


def get_artifact_from_index(index_path, artifact_path, use_proxy=False):
full_path = index_path + "/artifacts/" + artifact_path
response = _do_request(get_index_url(full_path, use_proxy))
Expand Down Expand Up @@ -271,6 +313,49 @@ def status_task(task_id, use_proxy=False):
return status


def status_task_batched(task_ids, use_proxy=False):
"""Gets the status of multiple tasks given task_ids.
In testing mode, just logs that it would have retrieved statuses.
Args:
task_id (List[str]): A list of task ids.
use_proxy (bool): Whether to use taskcluster-proxy (default: False)
Returns:
dict: A dictionary object as defined here:
https://docs.taskcluster.net/docs/reference/platform/queue/api#statuses
"""
if testing:
logger.info(f"Would have gotten status for {len(task_ids)} tasks.")
return
endpoint = liburls.api(get_root_url(use_proxy), "queue", "v1", "tasks/status")
statuses = {}
continuation_token = None

while True:
response = _do_request(
endpoint,
json={
"taskIds": task_ids,
},
params={
"continuationToken": continuation_token,
},
)
response_data = response.json()
if not response_data["statuses"]:
break
response_tasks = response_data["statuses"]
if (len(statuses) + len(response_tasks)) > len(task_ids):
raise ValueError("more task statuses were returned than were asked for")
statuses.update((t["taskId"], t["status"]) for t in response_tasks)
continuationToken = response_data.get("continuationToken")
if continuationToken is None:
break
return statuses


def state_task(task_id, use_proxy=False):
"""Gets the state of a task given a task_id.
Expand Down
29 changes: 16 additions & 13 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,20 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):


@pytest.mark.parametrize(
"graph,kwargs,exp_replaced,exp_removed,exp_label_to_taskid",
"graph,kwargs,exp_replaced,exp_removed",
(
# A task cannot be replaced if it depends on one that was not replaced
pytest.param(
make_triangle(
t1={"replace": "e1"},
t3={"replace": "e3"},
),
{},
{
"index_to_taskid": {"t1": "e1"},
},
# expectations
{"t1"},
set(),
{"t1": "e1"},
id="blocked",
),
# A task cannot be replaced if it should not be optimized
Expand All @@ -291,11 +292,13 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
t2={"replace": "xxx"}, # but do_not_optimize
t3={"replace": "e3"},
),
{"do_not_optimize": {"t2"}},
{
"do_not_optimize": {"t2"},
"index_to_taskid": {"t1": "e1"},
},
# expectations
{"t1"},
set(),
{"t1": "e1"},
id="do_not_optimize",
),
# No tasks are replaced when strategy is 'never'
Expand All @@ -305,7 +308,6 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
# expectations
set(),
set(),
{},
id="never",
),
# All replaceable tasks are replaced when strategy is 'replace'
Expand All @@ -315,11 +317,12 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
t2={"replace": "e2"},
t3={"replace": "e3"},
),
{},
{
"index_to_taskid": {"t1": "e1", "t2": "e2", "t3": "e3"},
},
# expectations
{"t1", "t2", "t3"},
set(),
{"t1": "e1", "t2": "e2", "t3": "e3"},
id="all",
),
# A task can be replaced with nothing
Expand All @@ -329,11 +332,12 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
t2={"replace": True},
t3={"replace": True},
),
{},
{
"index_to_taskid": {"t1": "e1"},
},
# expectations
{"t1"},
{"t2", "t3"},
{"t1": "e1"},
id="tasks_removed",
),
# A task which expires before a dependents deadline is not a valid replacement.
Expand All @@ -353,7 +357,6 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
# expectations
set(),
set(),
{},
id="deadline",
),
),
Expand All @@ -363,7 +366,6 @@ def test_replace_tasks(
kwargs,
exp_replaced,
exp_removed,
exp_label_to_taskid,
):
"""Tests the `replace_tasks` function.
Expand All @@ -378,6 +380,8 @@ def test_replace_tasks(
kwargs.setdefault("params", {})
kwargs.setdefault("do_not_optimize", set())
kwargs.setdefault("label_to_taskid", {})
kwargs.setdefault("index_to_taskid", {})
kwargs.setdefault("taskid_to_status", {})
kwargs.setdefault("removed_tasks", set())
kwargs.setdefault("existing_tasks", {})

Expand All @@ -388,7 +392,6 @@ def test_replace_tasks(
)
assert got_replaced == exp_replaced
assert kwargs["removed_tasks"] == exp_removed
assert kwargs["label_to_taskid"] == exp_label_to_taskid


@pytest.mark.parametrize(
Expand Down
35 changes: 14 additions & 21 deletions test/test_optimize_strategies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Any copyright is dedicated to the public domain.
# http://creativecommons.org/publicdomain/zero/1.0/

import os
from datetime import datetime
from test.fixtures.gen import make_task
from time import mktime
Expand Down Expand Up @@ -44,31 +43,25 @@ def params():
),
),
)
def test_index_search(responses, params, state, expires, expected):
def test_index_search(state, expires, expected):
taskid = "abc"
index_path = "foo.bar.latest"
responses.add(
responses.GET,
f"{os.environ['TASKCLUSTER_ROOT_URL']}/api/index/v1/task/{index_path}",
json={"taskId": taskid},
status=200,
)

responses.add(
responses.GET,
f"{os.environ['TASKCLUSTER_ROOT_URL']}/api/queue/v1/task/{taskid}/status",
json={
"status": {
"state": state,
"expires": expires,
}
},
status=200,
)
label_to_taskid = {index_path: taskid}
taskid_to_status = {
taskid: {
"state": state,
"expires": expires,
}
}

opt = IndexSearch()
deadline = "2021-06-07T19:03:20.482Z"
assert opt.should_replace_task({}, params, deadline, (index_path,)) == expected
assert (
opt.should_replace_task(
{}, params, deadline, ((index_path,), label_to_taskid, taskid_to_status)
)
== expected
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 5207917

Please sign in to comment.