forked from kornia/kornia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
196 lines (150 loc) · 6.25 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import sys
from itertools import product
from typing import Dict
import numpy as np
import pytest
import torch
import kornia
from kornia.utils._compat import torch_version
def get_test_devices() -> Dict[str, torch.device]:
"""Create a dictionary with the devices to test the source code. CUDA devices will be test only in case the
current hardware supports it.
Return:
dict(str, torch.device): list with devices names.
"""
devices: Dict[str, torch.device] = {}
devices["cpu"] = torch.device("cpu")
if torch.cuda.is_available():
devices["cuda"] = torch.device("cuda:0")
if kornia.xla_is_available():
import torch_xla.core.xla_model as xm
devices["tpu"] = xm.xla_device()
if hasattr(torch.backends, "mps"):
if torch.backends.mps.is_available():
devices["mps"] = torch.device("mps")
return devices
def get_test_dtypes() -> Dict[str, torch.dtype]:
"""Create a dictionary with the dtypes the source code.
Return:
dict(str, torch.dtype): list with dtype names.
"""
dtypes: Dict[str, torch.dtype] = {}
dtypes["bfloat16"] = torch.bfloat16
dtypes["float16"] = torch.float16
dtypes["float32"] = torch.float32
dtypes["float64"] = torch.float64
return dtypes
# setup the devices to test the source code
TEST_DEVICES: Dict[str, torch.device] = get_test_devices()
TEST_DTYPES: Dict[str, torch.dtype] = get_test_dtypes()
# Combinations of device and dtype to be excluded from testing.
# DEVICE_DTYPE_BLACKLIST = {('cpu', 'float16')}
DEVICE_DTYPE_BLACKLIST = {}
@pytest.fixture()
def device(device_name) -> torch.device:
return TEST_DEVICES[device_name]
@pytest.fixture()
def dtype(dtype_name) -> torch.dtype:
return TEST_DTYPES[dtype_name]
@pytest.fixture(scope="session")
def torch_optimizer():
if hasattr(torch, "compile") and sys.platform == "linux":
if not (sys.version_info[:2] == (3, 11) and torch_version() in {"2.0.0", "2.0.1"}):
# torch compile just have support for python 3.11 after torch 2.1.0
return torch.compile
pytest.skip(f"skipped because {torch.__version__} not have `compile` available! Failed to setup dynamo.")
def pytest_generate_tests(metafunc):
device_names = None
dtype_names = None
if "device_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
device_names = list(TEST_DEVICES.keys())
else:
device_names = raw_value.split(",")
if "dtype_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--dtype")
if raw_value == "all":
dtype_names = list(TEST_DTYPES.keys())
else:
dtype_names = raw_value.split(",")
if device_names is not None and dtype_names is not None:
# Exclude any blacklisted device/dtype combinations.
params = [combo for combo in product(device_names, dtype_names) if combo not in DEVICE_DTYPE_BLACKLIST]
metafunc.parametrize("device_name,dtype_name", params)
elif device_names is not None:
metafunc.parametrize("device_name", device_names)
elif dtype_names is not None:
metafunc.parametrize("dtype_name", dtype_names)
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--dtype", action="store", default="float32")
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def _setup_torch_compile():
if hasattr(torch, "compile") and sys.platform == "linux":
print("Setting up torch compile...")
torch.set_float32_matmul_precision("high")
def _dummy_function(x, y):
return (x + y).sum()
class _dummy_module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return (x**2).sum()
torch.compile(_dummy_function)
torch.compile(_dummy_module())
def pytest_sessionstart(session):
try:
_setup_torch_compile()
except RuntimeError as ex:
if "not yet supported for torch.compile" not in str(ex):
raise ex
# TODO: cache all torch.load weights/states here to not impact on test suite
def pytest_report_header(config):
try:
import accelerate
accelerate_info = f"accelerate-{accelerate.__version__}"
except ImportError:
accelerate_info = "`accelerate` not found"
import kornia_rs
import onnx
return f"""
main deps:
- kornia-{kornia.__version__}
- torch-{torch.__version__}
- commit: {torch.version.git_version}
- cuda: {torch.version.cuda}
x deps:
- {accelerate_info}
dev deps:
- kornia_rs-{kornia_rs.__version__}
- onnx-{onnx.__version__}
"""
@pytest.fixture(autouse=True)
def add_doctest_deps(doctest_namespace):
doctest_namespace["np"] = np
doctest_namespace["torch"] = torch
doctest_namespace["kornia"] = kornia
# the commit hash for the data version
sha: str = "cb8f42bf28b9f347df6afba5558738f62a11f28a"
sha2: str = "f7d8da661701424babb64850e03c5e8faec7ea62"
sha3: str = "8b98f44abbe92b7a84631ed06613b08fee7dae14"
@pytest.fixture(scope="session")
def data(request):
url = {
"loftr_homo": f"https://github.com/kornia/data_test/blob/{sha}/loftr_outdoor_and_homography_data.pt?raw=true",
"loftr_fund": f"https://github.com/kornia/data_test/blob/{sha}/loftr_indoor_and_fundamental_data.pt?raw=true",
"adalam_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"lightglue_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"disk_outdoor": f"https://github.com/kornia/data_test/blob/{sha3}/knchurch_disk.pt?raw=true",
"dexined": "https://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth",
}
return torch.hub.load_state_dict_from_url(url[request.param], map_location=torch.device("cpu"))