In [ ]:
import tensorflow as tf
import keras
In [ ]:
''' 元のTensorBoardだと、value.item()で死ぬのでvalueに変更。変更点はここだけ。 '''
class MyTensorBoard(keras.callbacks.TensorBoard):
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)
if self.model.uses_learning_phase:
cut_v_data = len(self.model.inputs)
val_data = self.validation_data[:cut_v_data] + [0]
tensors = self.model.inputs + [K.learning_phase()]
else:
val_data = self.validation_data
tensors = self.model.inputs
feed_dict = dict(zip(tensors, val_data))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
if self.embeddings_freq and self.embeddings_logs:
if epoch % self.embeddings_freq == 0:
for log in self.embeddings_logs:
self.saver.save(self.sess, log, epoch)
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value # Modified from: value.item()
summary_value.tag = name
self.writer.add_summary(summary, epoch)
self.writer.flush()