From f0c4c9256c9b1a6a5337762d763e4910631c65c4 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Mon, 4 Apr 2022 13:18:29 -0700 Subject: [PATCH] Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target. PiperOrigin-RevId: 439395101 --- tensorflow_graphics/projects/nasa/lib/models.py | 5 +++-- tensorflow_graphics/projects/nasa/train.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_graphics/projects/nasa/lib/models.py b/tensorflow_graphics/projects/nasa/lib/models.py index af540491b..f164a1cda 100644 --- a/tensorflow_graphics/projects/nasa/lib/models.py +++ b/tensorflow_graphics/projects/nasa/lib/models.py @@ -14,6 +14,7 @@ """Model Implementations.""" import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator from tensorflow_graphics.projects.nasa.lib import model_utils tf.disable_eager_execution() @@ -37,7 +38,7 @@ def nasa(hparams): sample_bbox = hparams.sample_bbox def _model_fn(features, labels, mode, params=None): - is_training = (mode == tf.estimator.ModeKeys.TRAIN) + is_training = (mode == tf_estimator.ModeKeys.TRAIN) batch_size = features['point'].shape[0] n_sample_frames = features['point'].shape[1] accum_size = batch_size * n_sample_frames @@ -116,7 +117,7 @@ def _model_fn(features, labels, mode, params=None): train_op = optimizer.minimize( indicator_loss, global_step=global_step, name='optimizer_shape') - return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return _model_fn diff --git a/tensorflow_graphics/projects/nasa/train.py b/tensorflow_graphics/projects/nasa/train.py index 1d6be17d5..a05bef4bd 100644 --- a/tensorflow_graphics/projects/nasa/train.py +++ b/tensorflow_graphics/projects/nasa/train.py @@ -15,6 +15,7 @@ import numpy as np import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator from tensorflow_graphics.projects.nasa.lib import datasets from tensorflow_graphics.projects.nasa.lib import models from tensorflow_graphics.projects.nasa.lib import utils @@ -45,13 +46,13 @@ def main(unused_argv): # Set up training. logging.info("=> Setting up training ...") - run_config = tf.estimator.RunConfig( + run_config = tf_estimator.RunConfig( model_dir=FLAGS.train_dir, save_checkpoints_steps=FLAGS.save_every, save_summary_steps=FLAGS.summary_every, keep_checkpoint_max=None, ) - trainer = tf.estimator.Estimator( + trainer = tf_estimator.Estimator( model_fn=model_fn, config=run_config, )