Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for specifying the order of population labels in Breakpoints.encode() #262

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions haptools/data/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,31 @@ def __iter__(self, samples: set[str] = None) -> Iterable[str, SampleBlocks]:
yield samp, [np.array(b, dtype=HapBlock) for b in blocks]
bps.close()

def encode(self) -> dict[int, str]:
def encode(self, labels: tuple[str] = None):
"""
Replace each ancestral label in :py:attr:`~.Breakpoints.data` with an
equivalent integer. Store a dictionary mapping these integers back to their
respective labels.
respective labels in :py:attr:`~.Breakpoints.labels`.

This method modifies :py:attr:`~.Breakpoints.data` in place.

Returns
-------
dict[int, str]
A dictionary mapping each integer back to its ancestral label
Parameters
----------
labels: tuple[str], optional
A list of population labels. The order of the labels in this list will be
kept in the respective labels.
"""
if not (self.labels is None):
raise ValueError("The data has already been encoded.")
# save the order of the fields for later reordering
names = [f[0] for f in HapBlock]
# initialize labels dict and label counter
labels = {}
pop_count = 0
if labels is None:
labels = {}
else:
labels = {pop: i for i, pop in enumerate(labels)}
pop_count = len(labels)
seen = set()
for sample, blocks in self.data.items():
for strand_num in range(len(blocks)):
# initialize and fill the array of integers
Expand All @@ -189,10 +194,11 @@ def encode(self) -> dict[int, str]:
labels[pop] = pop_count
pop_count += 1
ints[i] = labels[pop]
seen.add(pop)
# replace the "pop" labels
arr = rcf.drop_fields(blocks[strand_num], ["pop"])
blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names]
self.labels = labels
self.labels = {k: v for k, v in labels.items() if k in seen}

def recode(self):
"""
Expand Down Expand Up @@ -332,6 +338,7 @@ def write(self):
--------
To write to a file, you must first initialize a Breakpoints object and then
fill out the names, data, and samples properties:

>>> from haptools.data import Breakpoints, HapBlock
>>> breakpoints = Breakpoints('simple.bp')
>>> breakpoints.data = {
Expand Down
34 changes: 34 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,40 @@ def test_encode(self):
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

def test_encode_reorder(self):
expected = self._get_expected_breakpoints()
expected.labels = {"CEU": 0, "YRI": 1}

observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "YRI", "AMR"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
for sample in expected.data:
for strand in range(len(expected.data[sample])):
exp_strand = expected.data[sample][strand]
obs_strand = observed.data[sample][strand]
assert len(exp_strand) == len(observed.data[sample][strand])
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

# now try again with AMR in the middle
# In that case, it should keep the ordering when deciding the integers
# but the final labels should include the AMR key
expected.labels = {"CEU": 0, "YRI": 2}
observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "AMR", "YRI"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
for sample in expected.data:
for strand in range(len(expected.data[sample])):
exp_strand = expected.data[sample][strand]
obs_strand = observed.data[sample][strand]
assert len(exp_strand) == len(observed.data[sample][strand])
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

def test_recode(self):
expected = self._get_expected_breakpoints()

Expand Down
Loading