diff --git a/src/biokbase/narrative/jobs/job.py b/src/biokbase/narrative/jobs/job.py index a33bde4680..7499ae655c 100644 --- a/src/biokbase/narrative/jobs/job.py +++ b/src/biokbase/narrative/jobs/job.py @@ -194,16 +194,7 @@ def __getattr__(self, name): .get("narrative_cell_info", {}) .get("cell_id", JOB_ATTR_DEFAULTS["cell_id"]), "child_jobs": lambda: copy.deepcopy( - # TODO - # Only batch container jobs have a child_jobs field - # and need the state refresh. - # But KBParallel/KB Batch App jobs may not have the - # batch_job field - self.state(force_refresh=True).get( - "child_jobs", JOB_ATTR_DEFAULTS["child_jobs"] - ) - if self.batch_job - else self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) + self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) ), "job_id": lambda: self._acc_state.get("job_id"), "params": lambda: copy.deepcopy( @@ -212,13 +203,7 @@ def __getattr__(self, name): ) ), "retry_ids": lambda: copy.deepcopy( - # Batch container and retry jobs don't have a - # retry_ids field so skip the state refresh self._acc_state.get("retry_ids", JOB_ATTR_DEFAULTS["retry_ids"]) - if self.batch_job or self.retry_parent - else self.state(force_refresh=True).get( - "retry_ids", JOB_ATTR_DEFAULTS["retry_ids"] - ) ), "retry_parent": lambda: self._acc_state.get( "retry_parent", JOB_ATTR_DEFAULTS["retry_parent"] @@ -226,7 +211,7 @@ def __getattr__(self, name): "run_id": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("run_id", JOB_ATTR_DEFAULTS["run_id"]), - # TODO: add the status attribute! + "status": lambda: self._acc_state.get("status", ""), "tag": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("tag", JOB_ATTR_DEFAULTS["tag"]), @@ -264,20 +249,20 @@ def was_terminal(self): # add in a check for the case where this is a batch parent job # batch parent jobs with where all children have status "completed" are in a terminal state # otherwise, child jobs may be retried - if self._acc_state.get("batch_job"): + if self.batch_job: for child_job in self.children: - if child_job._acc_state.get("status") != COMPLETED_STATUS: + if child_job.status != COMPLETED_STATUS: return False return True else: - return self._acc_state.get("status") in TERMINAL_STATUSES + return self.status in TERMINAL_STATUSES def is_terminal(self): self.state() - if self._acc_state.get("batch_job"): + if self.batch_job: for child_job in self.children: - if child_job._acc_state.get("status") != COMPLETED_STATUS: + if child_job.status != COMPLETED_STATUS: child_job.state(force_refresh=True) return self.was_terminal() @@ -560,7 +545,7 @@ def _verify_children(self, children: List["Job"]) -> None: ) inst_child_ids = [job.job_id for job in children] - if sorted(inst_child_ids) != sorted(self._acc_state.get("child_jobs")): + if sorted(inst_child_ids) != sorted(self.child_jobs): raise ValueError("Child job id mismatch") def update_children(self, children: List["Job"]) -> None: diff --git a/src/biokbase/narrative/jobs/jobmanager.py b/src/biokbase/narrative/jobs/jobmanager.py index 0c80738f11..1630ccff24 100644 --- a/src/biokbase/narrative/jobs/jobmanager.py +++ b/src/biokbase/narrative/jobs/jobmanager.py @@ -742,6 +742,8 @@ def update_batch_job(self, batch_id: str) -> List[str]: if not batch_job.batch_job: raise JobRequestException(JOB_NOT_BATCH_ERR, batch_id) + # update the batch job + batch_job.state(force_refresh=True) child_ids = batch_job.child_jobs reg_child_jobs = [] diff --git a/src/biokbase/narrative/tests/test_job.py b/src/biokbase/narrative/tests/test_job.py index 7f8bf5dc3f..527756801e 100644 --- a/src/biokbase/narrative/tests/test_job.py +++ b/src/biokbase/narrative/tests/test_job.py @@ -23,7 +23,6 @@ ALL_JOBS, BATCH_CHILDREN, BATCH_PARENT, - BATCH_RETRY_RUNNING, CLIENTS, JOB_COMPLETED, JOB_CREATED, @@ -206,11 +205,8 @@ def check_job_attrs(self, job, job_id, exp_attrs=None, skip_state=False): attrs = create_attrs_from_ee2(job_id) attrs.update(exp_attrs) - # Mock here because job.child_jobs and job.retry_ids can - # cause EE2 query - with mock.patch(CLIENTS, get_mock_client): - for name, value in attrs.items(): - self.assertEqual(value, getattr(job, name)) + for name, value in attrs.items(): + self.assertEqual(value, getattr(job, name)) def test_job_init__error_no_job_id(self): @@ -322,7 +318,10 @@ def test_state__non_terminal(self): # ee2_state is fully populated (includes job_input, no job_output) job = create_job_from_ee2(JOB_CREATED) self.assertFalse(job.was_terminal()) - state = job.state() + + with assert_obj_method_called(MockClients, "check_job", call_status=True): + state = job.state() + self.assertFalse(job.was_terminal()) self.assertEqual(state["status"], "created") @@ -651,62 +650,6 @@ def test_query_job_states(self): ) self.assertEqual(exp, got) - def test_refresh_attrs__non_batch_active(self): - """ - retry_ids should be refreshed - """ - job_id = JOB_CREATED - job = create_job_from_ee2(job_id) - self.check_job_attrs(job, job_id) - - def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) - return {"retry_ids": self.NEW_RETRY_IDS} - - with mock.patch.object(MockClients, "check_job", mock_check_job): - self.check_job_attrs(job, job_id, {"retry_ids": self.NEW_RETRY_IDS}) - - def test_refresh_attrs__non_batch_terminal(self): - """ - retry_ids should be refreshed - """ - job_id = JOB_TERMINATED - job = create_job_from_ee2(job_id) - self.check_job_attrs(job, job_id) - - def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) - return {"retry_ids": self.NEW_RETRY_IDS} - - with mock.patch.object(MockClients, "check_job", mock_check_job): - self.check_job_attrs(job, job_id, {"retry_ids": self.NEW_RETRY_IDS}) - - def test_refresh_attrs__non_batch__is_retry(self): - """ - neither retry_ids/child_jobs should be refreshed - """ - job_id = BATCH_RETRY_RUNNING - job = create_job_from_ee2(job_id) - self.check_job_attrs(job, job_id) - - with assert_obj_method_called(MockClients, "check_job", call_status=False): - self.check_job_attrs(job, job_id, skip_state=True) - - def test_refresh_attrs__batch(self): - """ - child_jobs should be refreshed - """ - job_id = BATCH_PARENT - job = get_batch_family_jobs()[job_id] - self.check_job_attrs(job, job_id) - - def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) - return {"child_jobs": self.NEW_CHILD_JOBS} - - with mock.patch.object(MockClients, "check_job", mock_check_job): - self.check_job_attrs(job, job_id, {"child_jobs": self.NEW_CHILD_JOBS}) - def test_was_terminal(self): all_jobs = get_all_jobs() diff --git a/src/biokbase/narrative/tests/test_jobmanager.py b/src/biokbase/narrative/tests/test_jobmanager.py index dc2c36bb05..7dad51e394 100644 --- a/src/biokbase/narrative/tests/test_jobmanager.py +++ b/src/biokbase/narrative/tests/test_jobmanager.py @@ -776,14 +776,11 @@ def mock_check_job(params): reg_child_jobs = [ self.jm.get_job(job_id) for job_id in batch_job._acc_state["child_jobs"] ] - self.assertCountEqual(batch_job.children, reg_child_jobs) - self.assertCountEqual(batch_job._acc_state["child_jobs"], new_child_ids) - with mock.patch.object( - MockClients, "check_job", side_effect=mock_check_job - ) as m: + with assert_obj_method_called(MockClients, "check_job", call_status=False): self.assertCountEqual(batch_job.child_jobs, new_child_ids) + self.assertCountEqual(batch_job.child_jobs, batch_job._acc_state["child_jobs"]) def test_modify_job_refresh(self): for job_id, refreshing in REFRESH_STATE.items():