diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index a6660d50e70..51fbe338da7 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from csv import writer import os import copy import yaml @@ -87,27 +88,28 @@ def log_histogram(self, writer, tag, values, step=0, bins=1000): values = np.array(values) # Create histogram using numpy - counts, bin_edges = np.histogram(values, bins=bins) + # counts, bin_edges = np.histogram(values, bins=bins) # Fill fields of histogram proto - hist = tf.compat.v1.HistogramProto() - hist.min = float(np.min(values)) - hist.max = float(np.max(values)) - hist.num = int(np.prod(values.shape)) - hist.sum = float(np.sum(values)) - hist.sum_squares = float(np.sum(values**2)) + # hist = tf.compat.v1.HistogramProto() + # hist.min = float(np.min(values)) + # hist.max = float(np.max(values)) + # hist.num = int(np.prod(values.shape)) + # hist.sum = float(np.sum(values)) + # hist.sum_squares = float(np.sum(values**2)) - bin_edges = bin_edges[1:] + # bin_edges = bin_edges[1:] - for edge in bin_edges: - hist.bucket_limit.append(edge) - for c in counts: - hist.bucket.append(c) + # for edge in bin_edges: + # hist.bucket_limit.append(edge) + # for c in counts: + # hist.bucket.append(c) # Create and write Summary - summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, histo=hist)]) - writer.add_summary(summary, step) - writer.flush() + # update using TF2.X API + with writer.as_default(): + tf.summary.histogram(tag, values, step) + writer.flush() def _pre_hook_for_hvd(self, dataloader=None): import horovod.tensorflow as hvd @@ -275,7 +277,10 @@ def evaluate(self, model, dataloader, postprocess=None, if os.path.isdir(temp_dir): import shutil shutil.rmtree(temp_dir, ignore_errors=True) - writer = tf.compat.v1.summary.FileWriter(temp_dir, model.graph) + # Create the writer using TF2.x APIs to handle eager excutions + writer = tf.summary.create_file_writer(temp_dir) # pylint: disable=no-member + with writer.as_default(): + tf.summary.graph(model.graph) # pylint: disable=no-member cur_graph = GraphAnalyzer() cur_graph.graph = model.graph_def @@ -1310,7 +1315,8 @@ def get_optype_wise_ability(self): def _pre_eval_hook(self, model): return model - def _post_eval_hook(self, model): + # Add keyword arguments unpacking + def _post_eval_hook(self, model, **kwargs): pass def save(self, model, path):