Skip to content

Commit

Permalink
Add --start-container and --end-container args
Browse files Browse the repository at this point in the history
These allow the container-level search to be skipped entirely, which
enables bisecting at the commit level on containers not from the
ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series.
  • Loading branch information
olupton committed Oct 29, 2024
1 parent 3ad9cc2 commit 54a1772
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 22 deletions.
49 changes: 46 additions & 3 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile


def parse_args():
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="""
Triage failures in JAX/XLA-related tests. The expectation is that the given
Expand Down Expand Up @@ -37,7 +37,6 @@ def parse_args():
help="""
Container to use. Example: jax, pax, triton. Used to construct the URLs of
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
required=True,
)
parser.add_argument(
"--output-prefix",
Expand Down Expand Up @@ -67,6 +66,15 @@ def parse_args():
Command to execute inside the container. This should be as targeted as
possible.""",
)
container_search_args.add_argument(
"--end-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --start-container must be too, but --container
is not required. This can be used to apply the commit-level bisection
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
series, although they must have a similar structure.""",
)
container_search_args.add_argument(
"--end-date",
help="""
Expand All @@ -76,6 +84,15 @@ def parse_args():
test case fails on this date.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--start-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --end-container must be too, but --container is
not required. This can be used to apply the commit-level bisection search
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
although they must have a similar structure.""",
)
container_search_args.add_argument(
"--start-date",
help="""
Expand Down Expand Up @@ -109,4 +126,30 @@ def parse_args():
significantly speed up the commit-level search. By default, uses a temporary
directory including the name of the current user.""",
)
return parser.parse_args()
args = parser.parse_args(args=args)
num_explicit_containers = (args.start_container is not None) + (
args.end_container is not None
)
if num_explicit_containers == 1:
raise Exception(
"--start-container and --end-container must both be passed if either is"
)
if num_explicit_containers == 2:
# Explicit mode, --container, --start-date and --end-date are all ignored
if args.container:
raise Exception(
"--container must not be passed if --start-container and --end-container are"
)
if args.start_date:
raise Exception(
"--start-date must not be passed if --start-container and --end-container are"
)
if args.end_date:
raise Exception(
"--end-date must not be passed if --start-container and --end-container are"
)
elif num_explicit_containers == 0 and args.container is None:
raise Exception(
"--container must be passed if --start-container and --end-container are not"
)
return args
2 changes: 1 addition & 1 deletion .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestResult:
"""
Hold the result/stdout/stderr of a test execution
"""

__test__ = False # stop pytest gathering this
result: bool
stdout: typing.Optional[str] = None
stderr: typing.Optional[str] = None
Expand Down
36 changes: 22 additions & 14 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,43 @@ def check_container(date: datetime.date) -> TestResult:
)
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)

# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
if args.start_container is not None:
assert args.end_container is not None
# Skip the container-level search because explicit end points were given
start_url = args.start_container
end_url = args.end_container
else:
# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
start_url = container_url(range_start)
end_url = container_url(range_end)

# Container-level search is now complete. Triage proceeds inside the `range_end``
# container. First, we check that rewinding JAX and XLA inside the `range_end``
# container to the commits used in the `range_start` container passes, whereas
# using the `range_end` commits reproduces the failure.

with Container(container_url(range_start)) as worker:
with Container(start_url) as worker:
start_jax_commit, _ = get_commit(worker, "jax")
start_xla_commit, _ = get_commit(worker, "xla")

# Fire up the container that will be used for the fine search.
with Container(container_url(range_end)) as worker:
with Container(end_url) as worker:
end_jax_commit, jax_dir = get_commit(worker, "jax")
end_xla_commit, xla_dir = get_commit(worker, "xla")
logger.info(
(
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}"
f"XLA [{start_xla_commit}, {end_xla_commit}] using {end_url}"
)
)

Expand Down
62 changes: 62 additions & 0 deletions .github/triage/tests/test_arg_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from jax_toolbox_triage.args import parse_args

test_command = ["my-test-command"]
valid_start_end_container = [
"--start-container",
"start-url",
"--end-container",
"end-url",
]
valid_start_end_date_args = [
["--container", "jax"],
["--container", "jax", "--start-date", "2024-10-02"],
["--container", "jax", "--end-date", "2024-10-02"],
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
]


@pytest.mark.parametrize(
"good_args", [valid_start_end_container] + valid_start_end_date_args
)
def test_good_container_args(good_args):
args = parse_args(good_args + test_command)
assert args.test_command == test_command


@pytest.mark.parametrize("date_args", valid_start_end_date_args)
def test_bad_container_arg_combinations_across_groups(date_args):
# Can't combine --{start,end}-container with --container/--{start,end}-date
with pytest.raises(Exception):
parse_args(valid_start_end_container + date_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need --container
[],
["--start-date", "2024-10-01"],
["--end-date", "2024-10-02"],
["--start-date", "2024-10-01", "--end-date", "2024-10-02"],
# Need both if either is passed
["--start-container", "start-url"],
["--end-container", "end-url"],
],
)
def test_bad_container_arg_combinations_within_groups(container_args):
with pytest.raises(Exception):
parse_args(container_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need valid ISO dates
["--container", "jax", "--start-date", "a-blue-moon-ago"],
["--container", "jax", "--end-date", "a-year-ago-last-thursday"],
],
)
def test_unparsable_container_args(container_args):
with pytest.raises(SystemExit):
parse_args(container_args + test_command)
16 changes: 12 additions & 4 deletions docs/triage-tool.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The tool follows a three-step process:
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
regression.

The third step can also be used on its own, via the `--start-container` and
`--end-container` options, which allows it to be used between private container tags,
without the dependency on the `ghcr.io/nvidia/jax` registry.

## Installation

The triage tool can be installed using `pip`:
Expand All @@ -46,11 +50,15 @@ needed to execute the test case.

## Usage

To use the tool, there are two compulsory arguments:
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
families to execute the test command in. Example: `jax` for a JAX unit test
failure, `maxtext` for a MaxText model execution failure
To use the tool, there are two compulsory inputs:
* A test command to triage.
* A specification of which containers to triage in:
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
families to execute the test command in. Example: `jax` for a JAX unit test
failure, `maxtext` for a MaxText model execution failure.
* `--start-container` and `--end-container`: a pair of URLs to containers to use
in the commit-level search; if these are passed then no container-level search
is performed.

The test command will be executed directly in the container, not inside a shell, so be
sure not to add excessive quotation marks (*i.e.* run
Expand Down

0 comments on commit 54a1772

Please sign in to comment.