Skip to content

Commit

Permalink
Fix TensorBoard callback warning
Browse files Browse the repository at this point in the history
  • Loading branch information
bnaul committed May 2, 2016
1 parent 33af75a commit e04ce5e
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,15 @@ class TensorBoard(Callback):
histograms for the layers of the model. If set to 0,
histograms won't be computed.
'''
def __init__(self, log_dir='./logs', histogram_freq=0):
def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True):
super(Callback, self).__init__()
if K._BACKEND != 'tensorflow':
raise Exception('TensorBoard callback only works '
'with the TensorFlow backend.')
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.merged = None
self.write_graph = write_graph

def _set_model(self, model):
import tensorflow as tf
Expand All @@ -457,8 +458,16 @@ def _set_model(self, model):
tf.histogram_summary('{}_out'.format(layer),
layer.output)
self.merged = tf.merge_all_summaries()
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph_def)
if self.write_graph:
tf_version = tuple(int(i) for i in tf.__version__.split('.'))
if tf_version >= (0, 8, 0):
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph)
else:
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph_def)
else:
self.writer = tf.train.SummaryWriter(self.log_dir)

def on_epoch_end(self, epoch, logs={}):
import tensorflow as tf
Expand Down

0 comments on commit e04ce5e

Please sign in to comment.