Skip to content

Commit

Permalink
Fixed a tensorboard bug and upgrade the summry writer object using TF…
Browse files Browse the repository at this point in the history
…2.… (#1232)
  • Loading branch information
qgao007 authored Sep 20, 2022
1 parent fc6f9a7 commit f348529
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from csv import writer
import os
import copy
import yaml
Expand Down Expand Up @@ -87,27 +88,28 @@ def log_histogram(self, writer, tag, values, step=0, bins=1000):
values = np.array(values)

# Create histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# counts, bin_edges = np.histogram(values, bins=bins)

# Fill fields of histogram proto
hist = tf.compat.v1.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# hist = tf.compat.v1.HistogramProto()
# hist.min = float(np.min(values))
# hist.max = float(np.max(values))
# hist.num = int(np.prod(values.shape))
# hist.sum = float(np.sum(values))
# hist.sum_squares = float(np.sum(values**2))

bin_edges = bin_edges[1:]
# bin_edges = bin_edges[1:]

for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# for edge in bin_edges:
# hist.bucket_limit.append(edge)
# for c in counts:
# hist.bucket.append(c)

# Create and write Summary
summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, histo=hist)])
writer.add_summary(summary, step)
writer.flush()
# update using TF2.X API
with writer.as_default():
tf.summary.histogram(tag, values, step)
writer.flush()

def _pre_hook_for_hvd(self, dataloader=None):
import horovod.tensorflow as hvd
Expand Down Expand Up @@ -275,7 +277,10 @@ def evaluate(self, model, dataloader, postprocess=None,
if os.path.isdir(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
writer = tf.compat.v1.summary.FileWriter(temp_dir, model.graph)
# Create the writer using TF2.x APIs to handle eager excutions
writer = tf.summary.create_file_writer(temp_dir) # pylint: disable=no-member
with writer.as_default():
tf.summary.graph(model.graph) # pylint: disable=no-member

cur_graph = GraphAnalyzer()
cur_graph.graph = model.graph_def
Expand Down Expand Up @@ -1310,7 +1315,8 @@ def get_optype_wise_ability(self):
def _pre_eval_hook(self, model):
return model

def _post_eval_hook(self, model):
# Add keyword arguments unpacking
def _post_eval_hook(self, model, **kwargs):
pass

def save(self, model, path):
Expand Down

0 comments on commit f348529

Please sign in to comment.