Skip to content

Commit

Permalink
🚧 add types and split up large function
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Aug 7, 2024
1 parent 7c2def6 commit c305fe8
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 96 deletions.
253 changes: 161 additions & 92 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
from textwrap import dedent
from typing import Collection
from typing import Collection, Dict, Iterable, List, Optional, Tuple

from augur.dates import get_year_month, get_year_week
from augur.errors import AugurError
Expand All @@ -14,6 +14,9 @@
from . import constants
from .weights_file import WEIGHTS_COLUMN, get_weighted_columns, read_weights_file

Group = Tuple[str, ...]
"""Combinations of grouping column values in tuple form."""


def get_groups_for_subsampling(strains, metadata, group_by=None):
"""Return a list of groups for each given strain based on the corresponding
Expand Down Expand Up @@ -301,123 +304,189 @@ def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None):
return max_sizes_per_group


def get_weighted_group_sizes(groups, group_by, weights_file, target_total_size, output_missing_weights, output_sizes_file, random_seed):
"""Return group sizes based on weights defined in ``weights_file``.
TARGET_SIZE_COLUMN = '_augur_filter_target_size'

Returns
-------
dict :
Mapping between groups (combinations of grouping column values in tuple
form) to group sizes

def get_weighted_group_sizes(
groups: Collection[Group],
group_by: List[str],
weights_file: str,
target_total_size: int,
output_missing_weights: Optional[str],
output_sizes_file: Optional[str],
random_seed: Optional[int],
) -> Dict[Group, int]:
"""Return target group sizes based on weights defined in ``weights_file``.
"""
weights = read_weights_file(weights_file)

weighted_columns = get_weighted_columns(weights_file)

# Allow other columns in group_by to be equally weighted (uniform sampling)
# Other columns in group_by are considered unweighted.
unweighted_columns = list(set(group_by) - set(weighted_columns))

if unweighted_columns:
# Augment the weights DataFrame with equal weighting for unweighted columns.

# Get unique values for each unweighted column.
values_for_unweighted_columns = defaultdict(set)
for group in groups:
# NOTE: The ordering of entries in `group` corresponds to the column
# names in `group_by`, but only because `get_groups_for_subsampling`
# conveniently retains the order. This could be more tightly coupled,
# but it works.
column_to_value_map = dict(zip(group_by, group))
for column in unweighted_columns:
values_for_unweighted_columns[column].add(column_to_value_map[column])

# Create a DataFrame for all permutations of values in unweighted columns.
lists = [list(values_for_unweighted_columns[column]) for column in unweighted_columns]
unweighted_permutations = pd.DataFrame(list(itertools.product(*lists)), columns=unweighted_columns)

# Add the unweighted columns to the weights DataFrame.
# This extends the existing weights to the unweighted groups but with
# the side effect of weighting the values *alongside* (rather than within)
# each weighted group.
# After dropping unused groups, these weights will be adjusted to ensure
# equal weighting of unweighted columns *within* each weighted group
# defined by the weighted columns.
weights = pd.merge(unweighted_permutations, weights, how='cross')

# Drop any groups that don't appear in metadata.
# This has the side effect of weighting the values *alongside* (rather
# than within) each weighted group. After dropping unused groups, these
# weights should be adjusted to ensure equal weighting of unweighted
# columns *within* each weighted group defined by the weighted columns.
weights = _add_unweighted_columns(weights, groups, group_by, unweighted_columns)

# This must be done even if all columns are weighted.
weights = _drop_unused_groups(weights, groups, group_by)

# This must happen after dropping unused groups.
if unweighted_columns:
weights = _adjust_weights_for_unweighted_columns(weights, weighted_columns, unweighted_columns)

weights = _calculate_weighted_group_sizes(weights, target_total_size, random_seed)

missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1))
if missing_groups:
weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, missing_groups, output_missing_weights)

if output_sizes_file:
# TODO: Add another column representing actual number of sequences per
# group. It would be useful for debugging discrepancies between target sizes
# and actual sizes. Not possible with the current inputs to this function.
weights[[*group_by, WEIGHTS_COLUMN, TARGET_SIZE_COLUMN]].to_csv(output_sizes_file, index=False, sep='\t')

return dict(zip(weights[group_by].apply(tuple, axis=1), weights[TARGET_SIZE_COLUMN]))


def _add_unweighted_columns(
weights: pd.DataFrame,
groups: Iterable[Group],
group_by: List[str],
unweighted_columns: List[str],
) -> pd.DataFrame:
"""Add the unweighted columns to the weights DataFrame.
This is done by extending the existing weights to the newly created groups.
"""

# Get unique values for each unweighted column.
values_for_unweighted_columns = defaultdict(set)
for group in groups:
# NOTE: The ordering of entries in `group` corresponds to the column
# names in `group_by`, but only because `get_groups_for_subsampling`
# conveniently retains the order. This could be more tightly coupled,
# but it works.
column_to_value_map = dict(zip(group_by, group))
for column in unweighted_columns:
values_for_unweighted_columns[column].add(column_to_value_map[column])

# Create a DataFrame for all permutations of values in unweighted columns.
lists = [list(values_for_unweighted_columns[column]) for column in unweighted_columns]
unweighted_permutations = pd.DataFrame(list(itertools.product(*lists)), columns=unweighted_columns)

return pd.merge(unweighted_permutations, weights, how='cross')


def _drop_unused_groups(
weights: pd.DataFrame,
groups: Collection[Group],
group_by: List[str],
) -> pd.DataFrame:
"""Drop any groups from ``weights`` that don't appear in ``groups``.
"""
weights.set_index(group_by, inplace=True)
valid_index = set(groups) if len(group_by) > 1 else set(group[0] for group in groups)

# Pandas only uses MultiIndex if there is more than one column in the index.
valid_index: set[Group] | set[str]
if len(group_by) > 1:
valid_index = set(groups)
else:
valid_index = set(group[0] for group in groups)

extra_groups = set(weights.index) - valid_index
if extra_groups:
count = len(extra_groups)
unit = "group" if count == 1 else "groups"
print_err(f"NOTE: Skipping {count} {unit} due to lack of entries in metadata.")
weights = weights[weights.index.isin(valid_index)]

weights.reset_index(inplace=True)

# Adjust weights for unweighted columns to reflect equal weighting within each weighted group.
# This must happen after dropping groups that don't appear in metadata.
if unweighted_columns:
columns = 'column' if len(unweighted_columns) == 1 else 'columns'
those = 'that' if len(unweighted_columns) == 1 else 'those'
print_err(f"NOTE: Weights were not provided for the {columns} {', '.join(repr(col) for col in unweighted_columns)}. Using equal weights across values in {those} {columns}.")
return weights

weights_grouped = weights.groupby(weighted_columns)
weights[WEIGHTS_COLUMN] = weights_grouped[WEIGHTS_COLUMN].transform(lambda x: x / len(x))

# Calculate maximum group sizes based on weights
SIZE_COLUMN_FLOAT = '_augur_filter_target_size_float'
SIZE_COLUMN_INT = '_augur_filter_target_size_int'
weights[SIZE_COLUMN_FLOAT] = weights[WEIGHTS_COLUMN] / weights[WEIGHTS_COLUMN].sum() * target_total_size
def _adjust_weights_for_unweighted_columns(
weights: pd.DataFrame,
weighted_columns: List[str],
unweighted_columns: Collection[str],
) -> pd.DataFrame:
"""Adjust weights for unweighted columns to reflect equal weighting within each weighted group.
"""
columns = 'column' if len(unweighted_columns) == 1 else 'columns'
those = 'that' if len(unweighted_columns) == 1 else 'those'
print_err(f"NOTE: Weights were not provided for the {columns} {', '.join(repr(col) for col in unweighted_columns)}. Using equal weights across values in {those} {columns}.")

weights_grouped = weights.groupby(weighted_columns)
weights[WEIGHTS_COLUMN] = weights_grouped[WEIGHTS_COLUMN].transform(lambda x: x / len(x))

return weights

# Group sizes need to be whole numbers. Round probabilistically by adding
# a random number between [0,1) and truncating the decimal part.

def _calculate_weighted_group_sizes(
weights: pd.DataFrame,
target_total_size: int,
random_seed: Optional[int],
) -> pd.DataFrame:
"""Calculate maximum group sizes based on weights.
"""
weights[TARGET_SIZE_COLUMN] = pd.Series(weights[WEIGHTS_COLUMN] / weights[WEIGHTS_COLUMN].sum() * target_total_size)

# Group sizes must be whole numbers. Round probabilistically by adding a
# random number between [0,1) and truncating the decimal part.
rng = np.random.default_rng(random_seed)
weights[SIZE_COLUMN_INT] = (weights[SIZE_COLUMN_FLOAT].add(rng.random(len(weights)))).astype(int)
weights[TARGET_SIZE_COLUMN] = (weights[TARGET_SIZE_COLUMN].add(pd.Series(rng.random(len(weights))))).astype(int)

missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1))
if missing_groups:
# Collect the column values that are missing weights.
missing_values_by_column = defaultdict(set)
for group in missing_groups:
# NOTE: The ordering of entries in `group` corresponds to the column
# names in `group_by`, but only because `get_groups_for_subsampling`
# conveniently retains the order. This could be more tightly coupled,
# but it works.
column_to_value_map = dict(zip(group_by, group))
for column in weighted_columns:
missing_values_by_column[column].add(column_to_value_map[column])

columns_with_values = '\n - '.join(f'{column!r}: {list(values)}' for column, values in missing_values_by_column.items())
if not output_missing_weights:
raise AugurError(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
- {columns_with_values}
Re-run with --output-group-by-missing-weights to continue."""))
else:
missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by)
missing_weights_weighted_columns_only = missing_weights[weighted_columns].drop_duplicates()
missing_weights_weighted_columns_only[WEIGHTS_COLUMN] = ''
missing_weights_weighted_columns_only.to_csv(output_missing_weights, index=False, sep='\t')
print_err(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
- {columns_with_values}
Sequences associated with these values will be dropped.
A separate weights file has been generated with implicit weight of zero for these values: {output_missing_weights!r}
Consider updating {weights_file!r} with nonzero weights and re-running without --output-group-by-missing-weights."""))

# Drop sequences with missing weights
missing_weights[SIZE_COLUMN_INT] = 0
weights = pd.merge(weights, missing_weights, on=[*group_by, SIZE_COLUMN_INT], how='outer')
return weights

if output_sizes_file:
# TODO: Add another column representing actual number of sequences per
# group. It would be useful for debugging discrepancies between target sizes
# and actual sizes. Not possible with the current inputs to this function.
weights[[*group_by, WEIGHTS_COLUMN, SIZE_COLUMN_INT]].to_csv(output_sizes_file, index=False, sep='\t')

return dict(zip(weights[group_by].apply(tuple, axis=1), weights[SIZE_COLUMN_INT]))
def _handle_incomplete_weights(
weights: pd.DataFrame,
weights_file: str,
weighted_columns: List[str],
group_by: List[str],
missing_groups: Collection[Group],
output_missing_weights: Optional[str],
) -> pd.DataFrame:
"""Handle the case where the weights file does not cover all rows in the metadata.
"""
# Collect the column values that are missing weights.
missing_values_by_column = defaultdict(set)
for group in missing_groups:
# NOTE: The ordering of entries in `group` corresponds to the column
# names in `group_by`, but only because `get_groups_for_subsampling`
# conveniently retains the order. This could be more tightly coupled,
# but it works.
column_to_value_map = dict(zip(group_by, group))
for column in weighted_columns:
missing_values_by_column[column].add(column_to_value_map[column])

columns_with_values = '\n - '.join(f'{column!r}: {list(values)}' for column, values in missing_values_by_column.items())
if not output_missing_weights:
raise AugurError(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
- {columns_with_values}
Re-run with --output-group-by-missing-weights to continue."""))
else:
missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by)
missing_weights_weighted_columns_only = missing_weights[weighted_columns].drop_duplicates()
missing_weights_weighted_columns_only[WEIGHTS_COLUMN] = ''
missing_weights_weighted_columns_only.to_csv(output_missing_weights, index=False, sep='\t')
print_err(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
- {columns_with_values}
Sequences associated with these values will be dropped.
A separate weights file has been generated with implicit weight of zero for these values: {output_missing_weights!r}
Consider updating {weights_file!r} with nonzero weights and re-running without --output-group-by-missing-weights."""))

# Set the weight for these groups to zero, effectively dropping all sequences.
missing_weights[TARGET_SIZE_COLUMN] = 0
return pd.merge(weights, missing_weights, on=[*group_by, TARGET_SIZE_COLUMN], how='outer')


def create_queues_by_group(max_sizes_per_group):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Weight locations A:B as 2:1. This is reflected in target_group_sizes.tsv below.
> --output-metadata filtered.tsv 2>/dev/null

$ cat target_group_sizes.tsv
location weight _augur_filter_target_size_int
location weight _augur_filter_target_size
A 2 67
B 1 33

Expand Down Expand Up @@ -60,7 +60,7 @@ Using 1:1 weights is similarly straightforward, with 50 sequences from each loca
> --output-strains strains.txt 2>/dev/null

$ cat target_group_sizes.tsv
location weight _augur_filter_target_size_int
location weight _augur_filter_target_size
A 1 50
B 1 50

Expand All @@ -79,7 +79,7 @@ available per location.
> --output-strains strains.txt 2>/dev/null

$ cat target_group_sizes.tsv
year location weight _augur_filter_target_size_int
year location weight _augur_filter_target_size
2000 A 0.5 25
2000 B 0.3333333333333333 16
2001 A 0.5 25
Expand All @@ -104,7 +104,7 @@ requested 17, so the total number of sequences outputted is lower than requested
> --output-strains strains.txt 2>/dev/null

$ cat target_group_sizes.tsv
year location weight _augur_filter_target_size_int
year location weight _augur_filter_target_size
2000 A 0.3333333333333333 17
2000 B 0.3333333333333333 16
2001 A 0.3333333333333333 16
Expand Down

0 comments on commit c305fe8

Please sign in to comment.