diff --git a/tfx_bsl/beam/run_inference_test.py b/tfx_bsl/beam/run_inference_test.py index 4b027969..34a41e48 100644 --- a/tfx_bsl/beam/run_inference_test.py +++ b/tfx_bsl/beam/run_inference_test.py @@ -27,6 +27,7 @@ from googleapiclient import discovery from googleapiclient import http import tensorflow as tf +from tensorflow.compat.v1 import estimator as tf_estimator from tfx_bsl.beam import run_inference from tfx_bsl.beam import test_helpers from tfx_bsl.public.proto import model_spec_pb2 @@ -87,12 +88,12 @@ def _build_predict_model(self, model_path): default_value=0) } serving_receiver = ( - tf.compat.v1.estimator.export.build_parsing_serving_input_receiver_fn( + tf_estimator.export.build_parsing_serving_input_receiver_fn( input_tensors)()) output_tensors = {'y': serving_receiver.features['x'] * 2} sess = tf.compat.v1.Session() sess.run(tf.compat.v1.initializers.global_variables()) - signature_def = tf.compat.v1.estimator.export.PredictOutput( + signature_def = tf_estimator.export.PredictOutput( output_tensors).as_signature_def(serving_receiver.receiver_tensors) builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(model_path) builder.add_meta_graph_and_variables(