Skip to content

Commit

Permalink
Merge pull request #1460: Prepare for weighted sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin authored Aug 5, 2024
2 parents f49a3e4 + a8f8827 commit 79732b8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 49 deletions.
8 changes: 4 additions & 4 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def register_arguments(parser):
sequence_filter_group.add_argument('--non-nucleotide', action='store_true', help="exclude sequences that contain illegal characters")

subsample_group = parser.add_argument_group("subsampling", "options to subsample filtered data")
subsample_group.add_argument('--group-by', nargs='+', action='extend', help=f"""
subsample_group.add_argument('--group-by', nargs='+', action='extend', default=[], help=f"""
categories with respect to subsample.
Notes:
(1) Grouping by {sorted(constants.GROUP_BY_GENERATED_COLUMNS)} is only supported when there is a {METADATA_DATE_COLUMN!r} column in the metadata.
Expand All @@ -67,9 +67,9 @@ def register_arguments(parser):
subsample_limits_group = subsample_group.add_mutually_exclusive_group()
subsample_limits_group.add_argument('--sequences-per-group', type=int, help="subsample to no more than this number of sequences per category")
subsample_limits_group.add_argument('--subsample-max-sequences', type=int, help="subsample to no more than this number of sequences; can be used without the group_by argument")
probabilistic_sampling_group = subsample_group.add_mutually_exclusive_group()
probabilistic_sampling_group.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.")
probabilistic_sampling_group.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling')
group_size_options = subsample_group.add_mutually_exclusive_group()
group_size_options.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.")
group_size_options.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling')
subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "<strain>\\t<priority>") and no header.
When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group.
Higher numbers indicate higher priority.
Expand Down
24 changes: 13 additions & 11 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from . import include_exclude_rules
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling


def run(args):
Expand Down Expand Up @@ -276,19 +276,21 @@ def run(args):
except TooManyGroupsError as error:
raise AugurError(error)

if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
else:
print_err(f"Sampling at {sequences_per_group} per group.")

if queues_by_group is None:
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
queues_by_group = create_queues_by_group(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
)
if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
group_sizes = get_probabilistic_group_sizes(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
)
else:
print_err(f"Sampling at {sequences_per_group} per group.")
assert type(sequences_per_group) is int
group_sizes = {group: sequences_per_group for group in records_per_group.keys()}
queues_by_group = create_queues_by_group(group_sizes)

# Make a second pass through the metadata, only considering records that
# have passed filters.
Expand Down
61 changes: 27 additions & 34 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,64 +249,57 @@ def get_items(self):
yield item


def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None):
"""Create a dictionary of priority queues per group for the given maximum size.
def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None):
"""Create a dictionary of maximum sizes per group.
When the maximum size is fractional, probabilistically sample the maximum
size from a Poisson distribution. Make at least the given number of maximum
attempts to create queues for which the sum of their maximum sizes is
greater than zero.
Probabilistically generate varying sizes from a Poisson distribution. Make
at least the given number of maximum attempts to generate sizes for which
the total of all sizes is greater than zero.
Examples
--------
Create queues for two groups with a fixed maximum size.
>>> groups = ("2015", "2016")
>>> queues = create_queues_by_group(groups, 2)
>>> sum(queue.max_size for queue in queues.values())
4
Create queues for two groups with a fractional maximum size. Their total max
Get sizes for two groups with a fractional maximum size. Their total
size should still be an integer value greater than zero.
>>> groups = ("2015", "2016")
>>> seed = 314159
>>> queues = create_queues_by_group(groups, 0.1, random_seed=seed)
>>> int(sum(queue.max_size for queue in queues.values())) > 0
>>> group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed)
>>> int(sum(group_sizes.values())) > 0
True
A subsequent run of this function with the same groups and random seed
should produce the same queues and queue sizes.
should produce the same group sizes.
>>> more_queues = create_queues_by_group(groups, 0.1, random_seed=seed)
>>> [queue.max_size for queue in queues.values()] == [queue.max_size for queue in more_queues.values()]
>>> more_group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed)
>>> list(group_sizes.values()) == list(more_group_sizes.values())
True
"""
queues_by_group = {}
total_max_size = 0
attempts = 0

if max_size < 1.0:
random_generator = np.random.default_rng(random_seed)
assert target_group_size < 1.0

# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
random_generator = np.random.default_rng(random_seed)
total_max_size = 0
attempts = 0
max_attempts = 100
max_sizes_per_group = {}

while total_max_size == 0 and attempts < max_attempts:
for group in sorted(groups):
if max_size < 1.0:
queue_max_size = random_generator.poisson(max_size)
else:
queue_max_size = max_size
max_sizes_per_group[group] = random_generator.poisson(target_group_size)

queues_by_group[group] = PriorityQueue(queue_max_size)

total_max_size = sum(queue.max_size for queue in queues_by_group.values())
total_max_size = sum(max_sizes_per_group.values())
attempts += 1

return queues_by_group
return max_sizes_per_group


def create_queues_by_group(max_sizes_per_group):
return {group: PriorityQueue(max_size)
for group, max_size in max_sizes_per_group.items()}


def calculate_sequences_per_group(target_max_value, group_sizes, allow_probabilistic=True):
Expand Down

0 comments on commit 79732b8

Please sign in to comment.