Skip to content

Commit

Permalink
replace all get_covmats with get_mats in tests for utils (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbarthelemy authored Oct 11, 2023
1 parent a10af0b commit 3c5db03
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 90 deletions.
4 changes: 2 additions & 2 deletions tests/test_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def test_check_raise():
mean_riemann(C)


def test_nearest_sym_pos_def(get_covmats):
def test_nearest_sym_pos_def(get_mats):
n_matrices = 3
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
D = mats.diagonal(axis1=1, axis2=2)
psd = np.array([mat - np.diag(d) for mat, d in zip(mats, D)])

Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,12 @@ def test_normalize_shapes(norm, rndstate):
assert mat.shape == mat_n.shape


def test_normalize_values(rndstate, get_covmats):
def test_normalize_values(rndstate, get_mats):
"""Test normalize values"""
n_matrices, n_channels = 20, 3

# after corr-normalization => diags = 1 and values in [-1, 1]
mat = get_covmats(n_channels, n_channels)
mat = get_mats(n_channels, n_channels, "spd")
mat_cn = normalize(mat, "corr")
assert_array_almost_equal(np.ones(mat_cn.shape[:-1]),
np.diagonal(mat_cn, axis1=-2, axis2=-1))
Expand Down
40 changes: 20 additions & 20 deletions tests/test_utils_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_distances_metric(kind, metric, dist, get_mats):
assert np.isreal(d)


def test_distances_metric_error(get_covmats):
def test_distances_metric_error(get_mats):
n_matrices, n_channels = 2, 2
A = get_covmats(n_matrices, n_channels)
A = get_mats(n_matrices, n_channels, "spd")
with pytest.raises(ValueError):
distance(A[0], A[1], metric="universe")
with pytest.raises(ValueError):
Expand All @@ -89,18 +89,18 @@ def test_distances_squared(kind, dist, get_mats):


@pytest.mark.parametrize("dist", get_dist_func())
def test_distances_all_error(dist, get_covmats):
def test_distances_all_error(dist, get_mats):
n_matrices, n_channels = 3, 3
A = get_covmats(n_matrices, n_channels)
A = get_mats(n_matrices, n_channels, "spd")
with pytest.raises(ValueError):
dist(A, A[0])


@pytest.mark.parametrize("dist", get_dist_func())
def test_distances_all_ndarray(dist, get_covmats):
def test_distances_all_ndarray(dist, get_mats):
n_matrices, n_channels = 5, 3
A = get_covmats(n_matrices, n_channels)
B = get_covmats(n_matrices, n_channels)
A = get_mats(n_matrices, n_channels, "spd")
B = get_mats(n_matrices, n_channels, "spd")
assert isinstance(dist(A[0], B[0]), float) # 2D arrays
assert dist(A, B).shape == (n_matrices,) # 3D arrays

Expand Down Expand Up @@ -177,18 +177,18 @@ def test_distance_harmonic(kind, get_mats):
distance_harmonic(A, B)


def test_distance_kullback_implementation(get_covmats):
def test_distance_kullback_implementation(get_mats):
n_matrices, n_channels = 2, 6
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
A, B = mats[0], mats[1]
d = 0.5*(np.trace(np.linalg.inv(B) @ A) - n_channels
+ np.log(np.linalg.det(B) / np.linalg.det(A)))
assert distance_kullback(A, B) == approx(d)


def test_distance_logdet_implementation(get_covmats):
def test_distance_logdet_implementation(get_mats):
n_matrices, n_channels = 2, 6
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
A, B = mats[0], mats[1]
d = np.sqrt(np.log(np.linalg.det((A + B) / 2.0))
- 0.5 * np.log(np.linalg.det(A)*np.linalg.det(B)))
Expand Down Expand Up @@ -223,17 +223,17 @@ def test_distance_riemann_properties(kind, get_mats):


@pytest.mark.parametrize("dist, dfunc", zip(get_distances(), get_dist_func()))
def test_distance_wrapper(dist, dfunc, get_covmats):
def test_distance_wrapper(dist, dfunc, get_mats):
n_matrices, n_channels = 2, 5
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
A, B = mats[0], mats[1]
assert distance(A, B, metric=dist) == dfunc(A, B)


@pytest.mark.parametrize("dist", get_dist_func())
def test_distance_wrapper_between_set_and_matrix(dist, get_covmats):
def test_distance_wrapper_between_set_and_matrix(dist, get_mats):
n_matrices, n_channels = 10, 4
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
assert distance(mats, mats[-1], metric=dist).shape == (n_matrices, 1)

n_sets = 5
Expand All @@ -245,14 +245,14 @@ def test_distance_wrapper_between_set_and_matrix(dist, get_covmats):
@pytest.mark.parametrize("dist", get_distances())
@pytest.mark.parametrize("Y", [None, True])
@pytest.mark.parametrize("squared", [False, True])
def test_pairwise_distance_matrix(get_covmats, dist, Y, squared):
def test_pairwise_distance_matrix(get_mats, dist, Y, squared):
n_matrices_X, n_matrices_Y, n_channels = 6, 4, 5
X = get_covmats(n_matrices_X, n_channels)
X = get_mats(n_matrices_X, n_channels, "spd")
if Y is None:
n_matrices_Y = n_matrices_X
Y_ = X
else:
Y = get_covmats(n_matrices_Y, n_channels)
Y = get_mats(n_matrices_Y, n_channels, "spd")
Y_ = Y

pdist = pairwise_distance(X, Y, metric=dist, squared=squared)
Expand Down Expand Up @@ -286,11 +286,11 @@ def test_distance_mahalanobis(rndstate, complex_valued):


@pytest.mark.parametrize("mean", [True, None])
def test_distance_mahalanobis_scipy(rndstate, get_covmats, mean):
def test_distance_mahalanobis_scipy(rndstate, get_mats, mean):
"""Test equivalence between pyriemann and scipy for real data"""
n_channels, n_times = 3, 100
X = rndstate.randn(n_channels, n_times)
C = get_covmats(1, n_channels)[0]
C = get_mats(1, n_channels, "spd")[0]

Cinv = np.linalg.inv(C)
y = np.zeros(n_channels)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_utils_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def test_geodesic_euclid(rndstate, complex_valued):


@pytest.mark.parametrize("metric", get_geod_name())
def test_geodesic_wrapper_ndarray(metric, get_covmats):
def test_geodesic_wrapper_ndarray(metric, get_mats):
n_matrices, n_channels = 5, 3
A = get_covmats(n_matrices, n_channels)
B = get_covmats(n_matrices, n_channels)
A = get_mats(n_matrices, n_channels, "spd")
B = get_mats(n_matrices, n_channels, "spd")
assert geodesic(A[0], B[0], .3, metric=metric).shape == A[0].shape
assert geodesic(A, B, .2, metric=metric).shape == A.shape # 3D arrays

Expand All @@ -109,9 +109,9 @@ def test_geodesic_wrapper_simple(metric):


@pytest.mark.parametrize("metric, gfun", zip(get_geod_name(), get_geod_func()))
def test_geodesic_wrapper_random(metric, gfun, get_covmats):
def test_geodesic_wrapper_random(metric, gfun, get_mats):
n_matrices, n_channels = 2, 5
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
A, B = mats[0], mats[1]
if gfun is geodesic_euclid:
Ctrue = mean_euclid(mats)
Expand Down
34 changes: 17 additions & 17 deletions tests/test_utils_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,62 @@


@pytest.mark.parametrize("ker", rker_fct)
def test_kernel_x_x(ker, get_covmats):
def test_kernel_x_x(ker, get_mats):
"""Test kernel build"""
n_matrices, n_channels = 7, 3
X = get_covmats(n_matrices, n_channels)
X = get_mats(n_matrices, n_channels, "spd")
K = ker(X, X)
assert K.shape == (n_matrices, n_matrices)
assert is_spsd(K)
assert_array_almost_equal(K, ker(X))


@pytest.mark.parametrize("ker", rker_str)
def test_kernel_cref(ker, get_covmats):
def test_kernel_cref(ker, get_mats):
"""Test kernel reference"""
n_matrices, n_channels = 5, 3
X = get_covmats(n_matrices, n_channels)
X = get_mats(n_matrices, n_channels, "spd")
cref = mean_covariance(X, metric=ker)
K = kernel(X, X, metric=ker)
K1 = kernel(X, X, Cref=cref, metric=ker)
assert_array_equal(K, K1)


@pytest.mark.parametrize("ker", rker_str)
def test_kernel_x_y(ker, get_covmats):
def test_kernel_x_y(ker, get_mats):
"""Test kernel for different X and Y"""
n_matrices_X, n_matrices_Y, n_channels = 6, 5, 3
X = get_covmats(n_matrices_X, n_channels)
Y = get_covmats(n_matrices_Y, n_channels)
X = get_mats(n_matrices_X, n_channels, "spd")
Y = get_mats(n_matrices_Y, n_channels, "spd")
K = kernel(X, Y, metric=ker)
assert K.shape == (n_matrices_X, n_matrices_Y)


@pytest.mark.parametrize("ker", rker_str)
def test_metric_string(ker, get_covmats):
def test_metric_string(ker, get_mats):
"""Test generic kernel function"""
n_matrices, n_channels = 5, 3
X = get_covmats(n_matrices, n_channels)
X = get_mats(n_matrices, n_channels, "spd")
K = globals()[f'kernel_{ker}'](X)
K1 = kernel(X, metric=ker)
assert_array_equal(K, K1)


def test_metric_string_error(get_covmats):
def test_metric_string_error(get_mats):
"""Test generic kernel function error raise"""
n_matrices, n_channels = 5, 3
X = get_covmats(n_matrices, n_channels)
X = get_mats(n_matrices, n_channels, "spd")
with pytest.raises(ValueError):
kernel(X, metric='foo')


@pytest.mark.parametrize("ker", rker_str)
def test_input_dimension_error(ker, get_covmats):
def test_input_dimension_error(ker, get_mats):
"""Test errors for incorrect dimension"""
n_matrices, n_channels = 5, 3
X = get_covmats(n_matrices, n_channels)
Y = get_covmats(n_matrices, n_channels + 1)
cref = get_covmats(1, n_channels + 1)[0]
X = get_mats(n_matrices, n_channels, "spd")
Y = get_mats(n_matrices, n_channels + 1, "spd")
cref = get_mats(1, n_channels + 1, "spd")[0]
if ker == 'riemann':
with pytest.raises(AssertionError):
kernel(X, Cref=cref, metric=ker)
Expand All @@ -96,10 +96,10 @@ def test_euclid(n_dim0, n_dim1, rndstate):
assert_array_almost_equal(K, K1)


def test_riemann_correctness(get_covmats):
def test_riemann_correctness(get_mats):
"""Test Riemannian kernel correctness"""
n_matrices, n_channels = 5, 3
X = get_covmats(n_matrices, n_channels)
X = get_mats(n_matrices, n_channels, "spd")
K = kernel_riemann(X, Cref=np.eye(n_channels), reg=0)

log_X = logm(X)
Expand Down
36 changes: 18 additions & 18 deletions tests/test_utils_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def test_mean_weight_zero(kind, mean, get_mats):
nanmean_riemann,
],
)
def test_mean_weight_len_error(mean, get_covmats):
def test_mean_weight_len_error(mean, get_mats):
n_matrices, n_channels = 3, 2
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
with pytest.raises(ValueError):
mean(mats, sample_weight=np.ones(n_matrices + 1))

