In [18]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
import sys
from six.moves import cPickle as pickle
%matplotlib inline

Read the small data


In [35]:
#pickle_file = 'mini_train.pickle'
pickle_file = 'train.pickle'

with open(pickle_file, 'rb') as f:
    save = pickle.load(f)
    mini_X_0 = save['data']
    mini_outcome = save['outcome']
    del save  # hint to help gc free up memory

In [36]:
#reformat the label
#for each digit, add a 'end_digit' as '10'
#for each label, add a digit size
#each of them is a one-hot coding

def label_reformat(label, max_size = 5):
    digit_size = np.asarray([len(x) for x in label])
    digit_size[digit_size > max_size]= max_size
    digit_size = ((np.arange(max_size)+1) == digit_size[:,None]).astype(np.float32)
    
    digits = {}
    end_digit = 10.0
    for i in range(max_size):
        digit_coding = np.asarray( [x[i] if len(x)>i else end_digit for x in label])
        digit_coding = (np.arange(end_digit+1) == digit_coding[:,None]).astype(np.float32)
        digits['digit_'+ str(i)] = digit_coding
        
    return digit_size, digits

sample a smaller data


In [37]:
label = mini_outcome['label'][:100]
digit_size, digits = label_reformat(label)
mini_X = mini_X_0[:100]

In [38]:
print digit_size.shape
print digits['digit_0'].shape
print mini_X.shape


(100, 5)
(100, 11)
(100, 64, 64, 3)

start tensorflow session


In [39]:
sess = tf.InteractiveSession()

In [40]:
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

In [41]:
image_size = mini_X.shape[1]
num_channels = mini_X.shape[3]
batch_size = 20

x_image = tf.placeholder(tf.float32, shape=(batch_size, image_size, image_size, num_channels))

y_d1 = tf.placeholder(tf.float32, shape=(batch_size, 11))
y_d2 = tf.placeholder(tf.float32, shape=(batch_size, 11))
y_d3 = tf.placeholder(tf.float32, shape=(batch_size, 11))
y_d4 = tf.placeholder(tf.float32, shape=(batch_size, 11))
y_d5 = tf.placeholder(tf.float32, shape=(batch_size, 11))

y_dsize = tf.placeholder(tf.float32, shape=(batch_size, 5))

In [42]:
def next_batch(X, y_dsize, y_ds, batch_size=50):
    idx = np.random.choice(X.shape[0],batch_size)
    batch_x = X[idx,:,:,:]
    batch_y_dsize = y_dsize[idx,:]
    batch_y_d1 = y_ds['digit_0'][idx,:]
    batch_y_d2 = y_ds['digit_1'][idx,:]
    batch_y_d3 = y_ds['digit_2'][idx,:]
    batch_y_d4 = y_ds['digit_3'][idx,:]
    batch_y_d5 = y_ds['digit_4'][idx,:]
    
    return batch_x, batch_y_dsize, batch_y_d1, batch_y_d2, batch_y_d3, batch_y_d4, batch_y_d5

Construct CNN


In [43]:
W_conv1 = weight_variable([5, 5, num_channels, 32])
b_conv1 = bias_variable([32])

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

In [44]:
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

In [45]:
W_fc1 = weight_variable([16 * 16 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 16*16*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

Drop out layer


In [46]:
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

Fully connected layers several different softmax header, for different digits and digit size


In [47]:
#first digit
W_fc2_d1 = weight_variable([1024, 11])
b_fc2_d1 = bias_variable([11])

y_conv_d1 = tf.matmul(h_fc1_drop, W_fc2_d1) + b_fc2_d1

#second digit
W_fc2_d2 = weight_variable([1024, 11])
b_fc2_d2 = bias_variable([11])

y_conv_d2 = tf.matmul(h_fc1_drop, W_fc2_d2) + b_fc2_d2

#third digit
W_fc2_d3 = weight_variable([1024, 11])
b_fc2_d3 = bias_variable([11])

y_conv_d3 = tf.matmul(h_fc1_drop, W_fc2_d3) + b_fc2_d3

#fourth digit
W_fc2_d4 = weight_variable([1024, 11])
b_fc2_d4 = bias_variable([11])

y_conv_d4 = tf.matmul(h_fc1_drop, W_fc2_d4) + b_fc2_d4

#fifth digit
W_fc2_d5 = weight_variable([1024, 11])
b_fc2_d5 = bias_variable([11])

y_conv_d5 = tf.matmul(h_fc1_drop, W_fc2_d5) + b_fc2_d5

#digit size
W_fc2_dsize = weight_variable([1024, 5])
b_fc2_dsize = bias_variable([5])

y_conv_dsize = tf.matmul(h_fc1_drop, W_fc2_dsize) + b_fc2_dsize

In [48]:
cross_entropy = ( tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_d1, y_d1)) 
                 + tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_d2, y_d2))
                 + tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_d3, y_d3))
                 + tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_d4, y_d4))
                 + tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_d5, y_d5))
                 + tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv_dsize, y_dsize))
                 )

