diff --git a/components/clp-package-utils/clp_package_utils/scripts/native/search.py b/components/clp-package-utils/clp_package_utils/scripts/native/search.py index aa261d904..9041b0006 100755 --- a/components/clp-package-utils/clp_package_utils/scripts/native/search.py +++ b/components/clp-package-utils/clp_package_utils/scripts/native/search.py @@ -15,8 +15,8 @@ import pymongo from clp_py_utils.clp_config import Database, QUERY_JOBS_TABLE_NAME, ResultsCache from clp_py_utils.sql_adapter import SQL_Adapter -from job_orchestration.scheduler.constants import QueryJobStatus -from job_orchestration.scheduler.job_config import AggregationConfig, SearchConfig +from job_orchestration.scheduler.constants import QueryJobStatus, QueryJobType +from job_orchestration.scheduler.job_config import AggregationConfig, SearchJobConfig from clp_package_utils.general import ( CLP_DEFAULT_CONFIG_FILE_RELATIVE_PATH, @@ -83,7 +83,7 @@ def create_and_monitor_job_in_db( do_count_aggregation: bool | None, count_by_time_bucket_size: int | None, ): - search_config = SearchConfig( + search_config = SearchJobConfig( query_string=wildcard_query, begin_timestamp=begin_timestamp, end_timestamp=end_timestamp, @@ -111,8 +111,8 @@ def create_and_monitor_job_in_db( ) as db_cursor: # Create job db_cursor.execute( - f"INSERT INTO `{QUERY_JOBS_TABLE_NAME}` (`job_config`) VALUES (%s)", - (msgpack.packb(search_config.dict()),), + f"INSERT INTO `{QUERY_JOBS_TABLE_NAME}` (`job_config`, `type`) VALUES (%s, %s)", + (msgpack.packb(search_config.dict()), QueryJobType.SEARCH_OR_AGGREGATION), ) db_conn.commit() job_id = db_cursor.lastrowid diff --git a/components/clp-py-utils/clp_py_utils/initialize-orchestration-db.py b/components/clp-py-utils/clp_py_utils/initialize-orchestration-db.py index 32a285c42..1ed727367 100644 --- a/components/clp-py-utils/clp_py_utils/initialize-orchestration-db.py +++ b/components/clp-py-utils/clp_py_utils/initialize-orchestration-db.py @@ -97,6 +97,7 @@ def main(argv): f""" CREATE TABLE IF NOT EXISTS `{QUERY_JOBS_TABLE_NAME}` ( `id` INT NOT NULL AUTO_INCREMENT, + `type` INT NOT NULL, `status` INT NOT NULL DEFAULT '{QueryJobStatus.PENDING}', `creation_time` DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3), `num_tasks` INT NOT NULL DEFAULT '0', diff --git a/components/job-orchestration/job_orchestration/executor/query/fs_search_task.py b/components/job-orchestration/job_orchestration/executor/query/fs_search_task.py index f51eae407..92522a2d0 100644 --- a/components/job-orchestration/job_orchestration/executor/query/fs_search_task.py +++ b/components/job-orchestration/job_orchestration/executor/query/fs_search_task.py @@ -13,7 +13,7 @@ from clp_py_utils.clp_logging import set_logging_level from clp_py_utils.sql_adapter import SQL_Adapter from job_orchestration.executor.query.celery import app -from job_orchestration.scheduler.job_config import SearchConfig +from job_orchestration.scheduler.job_config import SearchJobConfig from job_orchestration.scheduler.scheduler_data import QueryTaskResult, QueryTaskStatus # Setup logging @@ -41,7 +41,7 @@ def make_command( clp_home: Path, archives_dir: Path, archive_id: str, - search_config: SearchConfig, + search_config: SearchJobConfig, results_cache_uri: str, results_collection: str, ): @@ -113,7 +113,7 @@ def search( self: Task, job_id: str, task_id: int, - search_config_obj: dict, + job_config_obj: dict, archive_id: str, clp_metadata_db_conn_params: dict, results_cache_uri: str, @@ -133,7 +133,7 @@ def search( logger.info(f"Started task for job {job_id}") - search_config = SearchConfig.parse_obj(search_config_obj) + search_config = SearchJobConfig.parse_obj(job_config_obj) sql_adapter = SQL_Adapter(Database.parse_obj(clp_metadata_db_conn_params)) start_time = datetime.datetime.now() @@ -168,7 +168,7 @@ def search( task_id=task_id, status=QueryTaskStatus.FAILED, duration=0, - error_log_path=clo_log_path, + error_log_path=str(clo_log_path), ).dict() update_search_task_metadata( @@ -231,6 +231,6 @@ def sigterm_handler(_signo, _stack_frame): ) if QueryTaskStatus.FAILED == search_status: - search_task_result.error_log_path = clo_log_path + search_task_result.error_log_path = str(clo_log_path) return search_task_result.dict() diff --git a/components/job-orchestration/job_orchestration/scheduler/constants.py b/components/job-orchestration/job_orchestration/scheduler/constants.py index 62f06f0cf..b640524d9 100644 --- a/components/job-orchestration/job_orchestration/scheduler/constants.py +++ b/components/job-orchestration/job_orchestration/scheduler/constants.py @@ -67,3 +67,13 @@ def __str__(self) -> str: def to_str(self) -> str: return str(self.name) + + +class QueryJobType(IntEnum): + SEARCH_OR_AGGREGATION = 0 + + def __str__(self) -> str: + return str(self.value) + + def to_str(self) -> str: + return str(self.name) diff --git a/components/job-orchestration/job_orchestration/scheduler/job_config.py b/components/job-orchestration/job_orchestration/scheduler/job_config.py index 93d4ede4e..528dce21a 100644 --- a/components/job-orchestration/job_orchestration/scheduler/job_config.py +++ b/components/job-orchestration/job_orchestration/scheduler/job_config.py @@ -39,7 +39,10 @@ class AggregationConfig(BaseModel): count_by_time_bucket_size: typing.Optional[int] = None # Milliseconds -class SearchConfig(BaseModel): +class QueryJobConfig(BaseModel): ... + + +class SearchJobConfig(QueryJobConfig): query_string: str max_num_results: int tags: typing.Optional[typing.List[str]] = None diff --git a/components/job-orchestration/job_orchestration/scheduler/query/query_scheduler.py b/components/job-orchestration/job_orchestration/scheduler/query/query_scheduler.py index d8a045f31..5331051ae 100644 --- a/components/job-orchestration/job_orchestration/scheduler/query/query_scheduler.py +++ b/components/job-orchestration/job_orchestration/scheduler/query/query_scheduler.py @@ -24,7 +24,7 @@ import pathlib import sys from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import celery import msgpack @@ -40,22 +40,27 @@ from clp_py_utils.decorators import exception_default_value from clp_py_utils.sql_adapter import SQL_Adapter from job_orchestration.executor.query.fs_search_task import search -from job_orchestration.scheduler.constants import QueryJobStatus, QueryTaskStatus -from job_orchestration.scheduler.job_config import SearchConfig +from job_orchestration.scheduler.constants import QueryJobStatus, QueryJobType, QueryTaskStatus +from job_orchestration.scheduler.job_config import SearchJobConfig from job_orchestration.scheduler.query.reducer_handler import ( handle_reducer_connection, ReducerHandlerMessage, ReducerHandlerMessageQueues, ReducerHandlerMessageType, ) -from job_orchestration.scheduler.scheduler_data import InternalJobState, QueryTaskResult, SearchJob +from job_orchestration.scheduler.scheduler_data import ( + InternalJobState, + QueryJob, + QueryTaskResult, + SearchJob, +) from pydantic import ValidationError # Setup logging logger = get_logger("search-job-handler") # Dictionary of active jobs indexed by job id -active_jobs: Dict[str, SearchJob] = {} +active_jobs: Dict[str, QueryJob] = {} reducer_connection_queue: Optional[asyncio.Queue] = None @@ -91,18 +96,19 @@ async def release_reducer_for_job(job: SearchJob): @exception_default_value(default=[]) -def fetch_new_search_jobs(db_conn) -> list: +def fetch_new_query_jobs(db_conn) -> list: """ - Fetches search jobs with status=PENDING from the database. + Fetches query jobs with status=PENDING from the database. :param db_conn: - :return: The pending search jobs on success. An empty list if an exception occurs while + :return: The pending query jobs on success. An empty list if an exception occurs while interacting with the database. """ with contextlib.closing(db_conn.cursor(dictionary=True)) as db_cursor: db_cursor.execute( f""" SELECT {QUERY_JOBS_TABLE_NAME}.id as job_id, - {QUERY_JOBS_TABLE_NAME}.job_config + {QUERY_JOBS_TABLE_NAME}.job_config, + {QUERY_JOBS_TABLE_NAME}.type FROM {QUERY_JOBS_TABLE_NAME} WHERE {QUERY_JOBS_TABLE_NAME}.status={QueryJobStatus.PENDING} """ @@ -124,6 +130,7 @@ def fetch_cancelling_search_jobs(db_conn) -> list: SELECT {QUERY_JOBS_TABLE_NAME}.id as job_id FROM {QUERY_JOBS_TABLE_NAME} WHERE {QUERY_JOBS_TABLE_NAME}.status={QueryJobStatus.CANCELLING} + AND {QUERY_JOBS_TABLE_NAME}.type={QueryJobType.SEARCH_OR_AGGREGATION} """ ) return db_cursor.fetchall() @@ -222,7 +229,7 @@ async def handle_cancelling_search_jobs(db_conn_pool) -> None: logger.error(f"Failed to cancel job {job_id}.") -def insert_search_tasks_into_db(db_conn, job_id, archive_ids: List[str]) -> List[int]: +def insert_query_tasks_into_db(db_conn, job_id, archive_ids: List[str]) -> List[int]: task_ids = [] with contextlib.closing(db_conn.cursor()) as cursor: for archive_id in archive_ids: @@ -241,7 +248,7 @@ def insert_search_tasks_into_db(db_conn, job_id, archive_ids: List[str]) -> List @exception_default_value(default=[]) def get_archives_for_search( db_conn, - search_config: SearchConfig, + search_config: SearchJobConfig, ): query = f"""SELECT id as archive_id, end_timestamp FROM {CLP_METADATA_TABLE_PREFIX}archives @@ -273,18 +280,17 @@ def get_archives_for_search( def get_task_group_for_job( archive_ids: List[str], task_ids: List[int], - job_id: str, - search_config: SearchConfig, + job: QueryJob, clp_metadata_db_conn_params: Dict[str, any], results_cache_uri: str, ): - search_config_obj = search_config.dict() + job_config_obj = job.get_config().dict() return celery.group( search.s( - job_id=job_id, + job_id=job.id, archive_id=archive_ids[i], task_id=task_ids[i], - search_config_obj=search_config_obj, + job_config_obj=job_config_obj, clp_metadata_db_conn_params=clp_metadata_db_conn_params, results_cache_uri=results_cache_uri, ) @@ -292,22 +298,20 @@ def get_task_group_for_job( ) -def dispatch_search_job( +def dispatch_query_job( db_conn, - job: SearchJob, - archives_for_search: List[Dict[str, any]], + job: QueryJob, + archive_ids: List[str], clp_metadata_db_conn_params: Dict[str, any], results_cache_uri: str, ) -> None: global active_jobs - archive_ids = [archive["archive_id"] for archive in archives_for_search] - task_ids = insert_search_tasks_into_db(db_conn, job.id, archive_ids) + task_ids = insert_query_tasks_into_db(db_conn, job.id, archive_ids) task_group = get_task_group_for_job( archive_ids, task_ids, - job.id, - job.search_config, + job, clp_metadata_db_conn_params, results_cache_uri, ) @@ -360,7 +364,7 @@ async def acquire_reducer_for_job(job: SearchJob): logger.info(f"Got reducer for job {job.id} at {reducer_host}:{reducer_port}") -def handle_pending_search_jobs( +def handle_pending_query_jobs( db_conn_pool, clp_metadata_db_conn_params: Dict[str, any], results_cache_uri: str, @@ -370,57 +374,68 @@ def handle_pending_search_jobs( reducer_acquisition_tasks = [] - pending_jobs = [ - job for job in active_jobs.values() if InternalJobState.WAITING_FOR_DISPATCH == job.state + pending_search_jobs = [ + job + for job in active_jobs.values() + if InternalJobState.WAITING_FOR_DISPATCH == job.state + and job.get_type() == QueryJobType.SEARCH_OR_AGGREGATION ] with contextlib.closing(db_conn_pool.connect()) as db_conn: - for job in fetch_new_search_jobs(db_conn): + for job in fetch_new_query_jobs(db_conn): job_id = str(job["job_id"]) - - # Avoid double-dispatch when a job is WAITING_FOR_REDUCER - if job_id in active_jobs: - continue - - search_config = SearchConfig.parse_obj(msgpack.unpackb(job["job_config"])) - archives_for_search = get_archives_for_search(db_conn, search_config) - if len(archives_for_search) == 0: - if set_job_or_task_status( - db_conn, - QUERY_JOBS_TABLE_NAME, - job_id, - QueryJobStatus.SUCCEEDED, - QueryJobStatus.PENDING, - start_time=datetime.datetime.now(), - num_tasks=0, - duration=0, - ): - logger.info(f"No matching archives, skipping job {job['job_id']}.") - continue - - new_search_job = SearchJob( - id=job_id, - search_config=search_config, - state=InternalJobState.WAITING_FOR_DISPATCH, - num_archives_to_search=len(archives_for_search), - num_archives_searched=0, - remaining_archives_for_search=archives_for_search, - ) - - if search_config.aggregation_config is not None: - new_search_job.search_config.aggregation_config.job_id = job["job_id"] - new_search_job.state = InternalJobState.WAITING_FOR_REDUCER - new_search_job.reducer_acquisition_task = asyncio.create_task( - acquire_reducer_for_job(new_search_job) + job_type = job["type"] + job_config = job["job_config"] + + if QueryJobType.SEARCH_OR_AGGREGATION == job_type: + # Avoid double-dispatch when a job is WAITING_FOR_REDUCER + if job_id in active_jobs: + continue + + search_config = SearchJobConfig.parse_obj(msgpack.unpackb(job_config)) + archives_for_search = get_archives_for_search(db_conn, search_config) + if len(archives_for_search) == 0: + if set_job_or_task_status( + db_conn, + QUERY_JOBS_TABLE_NAME, + job_id, + QueryJobStatus.SUCCEEDED, + QueryJobStatus.PENDING, + start_time=datetime.datetime.now(), + num_tasks=0, + duration=0, + ): + logger.info(f"No matching archives, skipping job {job_id}.") + continue + + new_search_job = SearchJob( + id=job_id, + search_config=search_config, + state=InternalJobState.WAITING_FOR_DISPATCH, + num_archives_to_search=len(archives_for_search), + num_archives_searched=0, + remaining_archives_for_search=archives_for_search, ) - reducer_acquisition_tasks.append(new_search_job.reducer_acquisition_task) + + if search_config.aggregation_config is not None: + new_search_job.search_config.aggregation_config.job_id = int(job_id) + new_search_job.state = InternalJobState.WAITING_FOR_REDUCER + new_search_job.reducer_acquisition_task = asyncio.create_task( + acquire_reducer_for_job(new_search_job) + ) + reducer_acquisition_tasks.append(new_search_job.reducer_acquisition_task) + else: + pending_search_jobs.append(new_search_job) + active_jobs[job_id] = new_search_job else: - pending_jobs.append(new_search_job) - active_jobs[job_id] = new_search_job + # NOTE: We're skipping the job for this iteration, but its status will remain + # unchanged. So this log will print again in the next iteration unless the user + # cancels the job. + logger.error(f"Unexpected job type: {job_type}, skipping job {job_id}") + continue - for job in pending_jobs: + for job in pending_search_jobs: job_id = job.id - if ( job.search_config.network_address is None and len(job.remaining_archives_for_search) > num_archives_to_search_per_sub_job @@ -435,11 +450,13 @@ def handle_pending_search_jobs( archives_for_search = job.remaining_archives_for_search job.remaining_archives_for_search = [] - dispatch_search_job( - db_conn, job, archives_for_search, clp_metadata_db_conn_params, results_cache_uri + archive_ids_for_search = [archive["archive_id"] for archive in archives_for_search] + + dispatch_query_job( + db_conn, job, archive_ids_for_search, clp_metadata_db_conn_params, results_cache_uri ) logger.info( - f"Dispatched job {job_id} with {len(archives_for_search)} archives to search." + f"Dispatched job {job_id} with {len(archive_ids_for_search)} archives to search." ) start_time = datetime.datetime.now() job.start_time = start_time @@ -487,6 +504,92 @@ def found_max_num_latest_results( return max_timestamp_in_remaining_archives <= min_timestamp_in_top_results +async def handle_finished_search_job( + db_conn, job: SearchJob, task_results: Optional[Any], results_cache_uri: str +) -> None: + global active_jobs + + job_id = job.id + is_reducer_job = job.reducer_handler_msg_queues is not None + new_job_status = QueryJobStatus.RUNNING + for task_result_obj in task_results: + task_result = QueryTaskResult.parse_obj(task_result_obj) + task_id = task_result.task_id + task_status = task_result.status + if not task_status == QueryTaskStatus.SUCCEEDED: + new_job_status = QueryJobStatus.FAILED + logger.error( + f"Search task job-{job_id}-task-{task_id} failed. " + f"Check {task_result.error_log_path} for details." + ) + else: + job.num_archives_searched += 1 + logger.info( + f"Search task job-{job_id}-task-{task_id} succeeded in " + f"{task_result.duration} second(s)." + ) + + if new_job_status != QueryJobStatus.FAILED: + max_num_results = job.search_config.max_num_results + # Check if we've searched all archives + if len(job.remaining_archives_for_search) == 0: + new_job_status = QueryJobStatus.SUCCEEDED + # Check if we've reached max results + elif False == is_reducer_job and max_num_results > 0: + if found_max_num_latest_results( + results_cache_uri, + job_id, + max_num_results, + job.remaining_archives_for_search[0]["end_timestamp"], + ): + new_job_status = QueryJobStatus.SUCCEEDED + if new_job_status == QueryJobStatus.RUNNING: + job.current_sub_job_async_task_result = None + job.state = InternalJobState.WAITING_FOR_DISPATCH + logger.info(f"Job {job_id} waiting for more archives to search.") + set_job_or_task_status( + db_conn, + QUERY_JOBS_TABLE_NAME, + job_id, + QueryJobStatus.RUNNING, + QueryJobStatus.RUNNING, + num_tasks_completed=job.num_archives_searched, + ) + return + + reducer_failed = False + if is_reducer_job: + # Notify reducer that it should have received all results + msg = ReducerHandlerMessage(ReducerHandlerMessageType.SUCCESS) + await job.reducer_handler_msg_queues.put_to_handler(msg) + + msg = await job.reducer_handler_msg_queues.get_from_handler() + if ReducerHandlerMessageType.FAILURE == msg.msg_type: + reducer_failed = True + new_job_status = QueryJobStatus.FAILED + elif ReducerHandlerMessageType.SUCCESS != msg.msg_type: + error_msg = f"Unexpected msg_type: {msg.msg_type.name}" + raise NotImplementedError(error_msg) + + # We set the status regardless of the job's previous status to handle the case where the + # job is cancelled (status = CANCELLING) while we're in this method. + if set_job_or_task_status( + db_conn, + QUERY_JOBS_TABLE_NAME, + job_id, + new_job_status, + num_tasks_completed=job.num_archives_searched, + duration=(datetime.datetime.now() - job.start_time).total_seconds(), + ): + if new_job_status == QueryJobStatus.SUCCEEDED: + logger.info(f"Completed job {job_id}.") + elif reducer_failed: + logger.error(f"Completed job {job_id} with failing reducer.") + else: + logger.info(f"Completed job {job_id} with failing tasks.") + del active_jobs[job_id] + + async def check_job_status_and_update_db(db_conn_pool, results_cache_uri): global active_jobs @@ -495,16 +598,15 @@ async def check_job_status_and_update_db(db_conn_pool, results_cache_uri): id for id, job in active_jobs.items() if InternalJobState.RUNNING == job.state ]: job = active_jobs[job_id] - is_reducer_job = job.reducer_handler_msg_queues is not None - try: returned_results = try_getting_task_result(job.current_sub_job_async_task_result) except Exception as e: logger.error(f"Job `{job_id}` failed: {e}.") # Clean up - if is_reducer_job: - msg = ReducerHandlerMessage(ReducerHandlerMessageType.FAILURE) - await job.reducer_handler_msg_queues.put_to_handler(msg) + if QueryJobType.SEARCH_OR_AGGREGATION == job.get_type(): + if job.reducer_handler_msg_queues is not None: + msg = ReducerHandlerMessage(ReducerHandlerMessageType.FAILURE) + await job.reducer_handler_msg_queues.put_to_handler(msg) del active_jobs[job_id] set_job_or_task_status( @@ -519,84 +621,14 @@ async def check_job_status_and_update_db(db_conn_pool, results_cache_uri): if returned_results is None: continue - - new_job_status = QueryJobStatus.RUNNING - for task_result_obj in returned_results: - task_result = QueryTaskResult.parse_obj(task_result_obj) - task_id = task_result.task_id - task_status = task_result.status - if not task_status == QueryTaskStatus.SUCCEEDED: - new_job_status = QueryJobStatus.FAILED - logger.error( - f"Search task job-{job_id}-task-{task_id} failed. " - f"Check {task_result.error_log_path} for details." - ) - else: - job.num_archives_searched += 1 - logger.info( - f"Search task job-{job_id}-task-{task_id} succeeded in " - f"{task_result.duration} second(s)." - ) - - if new_job_status != QueryJobStatus.FAILED: - max_num_results = job.search_config.max_num_results - # Check if we've searched all archives - if len(job.remaining_archives_for_search) == 0: - new_job_status = QueryJobStatus.SUCCEEDED - # Check if we've reached max results - elif False == is_reducer_job and max_num_results > 0: - if found_max_num_latest_results( - results_cache_uri, - job_id, - max_num_results, - job.remaining_archives_for_search[0]["end_timestamp"], - ): - new_job_status = QueryJobStatus.SUCCEEDED - if new_job_status == QueryJobStatus.RUNNING: - job.current_sub_job_async_task_result = None - job.state = InternalJobState.WAITING_FOR_DISPATCH - logger.info(f"Job {job_id} waiting for more archives to search.") - set_job_or_task_status( - db_conn, - QUERY_JOBS_TABLE_NAME, - job_id, - QueryJobStatus.RUNNING, - QueryJobStatus.RUNNING, - num_tasks_completed=job.num_archives_searched, + job_type = job.get_type() + if QueryJobType.SEARCH_OR_AGGREGATION == job_type: + search_job: SearchJob = job + await handle_finished_search_job( + db_conn, search_job, returned_results, results_cache_uri ) - continue - - reducer_failed = False - if is_reducer_job: - # Notify reducer that it should have received all results - msg = ReducerHandlerMessage(ReducerHandlerMessageType.SUCCESS) - await job.reducer_handler_msg_queues.put_to_handler(msg) - - msg = await job.reducer_handler_msg_queues.get_from_handler() - if ReducerHandlerMessageType.FAILURE == msg.msg_type: - reducer_failed = True - new_job_status = QueryJobStatus.FAILED - elif ReducerHandlerMessageType.SUCCESS != msg.msg_type: - error_msg = f"Unexpected msg_type: {msg.msg_type.name}" - raise NotImplementedError(error_msg) - - # We set the status regardless of the job's previous status to handle the case where the - # job is cancelled (status = CANCELLING) while we're in this method. - if set_job_or_task_status( - db_conn, - QUERY_JOBS_TABLE_NAME, - job_id, - new_job_status, - num_tasks_completed=job.num_archives_searched, - duration=(datetime.datetime.now() - job.start_time).total_seconds(), - ): - if new_job_status == QueryJobStatus.SUCCEEDED: - logger.info(f"Completed job {job_id}.") - elif reducer_failed: - logger.error(f"Completed job {job_id} with failing reducer.") - else: - logger.info(f"Completed job {job_id} with failing tasks.") - del active_jobs[job_id] + else: + logger.error(f"Unexpected job type: {job_type}, skipping job {job_id}") async def handle_job_updates(db_conn_pool, results_cache_uri: str, jobs_poll_delay: float): @@ -619,7 +651,7 @@ async def handle_jobs( tasks = [handle_updating_task] while True: - reducer_acquisition_tasks = handle_pending_search_jobs( + reducer_acquisition_tasks = handle_pending_query_jobs( db_conn_pool, clp_metadata_db_conn_params, results_cache_uri, diff --git a/components/job-orchestration/job_orchestration/scheduler/scheduler_data.py b/components/job-orchestration/job_orchestration/scheduler/scheduler_data.py index a3aa5f436..d337e0806 100644 --- a/components/job-orchestration/job_orchestration/scheduler/scheduler_data.py +++ b/components/job-orchestration/job_orchestration/scheduler/scheduler_data.py @@ -1,10 +1,15 @@ import asyncio import datetime +from abc import ABC, abstractmethod from enum import auto, Enum from typing import Any, Dict, List, Optional -from job_orchestration.scheduler.constants import CompressionTaskStatus, QueryTaskStatus -from job_orchestration.scheduler.job_config import SearchConfig +from job_orchestration.scheduler.constants import ( + CompressionTaskStatus, + QueryJobType, + QueryTaskStatus, +) +from job_orchestration.scheduler.job_config import QueryJobConfig, SearchJobConfig from job_orchestration.scheduler.query.reducer_handler import ReducerHandlerMessageQueues from pydantic import BaseModel, validator @@ -35,21 +40,33 @@ class InternalJobState(Enum): RUNNING = auto() -class QueryJob(BaseModel): +class QueryJob(BaseModel, ABC): id: str state: InternalJobState start_time: Optional[datetime.datetime] current_sub_job_async_task_result: Optional[Any] + @abstractmethod + def get_type(self) -> QueryJobType: ... + + @abstractmethod + def get_config(self) -> QueryJobConfig: ... + class SearchJob(QueryJob): - search_config: SearchConfig + search_config: SearchJobConfig num_archives_to_search: int num_archives_searched: int remaining_archives_for_search: List[Dict[str, Any]] reducer_acquisition_task: Optional[asyncio.Task] reducer_handler_msg_queues: Optional[ReducerHandlerMessageQueues] + def get_type(self) -> QueryJobType: + return QueryJobType.SEARCH_OR_AGGREGATION + + def get_config(self) -> QueryJobConfig: + return self.search_config + class Config: # To allow asyncio.Task and asyncio.Queue arbitrary_types_allowed = True diff --git a/components/webui/imports/api/search/constants.js b/components/webui/imports/api/search/constants.js index ec4c13ad6..fbc0c3188 100644 --- a/components/webui/imports/api/search/constants.js +++ b/components/webui/imports/api/search/constants.js @@ -81,6 +81,19 @@ const QUERY_JOB_STATUS_WAITING_STATES = [ QUERY_JOB_STATUS.CANCELLING, ]; +/* eslint-disable sort-keys */ +let enumQueryType; +/** + * Enum of job type, matching the `QueryJobType` class in + * `job_orchestration.query_scheduler.constants`. + * + * @enum {number} + */ +const QUERY_JOB_TYPE = Object.freeze({ + SEARCH_OR_AGGREGATION: (enumQueryType = 0), +}); +/* eslint-enable sort-keys */ + /** * Enum of Mongo Collection sort orders. * @@ -114,6 +127,7 @@ export { MONGO_SORT_ORDER, QUERY_JOB_STATUS, QUERY_JOB_STATUS_WAITING_STATES, + QUERY_JOB_TYPE, SEARCH_MAX_NUM_RESULTS, SEARCH_RESULTS_FIELDS, SEARCH_SIGNAL, diff --git a/components/webui/imports/api/search/server/QueryJobsDbManager.js b/components/webui/imports/api/search/server/QueryJobsDbManager.js index 4d3bed94a..835aae796 100644 --- a/components/webui/imports/api/search/server/QueryJobsDbManager.js +++ b/components/webui/imports/api/search/server/QueryJobsDbManager.js @@ -5,6 +5,7 @@ import {sleep} from "/imports/utils/misc"; import { QUERY_JOB_STATUS, QUERY_JOB_STATUS_WAITING_STATES, + QUERY_JOB_TYPE, } from "../constants"; @@ -21,6 +22,7 @@ const JOB_COMPLETION_STATUS_POLL_INTERVAL_MILLIS = 0.5; const QUERY_JOBS_TABLE_COLUMN_NAMES = Object.freeze({ ID: "id", STATUS: "status", + TYPE: "type", JOB_CONFIG: "job_config", }); @@ -52,9 +54,11 @@ class QueryJobsDbManager { async submitSearchJob (searchConfig) { const [queryInsertResults] = await this.#sqlDbConnPool.query( `INSERT INTO ${this.#queryJobsTableName} - (${QUERY_JOBS_TABLE_COLUMN_NAMES.JOB_CONFIG}) - VALUES (?)`, - [Buffer.from(msgpack.encode(searchConfig))], + (${QUERY_JOBS_TABLE_COLUMN_NAMES.JOB_CONFIG}, + ${QUERY_JOBS_TABLE_COLUMN_NAMES.TYPE}) + VALUES (?, ?)`, + [Buffer.from(msgpack.encode(searchConfig)), + QUERY_JOB_TYPE.SEARCH_OR_AGGREGATION], ); return queryInsertResults.insertId;