Skip to content

Commit

Permalink
Stop Waiting For Collection Files If Training Has Ended (aws#51)
Browse files Browse the repository at this point in the history
* stop waiting if training has ended

* fix incorrect merge

* Fail if collection files missing
  • Loading branch information
NihalHarish authored and rahul003 committed Nov 26, 2019
1 parent 57bb732 commit a99e163
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 12 deletions.
8 changes: 8 additions & 0 deletions smdebug/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def __str__(self):
return "Step {} of mode {} not yet available".format(self.step, self.mode.name)


class MissingCollectionFiles(Exception):
def __init__(self):
pass

def __str__(self):
return "Training job has ended. All the collection files could not be loaded"


class IndexReaderException(Exception):
def __init__(self, message):
self.message = message
Expand Down
28 changes: 16 additions & 12 deletions smdebug/trials/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
match_inc,
serialize_tf_device,
)
from smdebug.exceptions import NoMoreData, StepUnavailable, TensorUnavailable
from smdebug.exceptions import (
MissingCollectionFiles,
NoMoreData,
StepUnavailable,
TensorUnavailable,
)


class Trial(ABC):
Expand Down Expand Up @@ -149,22 +154,21 @@ def _fetch():
"Waiting to read collections files generated by the training job."
)

def _wait_for_first_collection_file():
while len(collection_files) == 0:
time.sleep(2)
_fetch()

def _wait_for_all_collection_files():
while len(collection_files) < self.num_workers:
def _wait_for_collection_files(number_of_collection_file_to_wait_for):
while len(collection_files) < number_of_collection_file_to_wait_for:
time.sleep(2)
_fetch()
for collection_file in collection_files:
self.worker_set.add(get_worker_name_from_collection_file(collection_file))
if has_training_ended(self.path):
""" _fetch should have returned all the collection files if the training job has ended """
if len(collection_files) < number_of_collection_file_to_wait_for:
raise MissingCollectionFiles

_fetch()
_wait_for_first_collection_file()
_wait_for_collection_files(1) # wait for the first collection file
self._read_collections(collection_files)
_wait_for_all_collection_files()
_wait_for_collection_files(self.num_workers) # wait for all the collection files
for collection_file in collection_files:
self.worker_set.add(get_worker_name_from_collection_file(collection_file))

@abstractmethod
def _load_tensors_from_index_tensors(self, index_tensors_dict):
Expand Down
65 changes: 65 additions & 0 deletions tests/analysis/trials/test_load_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Standard Library

# Third Party
import pytest

# First Party
from smdebug.exceptions import MissingCollectionFiles
from smdebug.trials import create_trial


@pytest.mark.slow
def test_load_collection_files_from_completed_job():
"""
Number of collection files : 2001
Training_has_ended.ts : Present
All the collection files have been written in the test dataset
and the training_has_ended file is present
:return:
"""
path = "s3://tornasole-testing/collection-tests/all-collection-files-present/"
try:
trial = create_trial(path)
except MissingCollectionFiles:
assert False
assert len(trial.workers()) == 2001


@pytest.mark.slow
def test_load_collection_files_from_completed_job_with_missing_files():
"""
Number of collection files : 1446
Training_has_ended.ts : Present
Some of the collection files have been removed in the test dataset.
The number of expected collection files is supposed to 2001
but the training_has_ended file is present so we stop waiting
:return:
"""
path = "s3://tornasole-testing/collection-tests/collection-files-missing/"
try:
trial = create_trial(path)
assert False
except MissingCollectionFiles:
assert True


@pytest.mark.slow
def test_load_collection_files_from_incomplete_job():
"""
Number of collection files : 2001
Training_has_ended.ts : Absent
All the collection files have been written in the test dataset
and the training_has_ended file is absent
:return:
"""
path = "s3://tornasole-testing/collection-tests/all-collection-files-present-job-incomplete/"
try:
trial = create_trial(path)
except MissingCollectionFiles:
assert False
assert len(trial.workers()) == 2001

0 comments on commit a99e163

Please sign in to comment.