Skip to content

Commit

Permalink
Adding pagination in list_training_jobs (aws#323)
Browse files Browse the repository at this point in the history
* Adding pagination in list_Training_jobs
  • Loading branch information
Vikas-kum authored Aug 13, 2020
1 parent 247f9c8 commit cb45e75
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions smdebug/rules/action/stop_training_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,55 @@ def __init__(self, rule_name, training_job_prefix):
self._found_jobs = self._get_sm_tj_jobs_with_prefix()

def _get_sm_tj_jobs_with_prefix(self):
found_jobs = []
try:
jobs = self._sm_client.list_training_jobs()
if "TrainingJobSummaries" in jobs:
jobs = jobs["TrainingJobSummaries"]
else:
self._logger.info(
f"No TrainingJob summaries found: list_training_jobs output is : {jobs}"
)
return
for job in jobs:
res = {}
found_job_dict = {}
next_token = None
name = self._training_job_prefix
i = 0
while i < 50:
try:
if next_token is None:
res = self._sm_client.list_training_jobs(
NameContains=name,
SortBy="CreationTime",
SortOrder="Descending",
StatusEquals="InProgress",
)
else:
res = self._sm_client.list_training_jobs(
NextToken=next_token,
NameContains=name,
SortBy="CreationTime",
SortOrder="Descending",
StatusEquals="InProgress",
)
if "TrainingJobSummaries" in res:
jobs = res["TrainingJobSummaries"]
else:
self._logger.info(
f"No TrainingJob summaries found: list_training_jobs output is : {res}"
)
return
for job in jobs:
tj_status = job["TrainingJobStatus"]
tj_name = job["TrainingJobName"]
self._logger.info(f"TrainingJob name: {tj_name} , status:{tj_status}")
if tj_name is not None and tj_name.startswith(name):
found_job_dict[tj_name] = 1
self._logger.info(f"found_training job {found_job_dict.keys()}")
except Exception as e:
self._logger.info(
f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}"
f"Caught exception while getting list_training_job exception is: \n {e}. Attempt:{i}"
)
if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith(
self._training_job_prefix
):
found_jobs.append(job["TrainingJobName"])
self._logger.info(f"found_training job {found_jobs}")
except Exception as e:
self._logger.info(
f"Caught exception while getting list_training_job exception is: \n {e}"
)
return found_jobs
if "NextToken" not in res:
break
else:
next_token = res["NextToken"]
res = {}
jobs = {}
i += 1

return found_job_dict.keys()

def _stop_training_job(self):
if len(self._found_jobs) != 1:
Expand Down

0 comments on commit cb45e75

Please sign in to comment.