diff --git a/haptools/data/covariates.py b/haptools/data/covariates.py index 47de25e4..3aa2e50a 100644 --- a/haptools/data/covariates.py +++ b/haptools/data/covariates.py @@ -28,5 +28,5 @@ class Covariates(Phenotypes): """ def __init__(self, fname: Path | str, log: Logger = None): - super(Phenotypes, self).__init__(fname, log) + super().__init__(fname, log) self._ext = "covar" diff --git a/tests/test_data.py b/tests/test_data.py index 3a7e1752..bcfb0d25 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -774,6 +774,40 @@ def test_load_covariates_subset(self): np.testing.assert_allclose(covars.data, expected) assert covars.samples == tuple(samples) + def test_subset_covariates(self): + cvs = self._get_fake_covariates() + + # subset to just the samples we want + expected_data = cvs.data[:3] + expected_names = cvs.names + samples = ("HG00096", "HG00097", "HG00099") + cvs_sub = cvs.subset(samples=samples) + assert cvs_sub.samples == samples + np.testing.assert_allclose(cvs_sub.data, expected_data) + assert np.array_equal(cvs_sub.names, expected_names) + + # subset to just the names we want + expected_data = cvs.data[:, [1]] + assert len(expected_data.shape) == 2 + expected_names = (cvs.names[1],) + names = ("age",) + cvs_sub = cvs.subset(names=names) + assert cvs_sub.samples == cvs.samples + np.testing.assert_allclose(cvs_sub.data, expected_data) + assert np.array_equal(cvs_sub.names, expected_names) + + # subset both: samples and names + expected_data = cvs.data[[3, 4], [1]] + expected_data = expected_data[:, np.newaxis] + assert len(expected_data.shape) == 2 + expected_names = (cvs.names[1],) + samples = ("HG00100", "HG00101") + names = ("age",) + cvs_sub = cvs.subset(samples=samples, names=names) + assert cvs_sub.samples == samples + np.testing.assert_allclose(cvs_sub.data, expected_data) + assert np.array_equal(cvs_sub.names, expected_names) + class TestHaplotypes: def _basic_haps(self):