In [1]:
import os, pickle, sys
import numpy as np
import tensorflow as tf
from scipy import misc
from glob import glob
from sklearn.metrics import roc_auc_score

from utils import rescale, weight_variable, bias_variable, conv2d, max_pool, image_path, tmp_image_file, model_file, max_shape

# PARAMETERS
n_epochs = 100

# CNN parameters
n_convo_layer1 = 32
n_convo_layer2 = 32
inner_layer_size = 50
percep_size = 1024

# training: batch size
batch_len = 50

Step 1 - Import images, rotate and rescale them


In [2]:
if os.path.exists(tmp_image_file):
    images = pickle.load(open(tmp_image_file, "rb"))
    
else:
    images = dict()
    for what in ['sushi', 'sandwich']:
        print("Loading images for '%s'..." % what)
        files = glob(image_dir % what)
        full_images = [misc.imread(f, mode="L") for f in files]
        images[what] = [rescale(img, max_shape) for img in full_images]
        
        for angle in [90, 180, 270]:
            print("-- rotating by %d degrees..." % angle)
            images[what].extend([rescale(misc.imrotate(img, angle), max_shape) for img in full_images])
    
    pickle.dump(images, open(tmp_image_file, "wb"))

Step 2 - define the Tensorflow graph


In [3]:
sess = tf.Session()

# input layer
X = tf.placeholder(tf.float32, shape=[None, max_shape[0], max_shape[1], 1])
y = tf.placeholder(tf.float32, shape=[None, 1])
keep_prob = tf.placeholder(tf.float32)

# first convolution
W_conv1 = weight_variable([10, 10, 1, n_convo_layer1])
b_conv1 = bias_variable([n_convo_layer1])

h_conv1 = tf.nn.sigmoid(conv2d(X, W_conv1) + b_conv1)
h_pool1 = max_pool(h_conv1, ksize=[1, 2, 2, 1])

# second convolution
W_conv2 = weight_variable([10, 10, n_convo_layer1, n_convo_layer2])
b_conv2 = bias_variable([n_convo_layer2])

h_conv2 = tf.nn.sigmoid(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool(h_conv2, ksize=[1, 2, 2, 1])

# first fully connected layer
W_fc1 = weight_variable([inner_layer_size * inner_layer_size * n_convo_layer2, percep_size])
b_fc1 = bias_variable([percep_size])

h_pool2_flat = tf.reshape(h_pool2, [-1, inner_layer_size * inner_layer_size * n_convo_layer2])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# second fully connected layer
W_fc2 = weight_variable([percep_size, percep_size])
b_fc2 = bias_variable([percep_size])

h_fc2 = tf.nn.tanh(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)

# output layer
W_fc3 = weight_variable([percep_size, 1])
b_fc3 = bias_variable([1])

logits = tf.matmul(h_fc2_drop, W_fc3) + b_fc3

cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits))
train_step = tf.train.RMSPropOptimizer(1e-5).minimize(cross_entropy)

sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

Step 3 - Set aside some samples for testing


In [4]:
X_array = np.stack(images['sushi'] + images['sandwich']).reshape(-1, max_shape[0], max_shape[1], 1)
y_array = np.array([1] * len(images['sushi']) + [0] * len(images['sandwich'])).reshape(-1, 1)

test_ix = np.random.choice(len(X_array), batch_len)
train_ix = np.setdiff1d(range(len(X_array)), test_ix)

Step 4 - Train the CNN


In [5]:
for epoch in range(n_epochs):
    np.random.shuffle(train_ix)
    for batch_ix in np.array_split(train_ix, len(train_ix) / batch_len):
        sess.run(train_step, feed_dict={X: X_array[batch_ix], y: y_array[batch_ix], keep_prob: 0.9})
    
    train_sample = np.random.choice(train_ix, batch_len)
    ce, train_logits = sess.run([cross_entropy, logits], feed_dict={X: X_array[train_sample], y: y_array[train_sample], keep_prob: 1.0})
    train_auc = roc_auc_score(y_array[train_sample], train_logits)
    print("epoch {0:,}; cross-entropy: {1:.3f}".format(epoch, ce))
    
    test_logits = sess.run(logits, feed_dict={X: X_array[test_ix], y: y_array[test_ix], keep_prob: 1})
    test_auc = roc_auc_score(y_array[test_ix], test_logits)
    
    print("AUC: train {0:.2f}, test {1:.2f}".format(train_auc, test_auc))

