Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-toolbox-triage: minor usability/doc improvements #1125

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile


def parse_args():
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="""
Triage failures in JAX/XLA-related tests. The expectation is that the given
Expand Down Expand Up @@ -37,7 +37,6 @@ def parse_args():
help="""
Container to use. Example: jax, pax, triton. Used to construct the URLs of
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
required=True,
)
parser.add_argument(
"--output-prefix",
Expand Down Expand Up @@ -67,6 +66,15 @@ def parse_args():
Command to execute inside the container. This should be as targeted as
possible.""",
)
container_search_args.add_argument(
"--failing-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --passing-container must be too, but --container
is not required. This can be used to apply the commit-level bisection
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
series, although they must have a similar structure.""",
)
container_search_args.add_argument(
"--end-date",
help="""
Expand All @@ -76,6 +84,15 @@ def parse_args():
test case fails on this date.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--passing-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --failing-container must be too, but --container is
not required. This can be used to apply the commit-level bisection search
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
although they must have a similar structure.""",
)
container_search_args.add_argument(
"--start-date",
help="""
Expand Down Expand Up @@ -109,4 +126,30 @@ def parse_args():
significantly speed up the commit-level search. By default, uses a temporary
directory including the name of the current user.""",
)
return parser.parse_args()
args = parser.parse_args(args=args)
num_explicit_containers = (args.passing_container is not None) + (
args.failing_container is not None
)
if num_explicit_containers == 1:
raise Exception(
"--passing-container and --failing-container must both be passed if either is"
)
if num_explicit_containers == 2:
# Explicit mode, --container, --start-date and --end-date are all ignored
if args.container:
raise Exception(
"--container must not be passed if --passing-container and --failing-container are"
)
if args.start_date:
raise Exception(
"--start-date must not be passed if --passing-container and --failing-container are"
)
if args.end_date:
raise Exception(
"--end-date must not be passed if --passing-container and --failing-container are"
)
elif num_explicit_containers == 0 and args.container is None:
raise Exception(
"--container must be passed if --passing-container and --failing-container are not"
)
return args
30 changes: 26 additions & 4 deletions .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from dataclasses import dataclass
import datetime
import functools
import logging
import typing


@dataclass
class TestResult:
"""
Hold the result/stdout/stderr of a test execution
"""

__test__ = False # stop pytest gathering this
result: bool
stdout: typing.Optional[str] = None
stderr: typing.Optional[str] = None


def as_datetime(date: datetime.date) -> datetime.datetime:
return datetime.datetime.combine(date, datetime.time())

