Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-toolbox-triage: minor usability/doc improvements #1125

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile


def parse_args():
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="""
Triage failures in JAX/XLA-related tests. The expectation is that the given
Expand Down Expand Up @@ -37,7 +37,6 @@ def parse_args():
help="""
Container to use. Example: jax, pax, triton. Used to construct the URLs of
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
required=True,
)
parser.add_argument(
"--output-prefix",
Expand Down Expand Up @@ -67,6 +66,15 @@ def parse_args():
Command to execute inside the container. This should be as targeted as
possible.""",
)
container_search_args.add_argument(
"--end-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --start-container must be too, but --container
is not required. This can be used to apply the commit-level bisection
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
series, although they must have a similar structure.""",
)
container_search_args.add_argument(
"--end-date",
help="""
Expand All @@ -76,6 +84,15 @@ def parse_args():
test case fails on this date.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--start-container",
olupton marked this conversation as resolved.
Show resolved Hide resolved
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --end-container must be too, but --container is
not required. This can be used to apply the commit-level bisection search
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
although they must have a similar structure.""",
)
container_search_args.add_argument(
"--start-date",
help="""
Expand Down Expand Up @@ -109,4 +126,30 @@ def parse_args():
significantly speed up the commit-level search. By default, uses a temporary
directory including the name of the current user.""",
)
return parser.parse_args()
args = parser.parse_args(args=args)
num_explicit_containers = (args.start_container is not None) + (
args.end_container is not None
)
if num_explicit_containers == 1:
raise Exception(
"--start-container and --end-container must both be passed if either is"
)
if num_explicit_containers == 2:
# Explicit mode, --container, --start-date and --end-date are all ignored
if args.container:
raise Exception(
"--container must not be passed if --start-container and --end-container are"
)
if args.start_date:
raise Exception(
"--start-date must not be passed if --start-container and --end-container are"
)
if args.end_date:
raise Exception(
"--end-date must not be passed if --start-container and --end-container are"
)
elif num_explicit_containers == 0 and args.container is None:
raise Exception(
"--container must be passed if --start-container and --end-container are not"
)
return args
1 change: 1 addition & 0 deletions .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class TestResult:
Hold the result/stdout/stderr of a test execution
"""

__test__ = False # stop pytest gathering this
result: bool
stdout: typing.Optional[str] = None
stderr: typing.Optional[str] = None
Expand Down
36 changes: 22 additions & 14 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,43 @@ def check_container(date: datetime.date) -> TestResult:
)
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)

# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
if args.start_container is not None:
assert args.end_container is not None
# Skip the container-level search because explicit end points were given
start_url = args.start_container
end_url = args.end_container
else:
# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
start_url = container_url(range_start)
end_url = container_url(range_end)

# Container-level search is now complete. Triage proceeds inside the `range_end``
# container. First, we check that rewinding JAX and XLA inside the `range_end``
# container to the commits used in the `range_start` container passes, whereas
# using the `range_end` commits reproduces the failure.

with Container(container_url(range_start)) as worker:
with Container(start_url) as worker:
start_jax_commit, _ = get_commit(worker, "jax")
start_xla_commit, _ = get_commit(worker, "xla")

# Fire up the container that will be used for the fine search.
with Container(container_url(range_end)) as worker:
with Container(end_url) as worker:
end_jax_commit, jax_dir = get_commit(worker, "jax")
end_xla_commit, xla_dir = get_commit(worker, "xla")
logger.info(
(
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}"
f"XLA [{start_xla_commit}, {end_xla_commit}] using {end_url}"
)
)

Expand Down
62 changes: 62 additions & 0 deletions .github/triage/tests/test_arg_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from jax_toolbox_triage.args import parse_args

test_command = ["my-test-command"]
valid_start_end_container = [
"--start-container",
"start-url",
"--end-container",
"end-url",
]
valid_start_end_date_args = [
["--container", "jax"],
["--container", "jax", "--start-date", "2024-10-02"],
["--container", "jax", "--end-date", "2024-10-02"],
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
]


@pytest.mark.parametrize(
"good_args", [valid_start_end_container] + valid_start_end_date_args
)
def test_good_container_args(good_args):
args = parse_args(good_args + test_command)
assert args.test_command == test_command


@pytest.mark.parametrize("date_args", valid_start_end_date_args)
def test_bad_container_arg_combinations_across_groups(date_args):
# Can't combine --{start,end}-container with --container/--{start,end}-date
with pytest.raises(Exception):
parse_args(valid_start_end_container + date_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need --container
[],
["--start-date", "2024-10-01"],
["--end-date", "2024-10-02"],
["--start-date", "2024-10-01", "--end-date", "2024-10-02"],
# Need both if either is passed
["--start-container", "start-url"],
["--end-container", "end-url"],
],
)
def test_bad_container_arg_combinations_within_groups(container_args):
with pytest.raises(Exception):
parse_args(container_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need valid ISO dates
["--container", "jax", "--start-date", "a-blue-moon-ago"],
["--container", "jax", "--end-date", "a-year-ago-last-thursday"],
],
)
def test_unparsable_container_args(container_args):
with pytest.raises(SystemExit):
parse_args(container_args + test_command)
16 changes: 12 additions & 4 deletions docs/triage-tool.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The tool follows a three-step process:
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
regression.

The third step can also be used on its own, via the `--start-container` and
`--end-container` options, which allows it to be used between private container tags,
without the dependency on the `ghcr.io/nvidia/jax` registry.
olupton marked this conversation as resolved.
Show resolved Hide resolved

## Installation

The triage tool can be installed using `pip`:
Expand All @@ -46,11 +50,15 @@ needed to execute the test case.

## Usage

To use the tool, there are two compulsory arguments:
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
families to execute the test command in. Example: `jax` for a JAX unit test
failure, `maxtext` for a MaxText model execution failure
To use the tool, there are two compulsory inputs:
* A test command to triage.
* A specification of which containers to triage in:
olupton marked this conversation as resolved.
Show resolved Hide resolved
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
families to execute the test command in. Example: `jax` for a JAX unit test
failure, `maxtext` for a MaxText model execution failure.
* `--start-container` and `--end-container`: a pair of URLs to containers to use
in the commit-level search; if these are passed then no container-level search
is performed.

The test command will be executed directly in the container, not inside a shell, so be
sure not to add excessive quotation marks (*i.e.* run
Expand Down
Loading