Skip to content

Commit

Permalink
jax-toolbox-triage: minor usability/doc improvements
Browse files Browse the repository at this point in the history
- Print the stdout/stderr of the first execution of the test case, which
  is supposed to fail, at INFO level along with a message encouraging
  the user to check that it is the correct failure.
- Print the path to the DEBUG log file at INFO level and, therefore, to
  the console.
- Expand the documentation.
  • Loading branch information
olupton committed Oct 29, 2024
1 parent bde47a4 commit 1e95fa5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
29 changes: 25 additions & 4 deletions .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from dataclasses import dataclass
import datetime
import functools
import logging
import typing


@dataclass
class TestResult:
"""
Hold the result/stdout/stderr of a test execution
"""

result: bool
stdout: typing.Optional[str]
stderr: typing.Optional[str]


def as_datetime(date: datetime.date) -> datetime.datetime:
return datetime.datetime.combine(date, datetime.time())

Expand Down Expand Up @@ -59,7 +71,7 @@ def adjust_date(
def container_search(
*,
container_exists: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], TestResult],
start_date: typing.Optional[datetime.date],
end_date: typing.Optional[datetime.date],
logger: logging.Logger,
Expand Down Expand Up @@ -88,8 +100,17 @@ def container_search(
logger.info(f"Skipping check for end-of-range failure in {end_date}")
else:
logger.info(f"Checking end-of-range failure in {end_date}")
if container_passes(end_date):
test_end_date = container_passes(end_date)
logger.info(f"stdout: {test_end_date.stdout}")
logger.info(f"stderr: {test_end_date.stderr}")
if test_end_date.result:
raise Exception(f"Could not reproduce failure in {end_date}")
logger.info(
"IMPORTANT: you should check that the test output above shows the "
f"*expected* failure of your test case in the {end_date} container. It is "
"very easy to accidentally provide a test case that fails for the wrong "
"reason, which will not triage the correct issue!"
)

# Start the coarse, container-level, search for a starting point to the bisection range
earliest_failure = end_date
Expand Down Expand Up @@ -127,7 +148,7 @@ def container_search(
logger.info(f"Skipping check that the test passes on start_date={start_date}")
else:
# While condition prints an info message
while not container_passes(search_date):
while not container_passes(search_date).result:
# Test failed on `search_date`, go further into the past
earliest_failure = search_date
new_search_date = adjust(
Expand Down Expand Up @@ -155,7 +176,7 @@ def container_search(
if range_mid is None:
# It wasn't possible to refine further.
break
result = container_passes(range_mid)
result = container_passes(range_mid).result
if result:
range_start = range_mid
else:
Expand Down
10 changes: 7 additions & 3 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .args import parse_args
from .docker import DockerContainer
from .logic import commit_search, container_search
from .logic import commit_search, container_search, TestResult
from .utils import (
container_exists as container_exists_base,
container_url as container_url_base,
Expand All @@ -21,6 +21,10 @@ def main():
args = parse_args()
bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache)
logger = get_logger(args.output_prefix)
logger.info(
"Verbose output, including stdout/err of triage commands, will be written to "
f'{(args.output_prefix / "debug.log").resolve()}'
)
container_url = functools.partial(container_url_base, container=args.container)
container_exists = functools.partial(
container_exists_base, container=args.container, logger=logger
Expand Down Expand Up @@ -75,7 +79,7 @@ def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]:
f"Could not extract commit of {repo} from {args.container} container {container}"
)

def check_container(date: datetime.date) -> bool:
def check_container(date: datetime.date) -> TestResult:
"""
See if the test passes in the given container.
"""
Expand All @@ -100,7 +104,7 @@ def check_container(date: datetime.date) -> bool:
"xla": xla_commit,
},
)
return test_pass
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)

# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
Expand Down
8 changes: 4 additions & 4 deletions .github/triage/tests/test_triage_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import pytest
import random
from jax_toolbox_triage.logic import commit_search, container_search
from jax_toolbox_triage.logic import commit_search, container_search, TestResult


def wrap(b):
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_container_search_limits(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: dt in dates_that_exist,
container_passes=lambda dt: False,
container_passes=lambda dt: TestResult(result=False),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_container_search_checks(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt in dates_that_pass,
container_passes=lambda dt: TestResult(result=dt in dates_that_pass),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand All @@ -374,7 +374,7 @@ def test_container_search(logger, start_date, days_of_failure, threshold_days):
assert start_date is None or threshold_date >= start_date
good_date, bad_date = container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt < threshold_date,
container_passes=lambda dt: TestResult(result=dt < threshold_date),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down
36 changes: 36 additions & 0 deletions docs/triage-tool.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,20 @@ The triage tool can be installed using `pip`:
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
```
or directly from a checkout of the JAX-Toolbox repository.

Because the tool needs to orchestrate running commands in multiple containers, it is
most convenient to install it in a virtual environment on the host system, rather than
attempting to install it inside a container.

The recommended installation method is to install `virtualenv` natively on the host
system, and then use that to create an isolated environment on the host system for the
triage tool, *i.e.*:
```bash
virtualenv triage-venv
./triage-venv/bin/pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
./triage-venv/bin/jax-toolbox-triage ...
```

The tool should be invoked on a machine with `docker` available and whatever GPUs are
needed to execute the test case.

Expand All @@ -47,6 +57,10 @@ sure not to add excessive quotation marks (*i.e.* run
`jax-toolbox-triage --container=jax test-jax.sh foo` not
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
as fast and targeted as possible.

If you want to run multiple commands, you might want to use something like
`jax-toolbox-triage --container=jax sh -c "command1 && command2"`.

The expectation is that the test case will be executed successfully several times as
part of the triage, so you may want to tune some parameters to reduce the execution
time in the successful case.
Expand All @@ -55,6 +69,28 @@ probably reduce `--steps` to optimise execution time in the successful case.

A JSON status file and both info-level and debug-level logfiles are written to the
directory given by `--output-prefix`.
Info-level output is also written to the console, and includes the path to the debug
log file.

You should pay attention to the first execution of your test case, to make sure it is
failing for the correct reason. For example:
```console
$ jax-toolbox-triage --container jax command-you-forgot-to-install
```
will not immediately abort, because the tool is **expecting** the command to fail in
the early stages of the triage:
```
[INFO] 2024-10-29 01:49:01 Verbose output, including stdout/err of triage commands, will be written to /home/olupton/JAX-Toolbox/triage-2024-10-29-01-49-01/debug.log
[INFO] 2024-10-29 01:49:05 Checking end-of-range failure in 2024-10-27
[INFO] 2024-10-29 01:49:05 Ran test case in 2024-10-27 in 0.4s, pass=False
[INFO] 2024-10-29 01:49:05 stdout: OCI runtime exec failed: exec failed: unable to start container process: exec: "command-you-forgot-to-install": executable file not found in $PATH: unknown
[INFO] 2024-10-29 01:49:05 stderr:
[INFO] 2024-10-29 01:49:05 IMPORTANT: you should check that the test output above shows the *expected* failure of your test case in the 2024-10-27 container. It is very easy to accidentally provide a test case that fails for the wrong reason, which will not triage the correct issue!
[INFO] 2024-10-29 01:49:06 Starting coarse search with 2024-10-26 based on end_date=2024-10-27
[INFO] 2024-10-29 01:49:06 Ran test case in 2024-10-26 in 0.4s, pass=False
```
where, notably, the triage search is continuing.

### Optimising container-level search performance

Expand Down

0 comments on commit 1e95fa5

Please sign in to comment.