Expand All @@ -120,10 +120,10 @@ def test_mean_weight_len_error(mean, get_covmats):
nanmean_riemann
]
)
def test_mean_warning_convergence(mean, get_covmats):
def test_mean_warning_convergence(mean, get_mats):
"""Test warning for convergence not reached """
n_matrices, n_channels = 3, 2
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
with pytest.warns(UserWarning):
if mean == mean_power:
mean(mats, 0.3, maxiter=0)
Expand Down Expand Up @@ -197,10 +197,10 @@ def test_mean_euclid(rndstate, complex_valued):
assert mean_euclid(mats) == approx(mats.mean(axis=0))


def test_mean_identity(get_covmats):
def test_mean_identity(get_mats):
"""Test the identity mean"""
n_matrices, n_channels = 2, 3
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
C = mean_identity(mats)
assert np.all(C == np.eye(n_channels))

Expand All @@ -215,10 +215,10 @@ def test_mean_power(kind, get_mats):
assert mean_power(mats, -1) == approx(mean_harmonic(mats))


def test_mean_power_errors(get_covmats):
def test_mean_power_errors(get_mats):
"""Test the power mean errors"""
n_matrices, n_channels = 3, 2
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")

with pytest.raises(ValueError): # exponent is not a scalar
mean_power(mats, [1])
Expand Down Expand Up @@ -259,10 +259,10 @@ def test_mean_riemann_properties(kind, get_mats):