train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

In [49]:
#let's just check the first digit
correct_prediction = tf.equal(tf.argmax(y_conv_d1,1), tf.argmax(y_d1,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Train model on a small data, see whether it overfit

if overfit, then good. If not, check bugs.


In [50]:
sess.run(tf.initialize_all_variables())
for i in range(1000):
    batch_x, batch_y_dsize, batch_y_d1, batch_y_d2, batch_y_d3, batch_y_d4, batch_y_d5 = next_batch(mini_X, digit_size, digits, batch_size)
    if i%10 == 0:
        train_accuracy = accuracy.eval(feed_dict={
                x_image: batch_x, y_dsize: batch_y_dsize,
                y_d1: batch_y_d1, y_d2: batch_y_d2, y_d3: batch_y_d3,
                y_d4: batch_y_d4, y_d5: batch_y_d5,
                keep_prob: 1.0})
        print("step %d, training accuracy %g"%(i, train_accuracy))
    train_step.run(feed_dict={
            x_image: batch_x, y_dsize: batch_y_dsize,
            y_d1: batch_y_d1, y_d2: batch_y_d2, y_d3: batch_y_d3,
            y_d4: batch_y_d4, y_d5: batch_y_d5,
            keep_prob: 0.5})


step 0, training accuracy 0.05
step 10, training accuracy 0.3
step 20, training accuracy 0.3
step 30, training accuracy 0.3
step 40, training accuracy 0.25
step 50, training accuracy 0.45
step 60, training accuracy 0.3
step 70, training accuracy 0.2
step 80, training accuracy 0.45
step 90, training accuracy 0.45
step 100, training accuracy 0.45
step 110, training accuracy 0.6
step 120, training accuracy 0.5
step 130, training accuracy 0.8
step 140, training accuracy 0.6
step 150, training accuracy 0.85
step 160, training accuracy 0.75
step 170, training accuracy 0.7
step 180, training accuracy 0.9
step 190, training accuracy 0.75
step 200, training accuracy 1
step 210, training accuracy 0.75
step 220, training accuracy 1
step 230, training accuracy 0.75
step 240, training accuracy 0.9
step 250, training accuracy 0.9
step 260, training accuracy 0.9
step 270, training accuracy 0.9
step 280, training accuracy 0.95
step 290, training accuracy 1
step 300, training accuracy 0.95
step 310, training accuracy 1
step 320, training accuracy 0.95
step 330, training accuracy 1
step 340, training accuracy 0.95
step 350, training accuracy 1
step 360, training accuracy 1
step 370, training accuracy 1
step 380, training accuracy 1
step 390, training accuracy 1
step 400, training accuracy 1
step 410, training accuracy 1
step 420, training accuracy 1
step 430, training accuracy 1
step 440, training accuracy 1
step 450, training accuracy 1
step 460, training accuracy 1
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-50-e1195155f09e> in <module>()
     14             y_d1: batch_y_d1, y_d2: batch_y_d2, y_d3: batch_y_d3,
     15             y_d4: batch_y_d4, y_d5: batch_y_d5,
---> 16             keep_prob: 0.5})

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in run(self, feed_dict, session)
   1617         none, the default session will be used.
   1618     """
-> 1619     _run_using_default_session(self, feed_dict, self.graph, session)
   1620 
   1621 

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in _run_using_default_session(operation, feed_dict, graph, session)
   3794                        "the operation's graph is different from the session's "
   3795                        "graph.")
-> 3796   session.run(operation, feed_dict)
   3797 
   3798 

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    715     try:
    716       result = self._run(None, fetches, feed_dict, options_ptr,
--> 717                          run_metadata_ptr)
    718       if run_metadata:
    719         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    913     if final_fetches or final_targets:
    914       results = self._do_run(handle, final_targets, final_fetches,
--> 915                              feed_dict_string, options, run_metadata)
    916     else:
    917       results = []

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
    963     if handle is None:
    964       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 965                            target_list, options, run_metadata)
    966     else:
    967       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
    970   def _do_call(self, fn, *args):
    971     try:
--> 972       return fn(*args)
    973     except errors.OpError as e:
    974       message = compat.as_text(e.message)

/home/josh/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
    952         return tf_session.TF_Run(session, options,
    953                                  feed_dict, fetch_list, target_list,
--> 954                                  status, run_metadata)
    955 
    956     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]:


In [ ]: