Skip to content

Commit

Permalink
Implement weighted sampling
Browse files Browse the repository at this point in the history
Adds a new option --group-by-weights and some additional options to
support the new feature.
  • Loading branch information
victorlin committed Aug 5, 2024
1 parent a20e727 commit 08e922d
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 3 deletions.
31 changes: 30 additions & 1 deletion augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,31 @@ def register_arguments(parser):
subsample_limits_group.add_argument('--sequences-per-group', type=int, help="subsample to no more than this number of sequences per category")
subsample_limits_group.add_argument('--subsample-max-sequences', type=int, help="subsample to no more than this number of sequences; can be used without the group_by argument")
group_size_options = subsample_group.add_mutually_exclusive_group()
group_size_options.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.")
group_size_options.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided. Cannot be used with `--group-by-weights`.")
group_size_options.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling')
group_size_options.add_argument('--group-by-weights', type=str, metavar="FILE", help="""
TSV file defining weights for grouping. Requirements:
(1) The first row must be a header.
(2) There must be a numeric ``weight`` column (weights can take on any
non-negative values).
(3) Other columns must be a subset of columns used in ``--group-by``,
with combinations of values covering all combinations present in the
metadata.
(4) This option only applies when ``--group-by`` and
``--subsample-max-sequences`` are provided.
(5) This option cannot be used with ``--probabilistic-sampling``.
Notes:
(1) Any ``--group-by`` columns absent from this file will be given equal
weighting across all values *within* groups defined by the other
weighted columns.
(2) All combinations of weighted column values that are present in the
metadata must be included in this file. Absence from this file will
cause augur filter to exit with an error describing how to add the
weights explicitly.
""")
subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "<strain>\\t<priority>") and no header.
When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group.
Higher numbers indicate higher priority.
Expand All @@ -81,6 +104,12 @@ def register_arguments(parser):
output_group.add_argument('--output-metadata', help="metadata for strains that passed filters")
output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)")
output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.")
output_group.add_argument('--output-group-by-missing-weights', type=str, metavar="FILE", help="""
TSV file formatted for --group-by-weights with an empty weight column.
Represents groups with entries in --metadata but absent from
--group-by-weights.
""")
output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.")
output_group.add_argument(
'--empty-output-reporting',
type=EmptyOutputReportingMethod.argtype,
Expand Down
15 changes: 13 additions & 2 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from . import include_exclude_rules
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling, get_weighted_group_sizes


def run(args):
Expand Down Expand Up @@ -279,7 +279,18 @@ def run(args):
if queues_by_group is None:
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
if (probabilistic_used):
if args.group_by_weights:
print_err(f"Sampling with weights defined by {args.group_by_weights}.")
group_sizes = get_weighted_group_sizes(
records_per_group.keys(),
group_by,
args.group_by_weights,
args.subsample_max_sequences,
args.output_group_by_missing_weights,
args.output_group_by_sizes,
args.subsample_seed,
)
elif (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
group_sizes = get_probabilistic_group_sizes(
records_per_group.keys(),
Expand Down
106 changes: 106 additions & 0 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import heapq
import itertools
import uuid
Expand All @@ -10,6 +11,7 @@
from augur.io.metadata import METADATA_DATE_COLUMN
from augur.io.print import print_err
from . import constants
from .weights_file import WEIGHTS_COLUMN, get_weighted_columns, read_weights_file


def get_groups_for_subsampling(strains, metadata, group_by=None):
Expand Down Expand Up @@ -297,6 +299,110 @@ def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None):
return max_sizes_per_group


def get_weighted_group_sizes(groups, group_by, weights_file, target_total_size, output_missing_weights, output_sizes_file, random_seed):
"""Return group sizes based on weights defined in ``weights_file``.
Returns
-------
dict :
Mapping between groups (combinations of grouping column values in tuple
form) to group sizes
"""
weights = read_weights_file(weights_file)

weighted_columns = get_weighted_columns(weights_file)

# Allow other columns in group_by to be equally weighted (uniform sampling)
unweighted_columns = list(set(group_by) - set(weighted_columns))

if unweighted_columns:
# Augment the weights DataFrame with equal weighting for unweighted columns.

# Get unique values for each unweighted column.
values_for_unweighted_columns = defaultdict(set)
for group in groups:
# NOTE: The ordering of entries in `group` corresponds to the column
# names in `group_by`, but only because `get_groups_for_subsampling`
# conveniently retains the order. This could be more tightly coupled,
# but it works.
column_to_value_map = dict(zip(group_by, group))
for column in unweighted_columns:
values_for_unweighted_columns[column].add(column_to_value_map[column])

# Create a DataFrame for all permutations of values in unweighted columns.
lists = [list(values_for_unweighted_columns[column]) for column in unweighted_columns]
unweighted_permutations = pd.DataFrame(list(itertools.product(*lists)), columns=unweighted_columns)

# Add the unweighted columns to the weights DataFrame.
# This extends the existing weights to the unweighted groups but with
# the side effect of weighting the values *alongside* (rather than within)
# each weighted group.
# After dropping unused groups, these weights will be adjusted to ensure
# equal weighting of unweighted columns *within* each weighted group
# defined by the weighted columns.
weights = pd.merge(unweighted_permutations, weights, how='cross')

# Drop any groups that don't appear in metadata.
# This must be done even if all columns are weighted.
weights.set_index(group_by, inplace=True)
valid_index = set(groups) if len(group_by) > 1 else set(group[0] for group in groups)
extra_groups = set(weights.index) - valid_index
if extra_groups:
count = len(extra_groups)
unit = "group" if count == 1 else "groups"
print_err(f"NOTE: Skipping {count} {unit} due to lack of entries in metadata.")
weights = weights[weights.index.isin(valid_index)]
weights.reset_index(inplace=True)

# Adjust weights for unweighted columns to reflect equal weighting within each weighted group.
# This must happen after dropping groups that don't appear in metadata.
if unweighted_columns:
columns = 'column' if len(unweighted_columns) == 1 else 'columns'
those = 'that' if len(unweighted_columns) == 1 else 'those'
print_err(f"NOTE: Weights were not provided for the {columns} {', '.join(repr(col) for col in unweighted_columns)}. Using equal weights across values in {those} {columns}.")

weights_grouped = weights.groupby(weighted_columns)
weights[WEIGHTS_COLUMN] = weights_grouped[WEIGHTS_COLUMN].transform(lambda x: x / len(x))

# Calculate maximum group sizes based on weights
SIZE_COLUMN_FLOAT = '_augur_filter_target_size_float'
SIZE_COLUMN_INT = '_augur_filter_target_size_int'
weights[SIZE_COLUMN_FLOAT] = weights[WEIGHTS_COLUMN] / weights[WEIGHTS_COLUMN].sum() * target_total_size

# Group sizes need to be whole numbers. Round probabilistically by adding
# a random number between [0,1) and truncating the decimal part.
rng = np.random.default_rng(random_seed)
weights[SIZE_COLUMN_INT] = (weights[SIZE_COLUMN_FLOAT].add(rng.random(len(weights)))).astype(int)

missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1))
if missing_groups:
n_missing = len(missing_groups)
group_s = "group" if n_missing == 1 else "groups"
are = "is" if n_missing == 1 else "are"
these = "this" if n_missing == 1 else "these"
if not output_missing_weights:
raise AugurError(f"The input metadata contains {n_missing} {group_s} that {are} missing from the weights file. Re-run with --output-group-by-missing-weights to continue.")
else:
print_err(f"WARNING: The input metadata contains {n_missing} {group_s} that {are} missing from the weights file. Sequences from {these} {group_s} will be dropped.")
missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by)
missing_weights_weighted_columns_only = missing_weights[weighted_columns].drop_duplicates()
missing_weights_weighted_columns_only[WEIGHTS_COLUMN] = ''
missing_weights_weighted_columns_only.to_csv(output_missing_weights, index=False, sep='\t')
print_err(f"All missing groups added to {output_missing_weights!r}.")

