Skip to content
This repository has been archived by the owner on May 25, 2022. It is now read-only.

Commit

Permalink
[RDST-111] keep tensorboard directory clean during s3 sync (#3)
Browse files Browse the repository at this point in the history
* copy tf files to clean directory for tensorboard

* address comments
  • Loading branch information
Ben Cook authored Jan 24, 2018
1 parent 6577a90 commit 6acf497
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, estimator, logdir=None):
threading.Thread.__init__(self)
self.event = threading.Event()
self.estimator = estimator
self._aws_sync_dir = tempfile.mkdtemp()
self.logdir = logdir or tempfile.mkdtemp()

@staticmethod
Expand All @@ -47,6 +48,31 @@ def _cmd_exists(cmd):
for path in os.environ["PATH"].split(os.pathsep)
)

@staticmethod
def _sync_directories(from_directory, to_directory):
"""Sync to_directory with from_directory by copying each file in
to_directory with new contents. Why do this? Because TensorBoard picks
up temp files from `aws s3 sync` and then stops reading the correct
tfevent files. This is probably related to tensorflow/tensorboard#349.
Args:
from_directory (str): The directory with updated files.
to_directory (str): The directory to be synced.
"""
if not os.path.exists(to_directory):
os.mkdir(to_directory)
for root, dirs, files in os.walk(from_directory):
to_root = root.replace(from_directory, to_directory)
for directory in dirs:
to_child_dir = os.path.join(to_root, directory)
if not os.path.exists(to_child_dir):
os.mkdir(to_child_dir)
for fname in files:
from_file = os.path.join(root, fname)
to_file = os.path.join(to_root, fname)
with open(from_file, 'rb') as a, open(to_file, 'wb') as b:
b.write(a.read())

def validate_requirements(self):
"""Ensure that TensorBoard and the AWS CLI are installed.
Expand Down Expand Up @@ -98,8 +124,9 @@ def run(self):
while not self.estimator.checkpoint_path:
self.event.wait(1)
while not self.event.is_set():
args = ['aws', 's3', 'sync', self.estimator.checkpoint_path, self.logdir]
args = ['aws', 's3', 'sync', self.estimator.checkpoint_path, self._aws_sync_dir]
subprocess.call(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self._sync_directories(self._aws_sync_dir, self.logdir)
self.event.wait(10)
tensorboard_process.terminate()

Expand Down

0 comments on commit 6acf497

Please sign in to comment.