From 54a1772bef4589fedf77fe95d47c268ae9c1f975 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 29 Oct 2024 02:50:06 -0700 Subject: [PATCH] Add --start-container and --end-container args These allow the container-level search to be skipped entirely, which enables bisecting at the commit level on containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series. --- .github/triage/jax_toolbox_triage/args.py | 49 +++++++++++++++-- .github/triage/jax_toolbox_triage/logic.py | 2 +- .github/triage/jax_toolbox_triage/main.py | 36 ++++++++----- .github/triage/tests/test_arg_parsing.py | 62 ++++++++++++++++++++++ docs/triage-tool.md | 16 ++++-- 5 files changed, 143 insertions(+), 22 deletions(-) create mode 100644 .github/triage/tests/test_arg_parsing.py diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index d092e7200..9122332c9 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -6,7 +6,7 @@ import tempfile -def parse_args(): +def parse_args(args=None): parser = argparse.ArgumentParser( description=""" Triage failures in JAX/XLA-related tests. The expectation is that the given @@ -37,7 +37,6 @@ def parse_args(): help=""" Container to use. Example: jax, pax, triton. Used to construct the URLs of nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""", - required=True, ) parser.add_argument( "--output-prefix", @@ -67,6 +66,15 @@ def parse_args(): Command to execute inside the container. This should be as targeted as possible.""", ) + container_search_args.add_argument( + "--end-container", + help=""" + Skip the container-level search and pass this container to the commit-level + search. If this is passed, --start-container must be too, but --container + is not required. This can be used to apply the commit-level bisection + search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD + series, although they must have a similar structure.""", + ) container_search_args.add_argument( "--end-date", help=""" @@ -76,6 +84,15 @@ def parse_args(): test case fails on this date.""", type=lambda s: datetime.date.fromisoformat(s), ) + container_search_args.add_argument( + "--start-container", + help=""" + Skip the container-level search and pass this container to the commit-level + search. If this is passed, --end-container must be too, but --container is + not required. This can be used to apply the commit-level bisection search + to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series, + although they must have a similar structure.""", + ) container_search_args.add_argument( "--start-date", help=""" @@ -109,4 +126,30 @@ def parse_args(): significantly speed up the commit-level search. By default, uses a temporary directory including the name of the current user.""", ) - return parser.parse_args() + args = parser.parse_args(args=args) + num_explicit_containers = (args.start_container is not None) + ( + args.end_container is not None + ) + if num_explicit_containers == 1: + raise Exception( + "--start-container and --end-container must both be passed if either is" + ) + if num_explicit_containers == 2: + # Explicit mode, --container, --start-date and --end-date are all ignored + if args.container: + raise Exception( + "--container must not be passed if --start-container and --end-container are" + ) + if args.start_date: + raise Exception( + "--start-date must not be passed if --start-container and --end-container are" + ) + if args.end_date: + raise Exception( + "--end-date must not be passed if --start-container and --end-container are" + ) + elif num_explicit_containers == 0 and args.container is None: + raise Exception( + "--container must be passed if --start-container and --end-container are not" + ) + return args diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index b92a12abe..2ea726f46 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -10,7 +10,7 @@ class TestResult: """ Hold the result/stdout/stderr of a test execution """ - + __test__ = False # stop pytest gathering this result: bool stdout: typing.Optional[str] = None stderr: typing.Optional[str] = None diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index b0dc9bb5d..8afc7a9c7 100755 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -106,35 +106,43 @@ def check_container(date: datetime.date) -> TestResult: ) return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr) - # Search through the published containers, narrowing down to a pair of dates with - # the property that the test passed on `range_start` and fails on `range_end`. - range_start, range_end = container_search( - container_exists=container_exists, - container_passes=check_container, - start_date=args.start_date, - end_date=args.end_date, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - threshold_days=args.threshold_days, - ) + if args.start_container is not None: + assert args.end_container is not None + # Skip the container-level search because explicit end points were given + start_url = args.start_container + end_url = args.end_container + else: + # Search through the published containers, narrowing down to a pair of dates with + # the property that the test passed on `range_start` and fails on `range_end`. + range_start, range_end = container_search( + container_exists=container_exists, + container_passes=check_container, + start_date=args.start_date, + end_date=args.end_date, + logger=logger, + skip_precondition_checks=args.skip_precondition_checks, + threshold_days=args.threshold_days, + ) + start_url = container_url(range_start) + end_url = container_url(range_end) # Container-level search is now complete. Triage proceeds inside the `range_end`` # container. First, we check that rewinding JAX and XLA inside the `range_end`` # container to the commits used in the `range_start` container passes, whereas # using the `range_end` commits reproduces the failure. - with Container(container_url(range_start)) as worker: + with Container(start_url) as worker: start_jax_commit, _ = get_commit(worker, "jax") start_xla_commit, _ = get_commit(worker, "xla") # Fire up the container that will be used for the fine search. - with Container(container_url(range_end)) as worker: + with Container(end_url) as worker: end_jax_commit, jax_dir = get_commit(worker, "jax") end_xla_commit, xla_dir = get_commit(worker, "xla") logger.info( ( f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and " - f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}" + f"XLA [{start_xla_commit}, {end_xla_commit}] using {end_url}" ) ) diff --git a/.github/triage/tests/test_arg_parsing.py b/.github/triage/tests/test_arg_parsing.py new file mode 100644 index 000000000..26f993eb9 --- /dev/null +++ b/.github/triage/tests/test_arg_parsing.py @@ -0,0 +1,62 @@ +import pytest +from jax_toolbox_triage.args import parse_args + +test_command = ["my-test-command"] +valid_start_end_container = [ + "--start-container", + "start-url", + "--end-container", + "end-url", +] +valid_start_end_date_args = [ + ["--container", "jax"], + ["--container", "jax", "--start-date", "2024-10-02"], + ["--container", "jax", "--end-date", "2024-10-02"], + ["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"], +] + + +@pytest.mark.parametrize( + "good_args", [valid_start_end_container] + valid_start_end_date_args +) +def test_good_container_args(good_args): + args = parse_args(good_args + test_command) + assert args.test_command == test_command + + +@pytest.mark.parametrize("date_args", valid_start_end_date_args) +def test_bad_container_arg_combinations_across_groups(date_args): + # Can't combine --{start,end}-container with --container/--{start,end}-date + with pytest.raises(Exception): + parse_args(valid_start_end_container + date_args + test_command) + + +@pytest.mark.parametrize( + "container_args", + [ + # Need --container + [], + ["--start-date", "2024-10-01"], + ["--end-date", "2024-10-02"], + ["--start-date", "2024-10-01", "--end-date", "2024-10-02"], + # Need both if either is passed + ["--start-container", "start-url"], + ["--end-container", "end-url"], + ], +) +def test_bad_container_arg_combinations_within_groups(container_args): + with pytest.raises(Exception): + parse_args(container_args + test_command) + + +@pytest.mark.parametrize( + "container_args", + [ + # Need valid ISO dates + ["--container", "jax", "--start-date", "a-blue-moon-ago"], + ["--container", "jax", "--end-date", "a-year-ago-last-thursday"], + ], +) +def test_unparsable_container_args(container_args): + with pytest.raises(SystemExit): + parse_args(container_args + test_command) diff --git a/docs/triage-tool.md b/docs/triage-tool.md index b8fbc99fc..805fe600c 100644 --- a/docs/triage-tool.md +++ b/docs/triage-tool.md @@ -20,6 +20,10 @@ The tool follows a three-step process: failing, and a reference commit of XLA (JAX) that can be used to reproduce the regression. +The third step can also be used on its own, via the `--start-container` and +`--end-container` options, which allows it to be used between private container tags, +without the dependency on the `ghcr.io/nvidia/jax` registry. + ## Installation The triage tool can be installed using `pip`: @@ -46,11 +50,15 @@ needed to execute the test case. ## Usage -To use the tool, there are two compulsory arguments: - * `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container - families to execute the test command in. Example: `jax` for a JAX unit test - failure, `maxtext` for a MaxText model execution failure +To use the tool, there are two compulsory inputs: * A test command to triage. + * A specification of which containers to triage in: + * `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container + families to execute the test command in. Example: `jax` for a JAX unit test + failure, `maxtext` for a MaxText model execution failure. + * `--start-container` and `--end-container`: a pair of URLs to containers to use + in the commit-level search; if these are passed then no container-level search + is performed. The test command will be executed directly in the container, not inside a shell, so be sure not to add excessive quotation marks (*i.e.* run