In [1]:
from __future__ import division, print_function
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
plt.rcParams['image.cmap'] = 'gist_earth'
np.random.seed(98765)

In [2]:
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util

In [3]:
nx = 572
ny = 572

In [4]:
generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20)

In [5]:
x_test, y_test = generator(1)

In [6]:
fig, ax = plt.subplots(1,2, sharey=True, figsize=(8,4))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")


Out[6]:
<matplotlib.image.AxesImage at 0x112595da0>

In [7]:
net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)


2017-03-27 21:11:45,858 Layers 3, features 16, filter size 3x3, pool size: 2x2

In [8]:
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))

In [9]:
path = trainer.train(generator, "./unet_trained", training_iters=32, epochs=10, display_step=2)


2017-03-27 21:11:51,941 Removing '/Users/jakeret/workspace/tf_unet/demo/prediction'
2017-03-27 21:11:51,944 Removing '/Users/jakeret/workspace/tf_unet/demo/unet_trained'
2017-03-27 21:11:51,946 Allocating '/Users/jakeret/workspace/tf_unet/demo/prediction'
2017-03-27 21:11:51,947 Allocating '/Users/jakeret/workspace/tf_unet/demo/unet_trained'
2017-03-27 21:12:00,946 Verification error= 19.4%, loss= 0.6514
2017-03-27 21:12:04,516 Start optimization
2017-03-27 21:12:09,380 Iter 0, Minibatch Loss= 0.6272, Training Accuracy= 0.7682, Minibatch error= 23.2%
2017-03-27 21:12:17,544 Iter 2, Minibatch Loss= 0.5162, Training Accuracy= 0.8452, Minibatch error= 15.5%
2017-03-27 21:12:25,593 Iter 4, Minibatch Loss= 0.4626, Training Accuracy= 0.8439, Minibatch error= 15.6%
2017-03-27 21:12:33,677 Iter 6, Minibatch Loss= 0.4321, Training Accuracy= 0.8621, Minibatch error= 13.8%
2017-03-27 21:12:41,630 Iter 8, Minibatch Loss= 0.4044, Training Accuracy= 0.8681, Minibatch error= 13.2%
2017-03-27 21:12:49,660 Iter 10, Minibatch Loss= 0.4302, Training Accuracy= 0.8460, Minibatch error= 15.4%
2017-03-27 21:12:57,779 Iter 12, Minibatch Loss= 0.4194, Training Accuracy= 0.8436, Minibatch error= 15.6%
2017-03-27 21:13:07,131 Iter 14, Minibatch Loss= 0.4451, Training Accuracy= 0.8125, Minibatch error= 18.8%
2017-03-27 21:13:15,672 Iter 16, Minibatch Loss= 0.4421, Training Accuracy= 0.7920, Minibatch error= 20.8%
2017-03-27 21:13:24,290 Iter 18, Minibatch Loss= 0.3628, Training Accuracy= 0.8238, Minibatch error= 17.6%
2017-03-27 21:13:28,321 Epoch 0, Average loss: 0.4787, learning rate: 0.2000
2017-03-27 21:13:35,880 Verification error= 19.4%, loss= 0.3911
2017-03-27 21:13:43,159 Iter 20, Minibatch Loss= 0.3515, Training Accuracy= 0.8206, Minibatch error= 17.9%
2017-03-27 21:13:51,190 Iter 22, Minibatch Loss= 0.3391, Training Accuracy= 0.8178, Minibatch error= 18.2%
2017-03-27 21:13:59,170 Iter 24, Minibatch Loss= 0.4205, Training Accuracy= 0.8146, Minibatch error= 18.5%
2017-03-27 21:14:07,244 Iter 26, Minibatch Loss= 0.2659, Training Accuracy= 0.8682, Minibatch error= 13.2%
2017-03-27 21:14:15,954 Iter 28, Minibatch Loss= 0.3302, Training Accuracy= 0.7807, Minibatch error= 21.9%
2017-03-27 21:14:23,971 Iter 30, Minibatch Loss= 0.2170, Training Accuracy= 0.8774, Minibatch error= 12.3%
2017-03-27 21:14:32,109 Iter 32, Minibatch Loss= 0.2471, Training Accuracy= 0.8493, Minibatch error= 15.1%
2017-03-27 21:14:40,141 Iter 34, Minibatch Loss= 0.2075, Training Accuracy= 0.8305, Minibatch error= 16.9%
2017-03-27 21:14:48,089 Iter 36, Minibatch Loss= 0.2787, Training Accuracy= 0.9224, Minibatch error= 7.8%
2017-03-27 21:14:56,550 Iter 38, Minibatch Loss= 0.3094, Training Accuracy= 0.9062, Minibatch error= 9.4%
2017-03-27 21:15:00,023 Epoch 1, Average loss: 0.3196, learning rate: 0.1900
2017-03-27 21:15:07,417 Verification error= 4.9%, loss= 0.2191

....

2017-03-27 21:26:17,925 Iter 180, Minibatch Loss= 0.0467, Training Accuracy= 0.9918, Minibatch error= 0.8%
2017-03-27 21:26:26,481 Iter 182, Minibatch Loss= 0.0567, Training Accuracy= 0.9800, Minibatch error= 2.0%
2017-03-27 21:26:36,835 Iter 184, Minibatch Loss= 0.0631, Training Accuracy= 0.9850, Minibatch error= 1.5%
2017-03-27 21:26:49,194 Iter 186, Minibatch Loss= 0.2021, Training Accuracy= 0.9514, Minibatch error= 4.9%
2017-03-27 21:26:59,758 Iter 188, Minibatch Loss= 0.2536, Training Accuracy= 0.9525, Minibatch error= 4.8%
2017-03-27 21:27:08,597 Iter 190, Minibatch Loss= 0.1040, Training Accuracy= 0.9590, Minibatch error= 4.1%
2017-03-27 21:27:19,947 Iter 192, Minibatch Loss= 0.1009, Training Accuracy= 0.9722, Minibatch error= 2.8%
2017-03-27 21:27:35,962 Iter 194, Minibatch Loss= 1.3488, Training Accuracy= 0.8229, Minibatch error= 17.7%
2017-03-27 21:27:48,464 Iter 196, Minibatch Loss= 0.2191, Training Accuracy= 0.9268, Minibatch error= 7.3%
2017-03-27 21:27:58,009 Iter 198, Minibatch Loss= 0.1754, Training Accuracy= 0.9502, Minibatch error= 5.0%
2017-03-27 21:28:04,737 Epoch 9, Average loss: 0.2360, learning rate: 0.1260
2017-03-27 21:28:13,116 Verification error= 5.1%, loss= 0.1608
2017-03-27 21:28:17,602 Optimization Finished!

In [22]:
x_test, y_test = generator(1)

prediction = net.predict("./unet_trained/model.ckpt", x_test)


2017-03-27 21:31:00,595 Model restored from file: ./unet_trained/model.ckpt

In [23]:
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
mask = prediction[0,...,1] > 0.9
ax[2].imshow(mask, aspect="auto")
ax[0].set_title("Input")
ax[1].set_title("Ground truth")
ax[2].set_title("Prediction")
fig.tight_layout()
fig.savefig("../docs/toy_problem.png")



In [ ]: