Skip to content

Commit

Permalink
Explicitly import estimator from tensorflow as a separate import inst…
Browse files Browse the repository at this point in the history
…ead of

accessing it via tf.estimator and depend on the tensorflow estimator target.

PiperOrigin-RevId: 439395101
  • Loading branch information
hertschuh authored and copybara-github committed Apr 4, 2022
1 parent 6fc7fd5 commit f0c4c92
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions tensorflow_graphics/projects/nasa/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Model Implementations."""
import tensorflow.compat.v1 as tf

from tensorflow.compat.v1 import estimator as tf_estimator
from tensorflow_graphics.projects.nasa.lib import model_utils

tf.disable_eager_execution()
Expand All @@ -37,7 +38,7 @@ def nasa(hparams):
sample_bbox = hparams.sample_bbox

def _model_fn(features, labels, mode, params=None):
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
batch_size = features['point'].shape[0]
n_sample_frames = features['point'].shape[1]
accum_size = batch_size * n_sample_frames
Expand Down Expand Up @@ -116,7 +117,7 @@ def _model_fn(features, labels, mode, params=None):
train_op = optimizer.minimize(
indicator_loss, global_step=global_step, name='optimizer_shape')

return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

return _model_fn

Expand Down
5 changes: 3 additions & 2 deletions tensorflow_graphics/projects/nasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import tensorflow.compat.v1 as tf

from tensorflow.compat.v1 import estimator as tf_estimator
from tensorflow_graphics.projects.nasa.lib import datasets
from tensorflow_graphics.projects.nasa.lib import models
from tensorflow_graphics.projects.nasa.lib import utils
Expand Down Expand Up @@ -45,13 +46,13 @@ def main(unused_argv):

# Set up training.
logging.info("=> Setting up training ...")
run_config = tf.estimator.RunConfig(
run_config = tf_estimator.RunConfig(
model_dir=FLAGS.train_dir,
save_checkpoints_steps=FLAGS.save_every,
save_summary_steps=FLAGS.summary_every,
keep_checkpoint_max=None,
)
trainer = tf.estimator.Estimator(
trainer = tf_estimator.Estimator(
model_fn=model_fn,
config=run_config,
)
Expand Down

0 comments on commit f0c4c92

Please sign in to comment.