# Drop sequences with missing weights
missing_weights[SIZE_COLUMN_INT] = 0
weights = pd.merge(weights, missing_weights, on=[*group_by, SIZE_COLUMN_INT], how='outer')

if output_sizes_file:
# TODO: Add another column representing actual number of sequences per
# group. It would be useful for debugging discrepancies between target sizes
# and actual sizes. Not possible with the current inputs to this function.
weights[[*group_by, WEIGHTS_COLUMN, SIZE_COLUMN_INT]].to_csv(output_sizes_file, index=False, sep='\t')

return dict(zip(weights[group_by].apply(tuple, axis=1), weights[SIZE_COLUMN_INT]))


def create_queues_by_group(max_sizes_per_group):
return {group: PriorityQueue(max_size)
for group, max_size in max_sizes_per_group.items()}
Expand Down
14 changes: 14 additions & 0 deletions augur/filter/validate_arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from augur.errors import AugurError
from augur.filter.weights_file import get_weighted_columns
from augur.io.vcf import is_vcf as filename_is_vcf


Expand Down Expand Up @@ -43,3 +44,16 @@ def validate_arguments(args):
# If user requested grouping, confirm that other required inputs are provided, too.
if args.group_by and not any((args.sequences_per_group, args.subsample_max_sequences)):
raise AugurError("You must specify a number of sequences per group or maximum sequences to subsample.")

# Weighted columns must be specified explicitly.
if args.group_by_weights:
weighted_columns = get_weighted_columns(args.group_by_weights)
if (not set(weighted_columns) <= set(args.group_by)):
raise AugurError("Columns in --group-by-weights must be a subset of columns provided in --group-by.")

# --output-group-by-sizes is only available for --group-by-weights.
if args.output_group_by_sizes and not args.group_by_weights:
raise AugurError(
"--output-group-by-sizes is only available for --group-by-weights. "
"It may be added to other sampling methods in the future - see <https://github.com/nextstrain/augur/issues/new>" # FIXME: create a GitHub issue and link it here
)
37 changes: 37 additions & 0 deletions augur/filter/weights_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import csv
import pandas as pd
from textwrap import dedent
from augur.errors import AugurError


WEIGHTS_COLUMN = 'weight'


class InvalidWeightsFile(AugurError):
def __init__(self, file, error_message):
super().__init__(f"Bad weights file {file!r}.\n{error_message}")


def read_weights_file(weights_file):
weights = pd.read_csv(weights_file, delimiter='\t')

if not pd.api.types.is_numeric_dtype(weights[WEIGHTS_COLUMN]):
non_numeric_weight_lines = [index + 2 for index in weights[~weights[WEIGHTS_COLUMN].str.isnumeric()].index.tolist()]
raise InvalidWeightsFile(weights_file, dedent(f"""\
Found non-numeric weights on the following lines: {non_numeric_weight_lines}
{WEIGHTS_COLUMN!r} column must be numeric."""))

if any(weights[WEIGHTS_COLUMN] < 0):
negative_weight_lines = [index + 2 for index in weights[weights[WEIGHTS_COLUMN] < 0].index.tolist()]
raise InvalidWeightsFile(weights_file, dedent(f"""\
Found negative weights on the following lines: {negative_weight_lines}
{WEIGHTS_COLUMN!r} column must be non-negative."""))

return weights


def get_weighted_columns(weights_file):
with open(weights_file) as f:
weighted_columns = next(csv.reader(f, delimiter='\t'))
weighted_columns.remove('weight')
return weighted_columns
1 change: 1 addition & 0 deletions docs/api/developer/augur.filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ Submodules
augur.filter.io
augur.filter.subsample
augur.filter.validate_arguments
augur.filter.weights_file
7 changes: 7 additions & 0 deletions docs/api/developer/augur.filter.weights_file.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
augur.filter.weights\_file module
=================================

.. automodule:: augur.filter.weights_file
:members:
:undoc-members:
:show-inheritance:

0 comments on commit 08e922d

Please sign in to comment.