diff --git a/mephisto/abstractions/databases/local_singleton_database.py b/mephisto/abstractions/databases/local_singleton_database.py index 8286d8511..c0ee7eb69 100644 --- a/mephisto/abstractions/databases/local_singleton_database.py +++ b/mephisto/abstractions/databases/local_singleton_database.py @@ -68,6 +68,7 @@ def __init__(self, database_path=None): # Create singleton dictionaries for entries self._singleton_cache = {k: dict() for k in self._cached_classes} + self._assignment_to_unit_mapping: Dict[str, List[Unit]] = {} def shutdown(self) -> None: """Close all open connections""" @@ -128,3 +129,93 @@ def new_agent( unit.db_status = AssignmentState.ASSIGNED unit.worker_id = agent.worker_id return agent_id + + def find_units( + self, + task_id: Optional[str] = None, + task_run_id: Optional[str] = None, + requester_id: Optional[str] = None, + assignment_id: Optional[str] = None, + unit_index: Optional[int] = None, + provider_type: Optional[str] = None, + task_type: Optional[str] = None, + agent_id: Optional[str] = None, + worker_id: Optional[str] = None, + sandbox: Optional[bool] = None, + status: Optional[str] = None, + ) -> List[Unit]: + """ + Uses caching to offset the cost of the most common queries. Defers + to the underlying DB for anything outside of those cases. + """ + + # Finding units is the most common small DB call to be optimized, as + # every assignment has multiple units. Thus, we try to break up the + # units to be queried by assignment, ensuring the most commonly + # queried edge is improved. + if assignment_id is not None: + if all( + v is None + for v in [ + task_id, + task_run_id, + requester_id, + unit_index, + provider_type, + task_type, + agent_id, + worker_id, + sandbox, + status, + ] + ): + units = self._assignment_to_unit_mapping.get(assignment_id) + if units is None: + units = super().find_units(assignment_id=assignment_id) + self._assignment_to_unit_mapping[assignment_id] = units + return units + + # Any other cases are less common and more complicated, and so we don't cache + return super().find_units( + task_id=task_id, + task_run_id=task_run_id, + requester_id=requester_id, + assignment_id=assignment_id, + unit_index=unit_index, + provider_type=provider_type, + task_type=task_type, + agent_id=agent_id, + worker_id=worker_id, + sandbox=sandbox, + status=status, + ) + + def new_unit( + self, + task_id: str, + task_run_id: str, + requester_id: str, + assignment_id: str, + unit_index: int, + pay_amount: float, + provider_type: str, + task_type: str, + sandbox: bool = True, + ) -> str: + """ + Create a new unit with the given index. Raises EntryAlreadyExistsException + if there is already a unit for the given assignment with the given index. + """ + if assignment_id in self._assignment_to_unit_mapping: + del self._assignment_to_unit_mapping[assignment_id] + return super().new_unit( + task_id=task_id, + task_run_id=task_run_id, + requester_id=requester_id, + assignment_id=assignment_id, + unit_index=unit_index, + pay_amount=pay_amount, + provider_type=provider_type, + task_type=task_type, + sandbox=sandbox, + ) diff --git a/mephisto/abstractions/providers/mturk/mturk_unit.py b/mephisto/abstractions/providers/mturk/mturk_unit.py index 496b9ab7a..800d9edd7 100644 --- a/mephisto/abstractions/providers/mturk/mturk_unit.py +++ b/mephisto/abstractions/providers/mturk/mturk_unit.py @@ -84,10 +84,7 @@ def register_from_provider_data( self.datastore.register_assignment_to_hit( hit_id, self.db_id, mturk_assignment_id ) - self.hit_id = hit_id - self.mturk_assignment_id = mturk_assignment_id - # We made the change, so we can set the sync time. - self._last_sync_time = time.monotonic() + self._sync_hit_mapping() def get_mturk_assignment_id(self) -> Optional[str]: """ diff --git a/mephisto/operations/operator.py b/mephisto/operations/operator.py index 9a138bbef..7d3c9efdd 100644 --- a/mephisto/operations/operator.py +++ b/mephisto/operations/operator.py @@ -48,6 +48,9 @@ from argparse import Namespace +RUN_STATUS_POLL_TIME = 10 + + class TrackedRun(NamedTuple): task_run: TaskRun architect: "Architect" @@ -306,7 +309,7 @@ def _track_and_kill_runs(self): tracked_run.architect.shutdown() tracked_run.task_launcher.shutdown() del self._task_runs_tracked[task_run.db_id] - time.sleep(2) + time.sleep(RUN_STATUS_POLL_TIME) def force_shutdown(self, timeout=5): """ @@ -460,7 +463,7 @@ def wait_for_runs_then_shutdown( if time.time() - last_log > log_rate: last_log = time.time() self.print_run_details() - time.sleep(10) + time.sleep(RUN_STATUS_POLL_TIME) except Exception as e: if skip_input: