Skip to content

Commit

Permalink
Reduce flaky tests (#1535
Browse files Browse the repository at this point in the history
* improve lpips
* improve clip test time
* fix
  • Loading branch information
SkafteNicki authored Feb 22, 2023
1 parent 92d15f3 commit 6c3e1a5
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
13 changes: 5 additions & 8 deletions tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

seed_all(42)

TIME = 100
TIME = 25

Input = namedtuple("Input", ["preds", "target"])

Expand All @@ -38,7 +38,7 @@
)


def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, zero_mean: bool):
def bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
if zero_mean:
Expand All @@ -50,10 +50,7 @@ def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, ze
for i in range(preds.shape[0]):
ms = []
for j in range(preds.shape[1]):
if metric_func == mir_eval_bss_eval_images:
snr_v = metric_func([target[i, j]], [preds[i, j]])[0][0]
else:
snr_v = metric_func([target[i, j]], [preds[i, j]])[0][0][0]
snr_v = mir_eval_bss_eval_images([target[i, j]], [preds[i, j]], compute_permutation=True)[0][0]
ms.append(snr_v)
mss.append(ms)
return torch.tensor(mss)
Expand All @@ -65,8 +62,8 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable):
return metric_func(preds, target).mean()


mireval_snr_zeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=True)
mireval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=False)
mireval_snr_zeromean = partial(bss_eval_images_snr, zero_mean=True)
mireval_snr_nozeromean = partial(bss_eval_images_snr, zero_mean=False)


@pytest.mark.parametrize(
Expand Down
18 changes: 13 additions & 5 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)


def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool, reduction: str = "mean") -> Tensor:
def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean") -> Tensor:
"""Comparison function for tm implementation."""
ref = LPIPS_reference(net=net_type)
res = ref(img1, img2, normalize=normalize).detach().cpu().numpy()
Expand All @@ -48,19 +48,18 @@ class TestLPIPS(MetricTester):
atol: float = 1e-4

@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("ddp", [True, False])
def test_lpips(self, net_type, normalize, ddp):
def test_lpips(self, net_type, ddp):
"""test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
preds=_inputs.img1,
target=_inputs.img2,
metric_class=LearnedPerceptualImagePatchSimilarity,
reference_metric=partial(_compare_fn, net_type=net_type, normalize=normalize),
reference_metric=partial(_compare_fn, net_type=net_type),
check_scriptable=False,
check_state_dict=False,
metric_args={"net_type": net_type, "normalize": normalize},
metric_args={"net_type": net_type},
)

def test_lpips_differentiability(self):
Expand All @@ -82,6 +81,15 @@ def test_lpips_half_gpu(self):
self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LearnedPerceptualImagePatchSimilarity)


@pytest.mark.parametrize("normalize", [False, True])
def test_normalize_arg(normalize):
"""Test that normalize argument works as expected."""
metric = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=normalize)
res = metric(_inputs.img1[0], _inputs.img2[1])
res2 = _compare_fn(_inputs.img1[0], _inputs.img2[1], net_type="vgg", normalize=normalize)
assert res == res2


@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
Expand Down
15 changes: 7 additions & 8 deletions tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@

captions = [
"28-year-old chef found dead in San Francisco mall",
"A 28-year-old chef who recently moved to San Francisco was "
"found dead in the staircase of a local shopping center.",
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at "
'him."',
"A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto .",
"A 28-year-old chef who recently moved to San Francisco was found dead.",
"The victim's brother said he cannot imagine anyone who would want to harm him",
"A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto.",
]

_random_input = Input(images=torch.randint(255, (2, 2, 3, 224, 224)), captions=[captions[0:2], captions[2:]])
_random_input = Input(images=torch.randint(255, (2, 2, 3, 64, 64)), captions=[captions[0:2], captions[2:]])


def _compare_fn(preds, target, model_name_or_path):
Expand Down Expand Up @@ -72,6 +70,7 @@ def test_clip_score(self, input, model_name_or_path, ddp):
metric_args={"model_name_or_path": model_name_or_path},
check_scriptable=False,
check_state_dict=False,
check_batch=False,
)

@skip_on_connection_issues()
Expand Down Expand Up @@ -101,7 +100,7 @@ def test_error_on_not_same_amount_of_input(self, input, model_name_or_path):
"""Test that an error is raised if the number of images and text examples does not match."""
metric = CLIPScore(model_name_or_path=model_name_or_path)
with pytest.raises(ValueError, match="Expected the number of images and text examples to be the same.*"):
metric(torch.randint(255, (2, 3, 224, 224)), "28-year-old chef found dead in San Francisco mall")
metric(torch.randint(255, (2, 3, 64, 64)), "28-year-old chef found dead in San Francisco mall")

@skip_on_connection_issues()
def test_error_on_wrong_image_format(self, input, model_name_or_path):
Expand All @@ -110,4 +109,4 @@ def test_error_on_wrong_image_format(self, input, model_name_or_path):
with pytest.raises(
ValueError, match="Expected all images to be 3d but found image that has either more or less"
):
metric(torch.randint(255, (224, 224)), "28-year-old chef found dead in San Francisco mall")
metric(torch.randint(255, (64, 64)), "28-year-old chef found dead in San Francisco mall")
1 change: 1 addition & 0 deletions tests/unittests/text/test_infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class TestInfoLM(TextTester):
atol = 1e-4

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.timeout(240) # download may be too slow for default timeout
@skip_on_connection_issues()
def test_infolm_class(self, ddp, preds, targets, information_measure, idf, alpha, beta):
metric_args = {
Expand Down

0 comments on commit 6c3e1a5

Please sign in to comment.