diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 39f9601b3c..1545b57bf2 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -38,6 +38,7 @@ def __init__(self, estimator, logdir=None): threading.Thread.__init__(self) self.event = threading.Event() self.estimator = estimator + self._aws_sync_dir = tempfile.mkdtemp() self.logdir = logdir or tempfile.mkdtemp() @staticmethod @@ -47,6 +48,31 @@ def _cmd_exists(cmd): for path in os.environ["PATH"].split(os.pathsep) ) + @staticmethod + def _sync_directories(from_directory, to_directory): + """Sync to_directory with from_directory by copying each file in + to_directory with new contents. Why do this? Because TensorBoard picks + up temp files from `aws s3 sync` and then stops reading the correct + tfevent files. This is probably related to tensorflow/tensorboard#349. + + Args: + from_directory (str): The directory with updated files. + to_directory (str): The directory to be synced. + """ + if not os.path.exists(to_directory): + os.mkdir(to_directory) + for root, dirs, files in os.walk(from_directory): + to_root = root.replace(from_directory, to_directory) + for directory in dirs: + to_child_dir = os.path.join(to_root, directory) + if not os.path.exists(to_child_dir): + os.mkdir(to_child_dir) + for fname in files: + from_file = os.path.join(root, fname) + to_file = os.path.join(to_root, fname) + with open(from_file, 'rb') as a, open(to_file, 'wb') as b: + b.write(a.read()) + def validate_requirements(self): """Ensure that TensorBoard and the AWS CLI are installed. @@ -98,8 +124,9 @@ def run(self): while not self.estimator.checkpoint_path: self.event.wait(1) while not self.event.is_set(): - args = ['aws', 's3', 'sync', self.estimator.checkpoint_path, self.logdir] + args = ['aws', 's3', 'sync', self.estimator.checkpoint_path, self._aws_sync_dir] subprocess.call(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._sync_directories(self._aws_sync_dir, self.logdir) self.event.wait(10) tensorboard_process.terminate()