Skip to content

Commit

Permalink
feat: add scaffolding and basic tests for taskgraph generation (#776)
Browse files Browse the repository at this point in the history
This is prep work for #628, where I'd like to add some tests to avoid regressing that again in the future.

The fixtures here are based on similar tests from Gecko: https://searchfox.org/mozilla-central/source/taskcluster/test. There's a bit of a terrible hack to make optimized task graphs testable, described more in the comments.
  • Loading branch information
bhearsum authored Aug 7, 2024
1 parent 19fc7b9 commit f66a7b6
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 3 deletions.
7 changes: 7 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ tasks:
--output-file "{{.OUTPUT_FILE}}"
--diff "{{.BASE_REV}}"
taskgraph-test:
desc: Run tests and validations against task generation
cmds:
- >-
poetry run --directory taskgraph --
pytest taskcluster/test
docs:
desc: Run the GitHub pages Jekyll theme locally.
cmds:
Expand Down
15 changes: 14 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ sh = "^2.0.6"
zstandard = "^0.22.0"
translations_parser = {path="./tracking/", develop=true}
taskcluster-taskgraph = "^10.0.1"
translations_taskgraph = {path="./taskcluster/", develop=true}

[tool.black]
extend-exclude= "/3rd_party"
Expand Down
16 changes: 16 additions & 0 deletions taskcluster/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
[build-system]
requires = [ "setuptools>=61.2",]
build-backend = "setuptools.build_meta"

[project]
name = "translations_taskgraph"
version = "0.1.0"
description = "Translations specific code needed to generate Taskcluster tasks & graphs"
requires-python = ">=3.10"

[tool.setuptools]
include-package-data = true

[tool.setuptools.packages.find]
namespaces = false

[tool.ruff]
line-length = 120
target-version = "py37"
Expand Down
86 changes: 86 additions & 0 deletions taskcluster/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from copy import deepcopy
import pytest
import requests_mock
from typing import Any, Dict, Generator, List, Protocol

from taskgraph.generator import TaskGraph, TaskGraphGenerator
from taskgraph.parameters import Parameters, parameters_loader
from translations_taskgraph.util.substitution import substitute


class CreateTgg(Protocol):
def __call__(
self, parameters: Parameters | None = None, overrides: dict | None = None
) -> TaskGraphGenerator:
...


# These fixtures are largely cribbed from Gecko:
# https://searchfox.org/mozilla-central/source/taskcluster/test
@pytest.fixture(scope="session")
def create_tgg():
def inner(
parameters: Parameters | None = None, overrides: dict | None = None
) -> TaskGraphGenerator:
params = parameters_loader(parameters, strict=False, overrides=overrides)
return TaskGraphGenerator(None, params)

return inner


@pytest.fixture(scope="module")
def mock_requests() -> Generator[requests_mock.Mocker, None, None]:
with requests_mock.Mocker() as m:
yield m


# Scoping this at the module level means that each module will only generate
# a taskgraph one time, no matter how many tests are within it. This is
# beneficial for performance reasons, but forces any tests that need distinct
# parameters to be moved to their own modules.
@pytest.fixture(scope="module")
def tgg(request: pytest.FixtureRequest, create_tgg: CreateTgg) -> TaskGraphGenerator:
if not hasattr(request.module, "PARAMS"):
pytest.fail("'tgg' fixture requires a module-level 'PARAMS' variable")

return create_tgg(overrides=request.module.PARAMS)


@pytest.fixture(scope="module")
def full_task_graph(tgg: TaskGraphGenerator) -> TaskGraph:
return tgg.full_task_graph


@pytest.fixture(scope="module")
def target_task_graph(tgg: TaskGraphGenerator) -> TaskGraph:
return tgg.target_task_graph


@pytest.fixture(scope="module")
def target_task_set(tgg: TaskGraphGenerator) -> TaskGraph:
return tgg.target_task_set


@pytest.fixture(scope="module")
def optimized_task_graph(
request: pytest.FixtureRequest, mock_requests: requests_mock.Mocker, tgg: TaskGraphGenerator
) -> TaskGraph:
for resp in getattr(request.module, "MOCK_REQUESTS", {}):
responses: List[Dict[str, Any]] = deepcopy(resp["responses"])
digests = {}
# This is a bit of a terrible hack, but it allows for cached task digests
# to be substituted into mocked API responses, which is needed to test
# the optimized and/or morphed task graph. Cached task digests are
# generated as part of earlier phases, so there's no sensible way for
# them to defined concretely at the same time as other parts of the
# MOCK_REQUESTS.
for label, key in resp.get("substitute_digest", {}).items():
digests[key] = tgg.full_task_set[label].attributes["cached_task"]["digest"]
responses = substitute(responses, **digests)
mock_requests.request(
resp["method"],
resp["url"],
responses,
)

return tgg.optimized_task_graph
99 changes: 99 additions & 0 deletions taskcluster/test/test_default_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from copy import deepcopy

from taskgraph.taskgraph import TaskGraph

from translations_taskgraph.parameters import get_defaults

PARAMS = deepcopy(get_defaults(None))
PARAMS["target_tasks_method"] = "train-target-tasks"

MOCK_REQUESTS = [
{
"substitute_digest": {
"build-docker-image-base": "digest_base",
"build-docker-image-test": "digest_test",
"build-docker-image-toolchain-build": "digest_toolchain",
"build-docker-image-train": "digest_train",
},
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/index/v1/tasks/indexes",
"responses": [
{
"json": {
"tasks": [
{
"namespace": "translations.cache.level-3.docker-images.v2.base.hash.{digest_base}",
"taskId": "build-docker-image-base",
},
{
"namespace": "translations.cache.level-3.docker-images.v2.test.hash.{digest_test}",
"taskId": "build-docker-image-test",
},
{
"namespace": "translations.cache.level-3.docker-images.v2.toolchain-build.hash.{digest_toolchain}",
"taskId": "build-docker-image-toolchain-build",
},
{
"namespace": "translations.cache.level-3.docker-images.v2.train.hash.{digest_train}",
"taskId": "build-docker-image-train",
},
],
},
"status_code": 200,
},
],
},
{
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/queue/v1/tasks/status",
"responses": [
{
"json": {
"statuses": [
{
"status": {
"state": "completed",
"expires": "3024-08-21T22:37:28.781Z",
},
"taskId": "build-docker-image-base",
},
{
"status": {
"state": "completed",
"expires": "3024-08-21T22:37:28.781Z",
},
"taskId": "build-docker-image-test",
},
{
"status": {
"state": "completed",
"expires": "3024-08-21T22:37:28.781Z",
},
"taskId": "build-docker-image-toolchain-build",
},
{
"status": {
"state": "completed",
"expires": "3024-08-21T22:37:28.781Z",
},
"taskId": "build-docker-image-train",
},
],
},
"status_code": 200,
},
],
},
]


def test_last_task_is_targeted(target_task_set: TaskGraph):
"""Ensure that the last task in the pipeline is targeted by default"""
assert any([task == "all-ru-en-1" for task in target_task_set.tasks])


def test_cached_tasks_optimized_away(optimized_task_graph: TaskGraph):
"""Ensure that any tasks found in a cache route are _not_ present
in the optimized graph (ie: they will not be scheduled)."""
for task in optimized_task_graph.tasks.values():
assert not task.label.startswith("build-docker-image")
19 changes: 19 additions & 0 deletions taskcluster/test/test_target_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from copy import deepcopy

from taskgraph.taskgraph import TaskGraph

from translations_taskgraph.parameters import get_defaults

PARAMS = deepcopy(get_defaults(None))
PARAMS["target_tasks_method"] = "train-target-tasks"
PARAMS["training_config"]["target-stage"] = "train-teacher"


def test_nothing_downstream_of_target(target_task_graph: TaskGraph):
# despite being called `reverse_links_dict`, this actually
# gives us a dict where we can find tasks _downstream_ of
# each task by label
links = target_task_graph.graph.reverse_links_dict()
for task in target_task_graph.graph.nodes:
if task.startswith("train-teacher"):
assert links[task] == set()
28 changes: 28 additions & 0 deletions taskcluster/test/test_training_continuation_backwards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from copy import deepcopy

from taskgraph.taskgraph import TaskGraph

from translations_taskgraph.parameters import get_defaults

PARAMS = deepcopy(get_defaults(None))
PARAMS["target_tasks_method"] = "train-target-tasks"
PARAMS["training_config"]["experiment"]["pretrained-models"] = {
"train-backwards": {
"mode": "use",
"type": "default",
"urls": [
"https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student"
],
},
}


def test_artifact_mounts(full_task_graph: TaskGraph):
task = [t for t in full_task_graph.tasks.values() if t.label == "train-backwards-ru-en"][0]
# No need to bother looking for _all_ files (we'd just duplicate
# the full list if we did that...), but we verify that one file
# is well formed.
mounted_files = {m["file"]: m for m in task.task["payload"]["mounts"] if "file" in m}
assert mounted_files["./artifacts/model.npz"]["content"] == {
"url": "https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student/model.npz",
}
2 changes: 1 addition & 1 deletion taskcluster/translations_taskgraph/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# the entire pipeline reasonably quickly to validate changes to the pipeline
# itself. Any real training should be overriding most, if not all, of these
# via the input to the `train` action.
def get_defaults(_):
def get_defaults(_) -> dict:
return {
"training_config": {
"target-stage": "all",
Expand Down
5 changes: 4 additions & 1 deletion taskcluster/translations_taskgraph/util/substitution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any


class PartialSubstitutionDict(dict):
"""A dictionary that will return any missing keys as their formatable
version. Useful when a string needs to be formatted multiple times
Expand All @@ -7,7 +10,7 @@ def __missing__(self, key):
return "{" + key + "}"


def substitute(item, **subs):
def substitute(item: Any, **subs):
if isinstance(item, list):
for i in range(len(item)):
item[i] = substitute(item[i], **subs)
Expand Down

0 comments on commit f66a7b6

Please sign in to comment.