Skip to content

Commit

Permalink
👌 Use value difference to check data_vars
Browse files Browse the repository at this point in the history
This improves sensible usage of errors caused by floating-point inaccuracy and take SVD vector scaling with SV into account
  • Loading branch information
s-weigand authored and jsnel committed Aug 25, 2021
1 parent 2dd8085 commit 40b5a03
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions .github/test_result_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os
import re
import subprocess
from collections import defaultdict
from functools import lru_cache
Expand Down Expand Up @@ -31,12 +32,14 @@
"ex_spectral_guidance",
]

SVD_PATTERN = re.compile(r"(?P<pre_fix>.+?)(right|left)_singular_vectors")


class AllCloseFixture(Protocol):
def __call__(
self,
a: np.ndarray | xr.DataArray,
b: np.ndarray | xr.DataArray,
a: float | np.ndarray | xr.DataArray,
b: float | np.ndarray | xr.DataArray,
rtol: float = 1e-5,
atol: float = 1e-8,
xtol: int = 0,
Expand Down Expand Up @@ -130,6 +133,10 @@ def coord_test(
f"data_var {data_var_name!r}"
)

# assert (
# expected_coord_value.dims == current_coords.dims
# ), f"Dimensions mismatch in {data_var_name!r}"

if exact_match or expected_coord_value.data.dtype == object:
assert np.array_equal(
expected_coord_value, current_coords[expected_coord_name]
Expand Down Expand Up @@ -274,26 +281,56 @@ def test_result_data_var_consistency(
expected_var_name in current_result.data_vars
), f"Missing data_var: {expected_var_name!r} in {file_name!r}"
current_data = current_result.data_vars[expected_var_name]
expected_values = expected_var_value
current_values = current_data

abs_tol = 1e-8 # default value
rtol = 1e-8

# due to platform specific differences in the low level code that generates the SVD
# the values might differ from expected which were generated on linux
eps = np.finfo(float).eps
if "singular_vectors" in expected_var_name: # type:ignore[operator]
abs_tol = 1e-5
rtol = 1e-5
pre_fix = SVD_PATTERN.match(expected_var_name).group( # type:ignore[operator]
"pre_fix"
)
expected_singular_values = expected_result.data_vars[
f"{pre_fix}singular_values"
]

if expected_var_value.shape[0] == expected_singular_values.shape[0]:
expected_values_scaled = np.diag(expected_singular_values).dot(
expected_var_value.data
)
else:
expected_values_scaled = expected_var_value.data.dot(
np.diag(expected_singular_values)
)

float_resolution = np.maximum(
eps * expected_values_scaled,
np.ones(expected_var_value.data.shape) * 2.0 * eps,
)
else:
float_resolution = np.maximum(
eps * expected_var_value.data,
np.ones(expected_var_value.data.shape) * 2.0 * eps,
)
abs_diff = np.abs(expected_values - current_values)

assert allclose(
expected_var_value.data,
current_data.data,
atol=abs_tol,
rtol=1e-3,
abs_diff,
float_resolution,
atol=rtol, # we compare the difference so atol -> rtol
print_fail=20,
), f"Result data_var data mismatch: {expected_var_name!r} in {file_name!r}"
), (
f"Result data_var data mismatch: {expected_var_name!r} in {file_name!r}.\n"
"With sum of absolute difference: "
f"{float(np.sum(abs_diff))} and shape: {expected_var_value.shape}"
)

coord_test(
expected_var_value.coords,
current_data.coords,
file_name,
allclose,
data_var_name=expected_var_name,
data_var_name=expected_var_name, # type:ignore[operator]
)

0 comments on commit 40b5a03

Please sign in to comment.