diff --git a/locust/distribution.py b/locust/distribution.py index 05aff62b0e..912db25607 100644 --- a/locust/distribution.py +++ b/locust/distribution.py @@ -1,4 +1,5 @@ import math +from itertools import combinations_with_replacement from operator import attrgetter from typing import ( Dict, @@ -59,82 +60,64 @@ def weight_users( return user_class_occurrences elif sum(user_class_occurrences.values()) > number_of_users: - user_class_occurrences_candidates: Dict[float, Dict[str, int]] = {} - _recursive_remove_users( + return _find_ideal_users_to_remove( user_classes, - number_of_users, + sum(user_class_occurrences.values()) - number_of_users, user_class_occurrences.copy(), - user_class_occurrences_candidates, ) - return user_class_occurrences_candidates[min(user_class_occurrences_candidates.keys())] elif sum(user_class_occurrences.values()) < number_of_users: - user_class_occurrences_candidates: Dict[float, Dict[str, int]] = {} - _recursive_add_users( + return _find_ideal_users_to_add( user_classes, - number_of_users, + number_of_users - sum(user_class_occurrences.values()), user_class_occurrences.copy(), - user_class_occurrences_candidates, ) - return user_class_occurrences_candidates[min(user_class_occurrences_candidates.keys())] -def _recursive_add_users( +def _find_ideal_users_to_add( user_classes: List[Type[User]], - number_of_users: int, - user_class_occurrences_candidate: Dict[str, int], - user_class_occurrences_candidates: Dict[float, Dict[str, int]], -): - if sum(user_class_occurrences_candidate.values()) == number_of_users: + number_of_users_to_add: int, + user_class_occurrences: Dict[str, int], +) -> Dict[str, int]: + user_class_occurrences_candidates: Dict[float, Dict[str, int]] = {} + + for user_classes_combination in combinations_with_replacement(user_classes, number_of_users_to_add): + user_class_occurrences_candidate = { + user_class.__name__: user_class_occurrences[user_class.__name__] + + sum(1 for user_class_ in user_classes_combination if user_class_.__name__ == user_class.__name__) + for user_class in user_classes + } distance = distance_from_desired_distribution( user_classes, user_class_occurrences_candidate, ) if distance not in user_class_occurrences_candidates: - user_class_occurrences_candidates[distance] = user_class_occurrences_candidate - return - elif sum(user_class_occurrences_candidate.values()) > number_of_users: - return - - for user_class in user_classes: - user_class_occurrences_candidate_ = user_class_occurrences_candidate.copy() - user_class_occurrences_candidate_[user_class.__name__] += 1 - _recursive_add_users( - user_classes, - number_of_users, - user_class_occurrences_candidate_, - user_class_occurrences_candidates, - ) + user_class_occurrences_candidates[distance] = user_class_occurrences_candidate.copy() + + return user_class_occurrences_candidates[min(user_class_occurrences_candidates.keys())] -def _recursive_remove_users( +def _find_ideal_users_to_remove( user_classes: List[Type[User]], - number_of_users: int, - user_class_occurrences_candidate: Dict[str, int], - user_class_occurrences_candidates: Dict[float, Dict[str, int]], -): - if sum(user_class_occurrences_candidate.values()) == number_of_users: + number_of_users_to_remove: int, + user_class_occurrences: Dict[str, int], +) -> Dict[str, int]: + user_class_occurrences_candidates: Dict[float, Dict[str, int]] = {} + + for user_classes_combination in combinations_with_replacement(user_classes, number_of_users_to_remove): + user_class_occurrences_candidate = { + user_class.__name__: user_class_occurrences[user_class.__name__] + - sum(1 for user_class_ in user_classes_combination if user_class_.__name__ == user_class.__name__) + for user_class in user_classes + } distance = distance_from_desired_distribution( user_classes, user_class_occurrences_candidate, ) if distance not in user_class_occurrences_candidates: - user_class_occurrences_candidates[distance] = user_class_occurrences_candidate - return - elif sum(user_class_occurrences_candidate.values()) < number_of_users: - return - - for user_class in sorted(user_classes, key=lambda u: u.__name__, reverse=True): - if user_class_occurrences_candidate[user_class.__name__] == 1: - continue - user_class_occurrences_candidate_ = user_class_occurrences_candidate.copy() - user_class_occurrences_candidate_[user_class.__name__] -= 1 - _recursive_remove_users( - user_classes, - number_of_users, - user_class_occurrences_candidate_, - user_class_occurrences_candidates, - ) + user_class_occurrences_candidates[distance] = user_class_occurrences_candidate.copy() + + return user_class_occurrences_candidates[min(user_class_occurrences_candidates.keys())] def distance_from_desired_distribution( diff --git a/locust/test/test_distribution.py b/locust/test/test_distribution.py index 7614cadac5..3d55b7d643 100644 --- a/locust/test/test_distribution.py +++ b/locust/test/test_distribution.py @@ -74,7 +74,7 @@ class User3(User): user_classes=[User1, User2, User3], number_of_users=5, ) - self.assertDictEqual(user_class_occurrences, {"User1": 2, "User2": 2, "User3": 1}) + self.assertDictEqual(user_class_occurrences, {"User1": 1, "User2": 2, "User3": 2}) user_class_occurrences = weight_users( user_classes=[User1, User2, User3], @@ -204,7 +204,7 @@ class User3(User): user_classes=[User1, User2, User3], number_of_users=4, ) - self.assertDictEqual(user_class_occurrences, {"User1": 1, "User2": 2, "User3": 1}) + self.assertDictEqual(user_class_occurrences, {"User1": 1, "User2": 1, "User3": 2}) user_class_occurrences = weight_users( user_classes=[User1, User2, User3],