@pytest.mark.parametrize("init", [True, False])
def test_mean_masked_riemann_shape(init, get_covmats, get_masks):
def test_mean_masked_riemann_shape(init, get_mats, get_masks):
"""Test the masked Riemannian mean"""
n_matrices, n_channels = 5, 3
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
masks = get_masks(n_matrices, n_channels)
if init:
C = maskedmean_riemann(mats, masks, tol=10e-3, init=mats[0])
Expand All @@ -272,10 +272,10 @@ def test_mean_masked_riemann_shape(init, get_covmats, get_masks):


@pytest.mark.parametrize("init", [True, False])
def test_mean_nan_riemann_shape(init, get_covmats, rndstate):
def test_mean_nan_riemann_shape(init, get_mats, rndstate):
"""Test the Riemannian NaN-mean"""
n_matrices, n_channels = 10, 6
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
emean = np.mean(mats, axis=0)
for i in range(n_matrices):
corrup_channels = rndstate.choice(
Expand All @@ -290,10 +290,10 @@ def test_mean_nan_riemann_shape(init, get_covmats, rndstate):
assert C.shape == (n_channels, n_channels)


def test_mean_nan_riemann_errors(get_covmats):
def test_mean_nan_riemann_errors(get_mats):
"""Test the Riemannian NaN-mean errors"""
n_matrices, n_channels = 5, 4
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")

with pytest.raises(ValueError): # not symmetric NaN values
mats_ = mats.copy()
Expand Down Expand Up @@ -325,19 +325,19 @@ def callable_np_average(X, sample_weight=None):
(callable_np_average, mean_euclid),
],
)
def test_mean_covariance_metric(metric, mean, get_covmats):
def test_mean_covariance_metric(metric, mean, get_mats):
"""Test mean_covariance for metric"""
n_matrices, n_channels = 3, 3
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
C = mean_covariance(mats, metric=metric)
Ctrue = mean(mats)
assert np.all(C == Ctrue)


