diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py index 53f1b5fdc2..e665ed0417 100644 --- a/smdebug/rules/action/stop_training_action.py +++ b/smdebug/rules/action/stop_training_action.py @@ -21,30 +21,55 @@ def __init__(self, rule_name, training_job_prefix): self._found_jobs = self._get_sm_tj_jobs_with_prefix() def _get_sm_tj_jobs_with_prefix(self): - found_jobs = [] - try: - jobs = self._sm_client.list_training_jobs() - if "TrainingJobSummaries" in jobs: - jobs = jobs["TrainingJobSummaries"] - else: - self._logger.info( - f"No TrainingJob summaries found: list_training_jobs output is : {jobs}" - ) - return - for job in jobs: + res = {} + found_job_dict = {} + next_token = None + name = self._training_job_prefix + i = 0 + while i < 50: + try: + if next_token is None: + res = self._sm_client.list_training_jobs( + NameContains=name, + SortBy="CreationTime", + SortOrder="Descending", + StatusEquals="InProgress", + ) + else: + res = self._sm_client.list_training_jobs( + NextToken=next_token, + NameContains=name, + SortBy="CreationTime", + SortOrder="Descending", + StatusEquals="InProgress", + ) + if "TrainingJobSummaries" in res: + jobs = res["TrainingJobSummaries"] + else: + self._logger.info( + f"No TrainingJob summaries found: list_training_jobs output is : {res}" + ) + return + for job in jobs: + tj_status = job["TrainingJobStatus"] + tj_name = job["TrainingJobName"] + self._logger.info(f"TrainingJob name: {tj_name} , status:{tj_status}") + if tj_name is not None and tj_name.startswith(name): + found_job_dict[tj_name] = 1 + self._logger.info(f"found_training job {found_job_dict.keys()}") + except Exception as e: self._logger.info( - f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}" + f"Caught exception while getting list_training_job exception is: \n {e}. Attempt:{i}" ) - if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith( - self._training_job_prefix - ): - found_jobs.append(job["TrainingJobName"]) - self._logger.info(f"found_training job {found_jobs}") - except Exception as e: - self._logger.info( - f"Caught exception while getting list_training_job exception is: \n {e}" - ) - return found_jobs + if "NextToken" not in res: + break + else: + next_token = res["NextToken"] + res = {} + jobs = {} + i += 1 + + return found_job_dict.keys() def _stop_training_job(self): if len(self._found_jobs) != 1: