Skip to content

Commit

Permalink
Merge pull request #1080 from facebookresearch/prevent-study-oversamp…
Browse files Browse the repository at this point in the history
…ling-by-one-worker

Prevent oversampling of study submissions by any single worker
  • Loading branch information
meta-paul authored Nov 30, 2023
2 parents 650326a + c5c5bc0 commit e177e7b
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 27 deletions.
6 changes: 5 additions & 1 deletion mephisto/abstractions/providers/mturk/mturk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def create_hit_type(
has_locale_qual = True
locale_requirements += existing_qualifications

if not has_locale_qual and not client_is_sandbox(client):
is_sandbox = client_is_sandbox(client)
if not has_locale_qual and not is_sandbox:
allowed_locales = get_config_arg("mturk", "allowed_locales")
if allowed_locales is None:
allowed_locales = [
Expand All @@ -458,6 +459,9 @@ def create_hit_type(
}
)

if is_sandbox:
hit_reward = 0

# Create the HIT type
response = client.create_hit_type(
AutoApprovalDelayInSeconds=auto_approve_delay,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def _base_request(
else:
result = response.json()

logger.debug(f"{log_prefix} Response: {result}")

return result

except ProlificException:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def remove_participants_from_group(
https://docs.prolific.co/docs/api-docs/public/#tag/
Participant-Groups/paths/~1api~1v1~1participant-groups~1%7Bid%7D~1participants~1/delete
"""
from mephisto.utils.logger_core import get_logger

logger = get_logger(name=__name__)
endpoint = cls.list_participants_for_group_api_endpoint.format(id=id)
params = dict(participant_ids=participant_ids)
response_json = cls.delete(endpoint, params=params)
Expand Down
22 changes: 20 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,29 @@ def new_from_provider_data(

assert isinstance(unit, ProlificUnit), "Can only register Prolific agents to Prolific units"

agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit
task_run: "TaskRun" = agent.get_task_run()

prolific_study_id = provider_data["prolific_study_id"]
prolific_submission_id = provider_data["assignment_id"]
unit.register_from_provider_data(prolific_study_id, prolific_submission_id)

logger.debug("Prolific Submission has been registered successfully")

return super().new_from_provider_data(db, worker, unit, provider_data)
# Check whether we need to prevent this worker from future submissions in this Task
if not worker.can_send_more_submissions_for_task(task_run):
# Excluding worker from Participant Group (instead of adding to Block List)
# only because Prolific cannot update Block List for an in-progress Study
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def approve_work(
self,
Expand Down Expand Up @@ -241,7 +257,6 @@ def get_status(self) -> str:
if prolific_submission_id:
prolific_submission = prolific_utils.get_submission(client, prolific_submission_id)
else:
# TODO: Not sure about this
self.update_status(AgentState.STATUS_EXPIRED)
return self.db_status

Expand All @@ -251,6 +266,9 @@ def get_status(self) -> str:

if prolific_submission.status == SubmissionStatus.RESERVED:
provider_status = local_status
elif prolific_submission.status == SubmissionStatus.ACTIVE:
# We don't need to map this status in our DB
pass
else:
provider_status = SUBMISSION_STATUS_TO_AGENT_STATE_MAP.get(
prolific_submission.status,
Expand Down
10 changes: 6 additions & 4 deletions mephisto/abstractions/providers/prolific/prolific_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def get_blocked_workers(self) -> List[dict]:
results = c.fetchall()
return results

def get_bloked_participant_ids(self) -> List[str]:
def get_blocked_participant_ids(self) -> List[str]:
return [w["worker_id"] for w in self.get_blocked_workers()]

def ensure_unit_exists(self, unit_id: str) -> None:
Expand Down Expand Up @@ -629,7 +629,7 @@ def find_qualifications_by_ids(
task_run_ids: Optional[List[str]] = None,
) -> List[dict]:
"""Find qualifications by Mephisto ids of qualifications and task runs"""
if not qualification_ids:
if not (qualification_ids or task_run_ids):
return []

with self.table_access_condition, self._get_connection() as conn:
Expand All @@ -645,12 +645,14 @@ def find_qualifications_by_ids(
task_run_ids_block = ""
if task_run_ids:
task_run_ids_str = ",".join([f'"{tid}"' for tid in task_run_ids])
task_run_ids_block = f"AND task_run_id IN ({task_run_ids_str})"
task_run_ids_block = f"task_run_id IN ({task_run_ids_str})"

where_block = " AND ".join(filter(bool, [qualification_ids_block, task_run_ids_block]))

c.execute(
f"""
SELECT * FROM qualifications
WHERE {qualification_ids_block} {task_run_ids_block};
WHERE {where_block};
"""
)
results = c.fetchall()
Expand Down
42 changes: 31 additions & 11 deletions mephisto/abstractions/providers/prolific/prolific_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.abstractions.providers.prolific.prolific_worker import ProlificWorker
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
from mephisto.data_model.worker import Worker
from mephisto.operations.registry import register_mephisto_abstraction
from mephisto.utils.logger_core import get_logger
from mephisto.utils.qualifications import QualificationType
Expand All @@ -44,14 +45,13 @@
from .api.exceptions import ProlificException

if TYPE_CHECKING:
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit
from mephisto.data_model.worker import Worker
from mephisto.data_model.requester import Requester
from mephisto.data_model.agent import Agent
from mephisto.abstractions.blueprint import SharedTaskState


DEFAULT_FRAME_HEIGHT = 0
DEFAULT_PROLIFIC_GROUP_NAME_ALLOW_LIST = "Allow list"
DEFAULT_PROLIFIC_GROUP_NAME_BLOCK_LIST = "Block list"
Expand Down Expand Up @@ -173,12 +173,13 @@ def _get_client(self, requester_name: str) -> ProlificClient:
def _get_qualified_workers(
self,
qualifications: List[QualificationType],
bloked_participant_ids: List[str],
blocked_participant_ids: List[str],
task_run: "TaskRun",
) -> List["Worker"]:
qualified_workers = []
workers: List[Worker] = self.db.find_workers(provider_type="prolific")
# `worker_name` is Prolific Participant ID in provider-specific datastore
available_workers = [w for w in workers if w.worker_name not in bloked_participant_ids]
available_workers = [w for w in workers if w.worker_name not in blocked_participant_ids]

for worker in available_workers:
if worker_is_qualified(worker, qualifications):
Expand Down Expand Up @@ -213,6 +214,20 @@ def _create_participant_group_with_qualified_workers(
)
return prolific_participant_group

def _get_excluded_participant_ids(self, task_run: "TaskRun") -> List[str]:
"""Find participant_ids that exceeded `maximum_units_per_worker` cap within this Task"""
task: "Task" = task_run.get_task()
task_units: List["Unit"] = self.db.find_units(task_id=task.db_id)

excluded_participant_ids: List[str] = []
for unit in task_units:
if unit.worker_id:
worker: "Worker" = Worker.get(self.db, unit.worker_id)
if not worker.can_send_more_submissions_for_task(task_run):
excluded_participant_ids.append(worker.worker_name)

return list(set(excluded_participant_ids))

def setup_resources_for_task_run(
self,
task_run: "TaskRun",
Expand Down Expand Up @@ -261,11 +276,12 @@ def setup_resources_for_task_run(
title=args.provider.prolific_project_name,
)

blocked_participant_ids = self.datastore.get_bloked_participant_ids()

blocked_participant_ids: List[str] = self.datastore.get_blocked_participant_ids()
excluded_participant_ids: List[str] = self._get_excluded_participant_ids(task_run)
# If no Mephisto qualifications found,
# we need to block Mephisto workers on Prolific as well
if blocked_participant_ids:
participant_ids_to_add_to_block_list = blocked_participant_ids + excluded_participant_ids
if participant_ids_to_add_to_block_list:
new_prolific_specific_qualifications = []
# Add empty Blacklist in case if there is not in state or config
blacklist_qualification = DictConfig(
Expand All @@ -285,27 +301,31 @@ def setup_resources_for_task_run(
whitelist_qualification = prolific_specific_qualification
prev_value = whitelist_qualification["white_list"]
whitelist_qualification["white_list"] = [
p for p in prev_value if p not in blocked_participant_ids
p for p in prev_value if p not in participant_ids_to_add_to_block_list
]
new_prolific_specific_qualifications.append(whitelist_qualification)
elif name == ParticipantGroupEligibilityRequirement.name:
# Remove blocked Participat IDs from Participant Group Eligibility Requirement
client.ParticipantGroups.remove_participants_from_group(
id=prolific_specific_qualification["id"],
participant_ids=blocked_participant_ids,
participant_ids=participant_ids_to_add_to_block_list,
)
else:
new_prolific_specific_qualifications.append(prolific_specific_qualification)

# Set Blacklist Eligibility Requirement
blacklist_qualification["black_list"] = list(
set(blacklist_qualification["black_list"] + blocked_participant_ids)
set(blacklist_qualification["black_list"] + participant_ids_to_add_to_block_list)
)
new_prolific_specific_qualifications.append(blacklist_qualification)
prolific_specific_qualifications = new_prolific_specific_qualifications

if qualifications:
qualified_workers = self._get_qualified_workers(qualifications, blocked_participant_ids)
qualified_workers = self._get_qualified_workers(
qualifications,
participant_ids_to_add_to_block_list,
task_run,
)

if qualified_workers:
prolific_workers_ids = [w.worker_name for w in qualified_workers]
Expand Down
22 changes: 21 additions & 1 deletion mephisto/abstractions/providers/prolific/prolific_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,15 @@ def compose_completion_codes(code_suffix: str) -> List[dict]:
),
],
),
dict(
code=f"{constants.StudyCodeType.OTHER}_{code_suffix}",
code_type=constants.StudyCodeType.OTHER,
actions=[
dict(
action=constants.StudyAction.MANUALLY_REVIEW,
),
],
),
]

# Task info
Expand Down Expand Up @@ -579,7 +588,10 @@ def remove_worker_qualification(
*args,
**kwargs,
) -> None:
"""Remove a qualification for the given worker (remove a worker from a Participant Group)"""
"""
Remove a qualification for the given worker (remove a worker from a Participant Group).
NOTE: If a participant is not a member of the group, they will be ignored (from API Docs)
"""
try:
client.ParticipantGroups.remove_participants_from_group(
id=qualification_id,
Expand All @@ -592,6 +604,14 @@ def remove_worker_qualification(
raise


def exclude_worker_from_participant_group(
client: ProlificClient,
worker_id: str,
participant_group_id: str,
):
remove_worker_qualification(client, worker_id, participant_group_id)


def pay_bonus(
client: ProlificClient,
task_run_config: "DictConfig",
Expand Down
43 changes: 41 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from typing import Tuple
from typing import TYPE_CHECKING

from omegaconf import DictConfig

from mephisto.abstractions.providers.prolific import prolific_utils
from mephisto.abstractions.providers.prolific.api.client import ProlificClient
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
Expand All @@ -28,6 +26,7 @@
from mephisto.abstractions.providers.prolific.prolific_requester import ProlificRequester
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.data_model.requester import Requester
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit

Expand Down Expand Up @@ -181,6 +180,46 @@ def unblock_worker(self, reason: str, requester: "Requester") -> Tuple[bool, str

return True, ""

def exclude_worker_from_task(
self,
task_run: Optional["TaskRun"] = None,
) -> Tuple[bool, str]:
"""Exclude this worker from current Task"""
logger.debug(f"{self.log_prefix}Excluding worker {self.worker_name} from Prolific")

# 1. Get Client
requester: "ProlificRequester" = task_run.get_requester()
client = self._get_client(requester.requester_name)

# 2. Find TaskRun IDs that are related to current Task
task: "Task" = task_run.get_task()
all_task_run_ids_for_task: List[str] = [t.db_id for t in task.get_runs()]

# 3. Select all Participant Group IDs that are related to the Task
datastore_qualifications = self.datastore.find_qualifications_by_ids(
task_run_ids=all_task_run_ids_for_task,
)
prolific_participant_group_ids = [
q["prolific_participant_group_id"] for q in datastore_qualifications
]

logger.debug(
f"{self.log_prefix}Found {len(prolific_participant_group_ids)} Participant Groups: "
f"{prolific_participant_group_ids}"
)

# 4. Exclude the Worker from Prolific Participant Groups
for prolific_participant_group_id in prolific_participant_group_ids:
prolific_utils.exclude_worker_from_participant_group(
client,
self.worker_name,
prolific_participant_group_id,
)

logger.debug(f"{self.log_prefix}Worker {self.worker_name} excluded")

return True, ""

def is_blocked(self, requester: "Requester") -> bool:
"""Determine if a worker is blocked"""
task_run = self._get_first_task_run(requester)
Expand Down
11 changes: 11 additions & 0 deletions mephisto/data_model/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,17 @@ def new_from_provider_data(
agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit

# Prevent sending more units to worker if worker exceeded submission cap within this Task
task_run: "TaskRun" = agent.get_task_run()
if not worker.can_send_more_submissions_for_task(task_run):
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def get_status(self) -> str:
Expand Down
12 changes: 8 additions & 4 deletions mephisto/data_model/task_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]:

# Cannot pair with self
units: List["Unit"] = []
for unit_set in unit_assigns.values():
is_self_set = map(lambda u: u.worker_id == worker.db_id, unit_set)
if not any(is_self_set):
units += unit_set
for unit_list in unit_assigns.values():
self_linked_units = [
u
for u in unit_list
if u.worker_id == worker.db_id and u.db_status == AssignmentState.LAUNCHED
]
if not self_linked_units:
units += unit_list

# Valid units must be launched and must not be special units (negative indices)
# Can use db_status directly rather than polling in the critical path, as in
Expand Down
Loading

0 comments on commit e177e7b

Please sign in to comment.