def test_mean_covariance_args(get_covmats):
def test_mean_covariance_args(get_mats):
"""Test mean_covariance with different arguments"""
n_matrices, n_channels = 3, 3
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
mean_covariance(mats, metric='ale', maxiter=5)
mean_covariance(mats, metric='logdet', tol=10e-3)
mean_covariance(mats, metric='riemann', init=np.eye(n_channels))
8 changes: 4 additions & 4 deletions tests/test_utils_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_median_weight_zero(kind, median, get_mats):


@pytest.mark.parametrize("median", [median_euclid, median_riemann])
def test_median_warning_convergence(median, get_covmats):
def test_median_warning_convergence(median, get_mats):
"""Test warning for convergence not reached"""
n_matrices, n_channels = 3, 2
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
with pytest.warns(UserWarning):
median(mats, maxiter=0)

Expand All @@ -70,8 +70,8 @@ def test_median_euclid(rndstate, complex_valued):


@pytest.mark.parametrize("step_size", [0, 2.5])
def test_median_riemann_stepsize_error(step_size, get_covmats):
def test_median_riemann_stepsize_error(step_size, get_mats):
n_matrices, n_channels = 1, 2
mats = get_covmats(n_matrices, n_channels)
mats = get_mats(n_matrices, n_channels, "spd")
with pytest.raises(ValueError):
median_riemann(mats, step_size=step_size)
Loading

0 comments on commit 3c5db03

Please sign in to comment.