[tf.add_to_collection('test', x) for x in [X, y, keep_prob, logits]]
saver.save(sess, model_file)


epoch 0; cross-entropy: 0.800
AUC: train 0.48, test 0.52
epoch 1; cross-entropy: 0.819
AUC: train 0.57, test 0.59
epoch 2; cross-entropy: 0.799
AUC: train 0.56, test 0.67
epoch 3; cross-entropy: 0.656
AUC: train 0.68, test 0.64
epoch 4; cross-entropy: 0.840
AUC: train 0.55, test 0.59
epoch 5; cross-entropy: 0.682
AUC: train 0.64, test 0.72
epoch 6; cross-entropy: 0.706
AUC: train 0.59, test 0.69
epoch 7; cross-entropy: 0.607
AUC: train 0.78, test 0.82
epoch 8; cross-entropy: 0.770
AUC: train 0.63, test 0.67
epoch 9; cross-entropy: 0.701
AUC: train 0.55, test 0.68
epoch 10; cross-entropy: 0.708
AUC: train 0.54, test 0.76
epoch 11; cross-entropy: 0.774
AUC: train 0.48, test 0.73
epoch 12; cross-entropy: 0.735
AUC: train 0.61, test 0.81
epoch 13; cross-entropy: 0.728
AUC: train 0.44, test 0.75
epoch 14; cross-entropy: 0.602
AUC: train 0.73, test 0.77
epoch 15; cross-entropy: 0.703
AUC: train 0.52, test 0.71
epoch 16; cross-entropy: 0.677
AUC: train 0.63, test 0.72
epoch 17; cross-entropy: 0.640
AUC: train 0.69, test 0.68
epoch 18; cross-entropy: 0.739
AUC: train 0.55, test 0.72
epoch 19; cross-entropy: 0.620
AUC: train 0.71, test 0.73
epoch 20; cross-entropy: 0.666
AUC: train 0.57, test 0.76
epoch 21; cross-entropy: 0.750
AUC: train 0.59, test 0.69
epoch 22; cross-entropy: 0.686
AUC: train 0.61, test 0.70
epoch 23; cross-entropy: 0.578
AUC: train 0.86, test 0.75
epoch 24; cross-entropy: 0.714
AUC: train 0.56, test 0.72
epoch 25; cross-entropy: 0.614
AUC: train 0.73, test 0.71
epoch 26; cross-entropy: 0.806
AUC: train 0.45, test 0.80
epoch 27; cross-entropy: 0.694
AUC: train 0.59, test 0.72
epoch 28; cross-entropy: 0.684
AUC: train 0.64, test 0.78
epoch 29; cross-entropy: 0.725
AUC: train 0.56, test 0.81
epoch 30; cross-entropy: 0.688
AUC: train 0.70, test 0.71
epoch 31; cross-entropy: 0.779
AUC: train 0.57, test 0.81
epoch 32; cross-entropy: 0.607
AUC: train 0.73, test 0.81
epoch 33; cross-entropy: 0.668
AUC: train 0.63, test 0.81
epoch 34; cross-entropy: 0.792
AUC: train 0.55, test 0.81
epoch 35; cross-entropy: 0.721
AUC: train 0.61, test 0.81
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-5-23e2333df9d0> in <module>()
      2     np.random.shuffle(train_ix)
      3     for batch_ix in np.array_split(train_ix, len(train_ix) / batch_len):
----> 4         sess.run(train_step, feed_dict={X: X_array[batch_ix], y: y_array[batch_ix], keep_prob: 0.9})
      5 
      6     train_sample = np.random.choice(train_ix, batch_len)

/home/bean/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/bean/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/home/bean/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/bean/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/home/bean/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: