In [1]:
from proton_decay_study.generators.gen3d import Gen3D
from proton_decay_study.models.kevnet import Kevnet
import tensorflow as tf
import logging
import glob
In [2]:
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
generator = Gen3D(glob.glob('../../*.h5'), 'image/wires','label/type', batch_size=1)
model = Kevnet(generator)
In [ ]:
#for i in range(11):
# generator.next()
In [ ]:
X,Y = generator.next()
model.fit(X, Y, batch_size=1, epochs=1, verbose=1, callbacks=None,
validation_split=0.0, validation_data=None,
shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0)
In [ ]:
"""
training_output = model.fit_generator(generator,
steps_per_epoch=1,
epochs=1,
workers=1,
verbose=0,
max_q_size=1,
pickle_safe=False
)
"""