Expand Down Expand Up @@ -59,7 +72,7 @@ def adjust_date(
def container_search(
*,
container_exists: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], TestResult],
start_date: typing.Optional[datetime.date],
end_date: typing.Optional[datetime.date],
logger: logging.Logger,
Expand Down Expand Up @@ -88,8 +101,17 @@ def container_search(
logger.info(f"Skipping check for end-of-range failure in {end_date}")
else:
logger.info(f"Checking end-of-range failure in {end_date}")
if container_passes(end_date):
test_end_date = container_passes(end_date)
logger.info(f"stdout: {test_end_date.stdout}")
logger.info(f"stderr: {test_end_date.stderr}")
if test_end_date.result:
raise Exception(f"Could not reproduce failure in {end_date}")
logger.info(
"IMPORTANT: you should check that the test output above shows the "
f"*expected* failure of your test case in the {end_date} container. It is "
"very easy to accidentally provide a test case that fails for the wrong "
"reason, which will not triage the correct issue!"
)

# Start the coarse, container-level, search for a starting point to the bisection range
earliest_failure = end_date
Expand Down Expand Up @@ -127,7 +149,7 @@ def container_search(
logger.info(f"Skipping check that the test passes on start_date={start_date}")
else:
# While condition prints an info message
while not container_passes(search_date):
while not container_passes(search_date).result:
# Test failed on `search_date`, go further into the past
earliest_failure = search_date
new_search_date = adjust(
Expand Down Expand Up @@ -155,7 +177,7 @@ def container_search(
if range_mid is None:
# It wasn't possible to refine further.
break
result = container_passes(range_mid)
result = container_passes(range_mid).result
if result:
range_start = range_mid
else:
Expand Down
48 changes: 30 additions & 18 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .args import parse_args
from .docker import DockerContainer
from .logic import commit_search, container_search
from .logic import commit_search, container_search, TestResult
from .utils import (
container_exists as container_exists_base,
container_url as container_url_base,
Expand All @@ -21,6 +21,10 @@ def main():
args = parse_args()
bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache)
logger = get_logger(args.output_prefix)
logger.info(
"Verbose output, including stdout/err of triage commands, will be written to "
f'{(args.output_prefix / "debug.log").resolve()}'
)
container_url = functools.partial(container_url_base, container=args.container)
container_exists = functools.partial(
container_exists_base, container=args.container, logger=logger
Expand Down Expand Up @@ -75,7 +79,7 @@ def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]:
f"Could not extract commit of {repo} from {args.container} container {container}"
)

def check_container(date: datetime.date) -> bool:
def check_container(date: datetime.date) -> TestResult:
"""
See if the test passes in the given container.
"""
Expand All @@ -100,37 +104,45 @@ def check_container(date: datetime.date) -> bool:
"xla": xla_commit,
},
)
return test_pass

# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)

if args.passing_container is not None:
assert args.failing_container is not None
# Skip the container-level search because explicit end points were given
passing_url = args.passing_container
failing_url = args.failing_container
else:
# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
passing_url = container_url(range_start)
failing_url = container_url(range_end)

# Container-level search is now complete. Triage proceeds inside the `range_end``
# container. First, we check that rewinding JAX and XLA inside the `range_end``
# container to the commits used in the `range_start` container passes, whereas
# using the `range_end` commits reproduces the failure.

with Container(container_url(range_start)) as worker:
with Container(passing_url) as worker:
start_jax_commit, _ = get_commit(worker, "jax")
start_xla_commit, _ = get_commit(worker, "xla")

# Fire up the container that will be used for the fine search.
with Container(container_url(range_end)) as worker:
with Container(failing_url) as worker:
end_jax_commit, jax_dir = get_commit(worker, "jax")
end_xla_commit, xla_dir = get_commit(worker, "xla")
logger.info(
(
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}"
f"XLA [{start_xla_commit}, {end_xla_commit}] using {failing_url}"
)
)

Expand Down
62 changes: 62 additions & 0 deletions .github/triage/tests/test_arg_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from jax_toolbox_triage.args import parse_args

test_command = ["my-test-command"]
valid_start_end_container = [
"--passing-container",
"passing-url",
"--failing-container",
"failing-url",
]
valid_start_end_date_args = [
["--container", "jax"],
["--container", "jax", "--start-date", "2024-10-02"],
["--container", "jax", "--end-date", "2024-10-02"],
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
]


@pytest.mark.parametrize(
"good_args", [valid_start_end_container] + valid_start_end_date_args
)
def test_good_container_args(good_args):
args = parse_args(good_args + test_command)
assert args.test_command == test_command


@pytest.mark.parametrize("date_args", valid_start_end_date_args)
def test_bad_container_arg_combinations_across_groups(date_args):
# Can't combine --{start,end}-container with --container/--{start,end}-date
with pytest.raises(Exception):
parse_args(valid_start_end_container + date_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need --container
[],
["--start-date", "2024-10-01"],
["--end-date", "2024-10-02"],
["--start-date", "2024-10-01", "--end-date", "2024-10-02"],
# Need both if either is passed
["--passing-container", "passing-url"],
["--failing-container", "failing-url"],
],
)
def test_bad_container_arg_combinations_within_groups(container_args):
with pytest.raises(Exception):
parse_args(container_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need valid ISO dates
["--container", "jax", "--start-date", "a-blue-moon-ago"],
["--container", "jax", "--end-date", "a-year-ago-last-thursday"],
],
)
def test_unparsable_container_args(container_args):
with pytest.raises(SystemExit):
parse_args(container_args + test_command)
8 changes: 4 additions & 4 deletions .github/triage/tests/test_triage_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import pytest
import random
from jax_toolbox_triage.logic import commit_search, container_search
from jax_toolbox_triage.logic import commit_search, container_search, TestResult


def wrap(b):
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_container_search_limits(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: dt in dates_that_exist,
container_passes=lambda dt: False,
container_passes=lambda dt: TestResult(result=False),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_container_search_checks(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt in dates_that_pass,
container_passes=lambda dt: TestResult(result=dt in dates_that_pass),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand All @@ -374,7 +374,7 @@ def test_container_search(logger, start_date, days_of_failure, threshold_days):
assert start_date is None or threshold_date >= start_date
good_date, bad_date = container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt < threshold_date,
container_passes=lambda dt: TestResult(result=dt < threshold_date),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down
Loading
Loading