Skip to content

Commit

Permalink
Init test for NMS kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonepia committed Dec 3, 2024
1 parent 8646bf2 commit a8427fb
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
7 changes: 7 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
XPU_NOT_AVAILABLE_MSG = "XPU device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."


Expand Down Expand Up @@ -141,6 +142,12 @@ def needs_mps(test_func):
return pytest.mark.needs_mps(test_func)


def needs_xpu(test_func):
import pytest # noqa

return pytest.mark.needs_xpu(test_func)


def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
Expand Down
8 changes: 8 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
XPU_NOT_AVAILABLE_MSG,
)


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "needs_xpu: mark for tests that rely on a XPU device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
config.addinivalue_line("markers", "opcheck_only_one: only opcheck one parametrization")

Expand All @@ -43,12 +45,18 @@ def pytest_collection_modifyitems(items):
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None
needs_xpu = item.get_closest_marker("needs_xpu") is not None

if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_xpu and not torch.xpu.is_available():
# In general, we skip xpu tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=XPU_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_qnms(self, iou, scale, zero_point):
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
pytest.param("xpu", marks=pytest.mark.needs_xpu),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
Expand Down

0 comments on commit a8427fb

Please sign in to comment.