In [1]:
%matplotlib inline
import math,sys,os,numpy as np
from numpy.linalg import norm
from PIL import Image
from matplotlib import pyplot as plt, rcParams, rc
from scipy.ndimage import imread
from skimage.measure import block_reduce
import cPickle as pickle
from scipy.ndimage.filters import correlate, convolve
from ipywidgets import interact, interactive, fixed
from ipywidgets.widgets import *
rc('animation', html='html5')
rcParams['figure.figsize'] = 3, 6
%precision 4
np.set_printoptions(precision=4, linewidth=100)

In [2]:
"""
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
images, labels = mnist.train.images, mnist.train.labels
images = images.reshape((55000,28,28))
np.savez_compressed("MNIST_data/train", images=images, labels=labels)
"""
1


Out[2]:
1

In [7]:
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() # saves to /root/.keras/datasets/mnist.pkl.gz

(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

# Alternate:
#import cPickle
#import gzip
#f = gzip.open('/root/.keras/datasets/mnist.pkl.gz', 'rb')

#if sys.version_info < (3,):
#    data = cPickle.load(f)
#else:
#    data = cPickle.load(f, encoding='bytes')
#f.close()
#(x_train, y_train), (x_test, y_test) = data


Out[7]:
((60000L, 28L, 28L), (60000L,), (10000L, 28L, 28L), (10000L,))

In [8]:
# Add 1 dimension for the color channel.
x_train = np.expand_dims(x_train, 1) 
x_test = np.expand_dims(x_test, 1) 

# Also need to one-hot encode the training labels
from utils import onehot
y_train = onehot(y_train)
y_test = onehot(y_test)

(x_train.shape, y_train.shape, x_test.shape, y_test.shape)


Out[8]:
((60000L, 1L, 28L, 28L), (60000L, 10L), (10000L, 1L, 28L, 28L), (10000L, 10L))

In [15]:
def plots(ims, interp=False, titles=None):
    ims=np.array(ims)
    mn,mx=ims.min(),ims.max()
    f = plt.figure(figsize=(12,24))
    for i in range(len(ims)):
        sp=f.add_subplot(1, len(ims), i+1)
        if not titles is None: sp.set_title(titles[i], fontsize=18)
        plt.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn,vmax=mx)

def plot(im, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    plt.imshow(im, interpolation=None if interp else 'none')

plt.gray()
plt.close()

In [26]:
??plt

In [16]:
#data = np.load("MNIST_data/train.npz")
#images=data['images']
#labels=data['labels']
images=x_train
labels=y_train
n=len(images)
images.shape


Out[16]:
(60000L, 1L, 28L, 28L)

In [22]:
inspect_idx = 15
inspect_slice = slice(inspect_idx-5,inspect_idx+5)
#images[inspect_idx]

In [28]:
plot(images[inspect_idx])
#plt.plot(images[inspect_idx])


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-28-47d82c6728db> in <module>()
----> 1 plot(images[inspect_idx])
      2 #plt.plot(images[inspect_idx])

<ipython-input-15-def34cadf765> in plot(im, interp)
     10 def plot(im, interp=False):
     11     f = plt.figure(figsize=(3,6), frameon=True)
---> 12     plt.imshow(im, interpolation=None if interp else 'none')
     13 
     14 plt.gray()

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\pyplot.pyc in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
   3155                         filternorm=filternorm, filterrad=filterrad,
   3156                         imlim=imlim, resample=resample, url=url, data=data,
-> 3157                         **kwargs)
   3158     finally:
   3159         ax._hold = washold

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\__init__.pyc in inner(ax, *args, **kwargs)
   1895                     warnings.warn(msg % (label_namer, func.__name__),
   1896                                   RuntimeWarning, stacklevel=2)
-> 1897             return func(ax, *args, **kwargs)
   1898         pre_doc = inner.__doc__
   1899         if pre_doc is None:

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\axes\_axes.pyc in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
   5122                               resample=resample, **kwargs)
   5123 
-> 5124         im.set_data(X)
   5125         im.set_alpha(alpha)
   5126         if im.get_clip_path() is None:

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\image.pyc in set_data(self, A)
    598         if (self._A.ndim not in (2, 3) or
    599                 (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
--> 600             raise TypeError("Invalid dimensions for image data")
    601 
    602         self._imcache = None

TypeError: Invalid dimensions for image data

In [24]:
labels[inspect_idx]


Out[24]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.])

In [25]:
plots(images[inspect_slice], titles=labels[inspect_slice])


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-25-18a96457e9c4> in <module>()
----> 1 plots(images[inspect_slice], titles=labels[inspect_slice])

<ipython-input-15-def34cadf765> in plots(ims, interp, titles)
      6         sp=f.add_subplot(1, len(ims), i+1)
      7         if not titles is None: sp.set_title(titles[i], fontsize=18)
----> 8         plt.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn,vmax=mx)
      9 
     10 def plot(im, interp=False):

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\pyplot.pyc in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
   3155                         filternorm=filternorm, filterrad=filterrad,
   3156                         imlim=imlim, resample=resample, url=url, data=data,
-> 3157                         **kwargs)
   3158     finally:
   3159         ax._hold = washold

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\__init__.pyc in inner(ax, *args, **kwargs)
   1895                     warnings.warn(msg % (label_namer, func.__name__),
   1896                                   RuntimeWarning, stacklevel=2)
-> 1897             return func(ax, *args, **kwargs)
   1898         pre_doc = inner.__doc__
   1899         if pre_doc is None:

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\axes\_axes.pyc in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
   5122                               resample=resample, **kwargs)
   5123 
-> 5124         im.set_data(X)
   5125         im.set_alpha(alpha)
   5126         if im.get_clip_path() is None:

C:\Users\matsaleh\AppData\Local\conda\conda\envs\fastai2\lib\site-packages\matplotlib\image.pyc in set_data(self, A)
    598         if (self._A.ndim not in (2, 3) or
    599                 (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
--> 600             raise TypeError("Invalid dimensions for image data")
    601 
    602         self._imcache = None

TypeError: Invalid dimensions for image data

In [10]:
top=[[-1,-1,-1],
     [ 1, 1, 1],
     [ 0, 0, 0]]

plot(top)



In [11]:
r=(0,28)
def zoomim(x1=0,x2=28,y1=0,y2=28):
    plot(images[inspect_idx,y1:y2,x1:x2])
w=interactive(zoomim, x1=r,x2=r,y1=r,y2=r)
w



In [12]:
k=w.kwargs
dims = np.index_exp[k['y1']:k['y2']:1,k['x1']:k['x2']]
images[inspect_idx][dims]


Out[12]:
array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 115, 121, 162,
        253, 253, 213,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  63, 107, 170, 251, 252, 252,
        252, 252, 250, 214,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  25, 192, 226, 226, 241, 252, 253, 202, 252, 252,
        252, 252, 252, 225,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  68, 223, 252, 252, 252, 252, 252,  39,  19,  39,  65,
        224, 252, 252, 183,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0, 186, 252, 252, 252, 245, 108,  53,   0,   0,   0, 150,
        252, 252, 220,  20,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,  70, 242, 252, 252, 222,  59,   0,   0,   0,   0,   0, 178,
        252, 252, 141,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0, 185, 252, 252, 194,  67,   0,   0,   0,   0,  17,  90, 240,
        252, 194,  67,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,  83, 205, 190,  24,   0,   0,   0,   0,   0, 121, 252, 252,
        209,  24,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  77, 247, 252, 248,
        106,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 253, 252, 252, 102,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 134, 255, 253, 253,  39,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6, 183, 253, 252, 107,   2,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  10, 102, 252, 253, 163,  16,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  13, 168, 252, 252, 110,   2,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  41, 252, 252, 217,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  40, 155, 252, 214,  31,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0, 165, 252, 252, 106,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  43, 179, 252, 150,  39,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0, 137, 252, 221,  39,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  67, 252,  79,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=uint8)

In [13]:
corrtop = correlate(images[inspect_idx], top)

In [14]:
corrtop[dims]


Out[14]:
array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 115, 236, 142,  24,
        156, 207, 210, 213,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  63, 170,  84, 157, 181, 101, 220,
         88,  35, 250, 251, 214,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  25, 217, 187, 132, 118,  37, 150, 179,  34, 207,   0,
          0,   2,  13,  13,  11,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,  68,  10,  70,  28, 112,  63,  37,  53, 115, 158, 185,  84,
         41, 228, 214, 214, 214,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0, 118, 147, 147,  29, 249, 105, 162, 130, 255, 159,  27,  74,
        113, 252,  61,  61,  93,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,  70, 126, 126,  56, 226,  40, 188, 165,  95, 203,   0,  28,  28,
         28, 177, 157, 157, 236,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0, 115, 125, 125, 208,  43, 240,  42, 197,   0,  17, 107, 169, 152,
          4, 124, 124, 182,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0, 154, 107,  45, 233, 213,  19, 189,   0,   0, 104,  10,  22, 131,
         55, 232,  19, 189,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0, 173, 224,  34,  93,  42, 232,   0,   0,  77, 203, 203, 122, 149,
        125, 129, 232,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 176, 181, 181, 115,   4,
          4, 150,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 134, 136, 137,   4, 195, 194,
        193,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6,  55,  53,  46, 107,  72,  73,
        219,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  10, 106, 175, 165, 236,  76,  74, 163,
        254,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  13, 171,  65,  52,   7, 208, 192,  79, 240,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  28, 112, 112,  49, 111, 109, 144, 254,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  40, 154, 154,  76,  32,  32,  70,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0, 125, 222, 222, 245, 117, 117, 225,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  43,  57,  57, 168,  87,  87, 189,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  94, 167, 136, 187,  75, 106, 217,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0, 186, 186,  44,  75,  75, 217,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0, 189, 193, 114, 181, 177,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=uint8)

In [20]:
plot(corrtop[dims])



In [21]:
plot(corrtop)



In [45]:
np.rot90(top, 1)


Out[45]:
array([[-1,  1,  0],
       [-1,  1,  0],
       [-1,  1,  0]])

In [46]:
convtop = convolve(images[inspect_idx], np.rot90(top,2))
plot(convtop)
np.allclose(convtop, corrtop)


Out[46]:
True

In [47]:
straights=[np.rot90(top,i) for i in range(4)]
plots(straights)



In [48]:
br=[[ 0, 0, 1],
    [ 0, 1,-1.5],
    [ 1,-1.5, 0]]

diags = [np.rot90(br,i) for i in range(4)]
plots(diags)



In [49]:
rots = straights + diags
corrs = [correlate(images[inspect_idx], rot) for rot in rots]
plots(corrs)



In [50]:
def pool(im): return block_reduce(im, (7,7), np.max)

plots([pool(im) for im in corrs])



In [51]:
eights=[images[i] for i in xrange(n) if labels[i]==8]
ones=[images[i] for i in xrange(n) if labels[i]==1]

In [24]:
plots(eights[:5])
plots(ones[:5])



In [25]:
pool8 = [np.array([pool(correlate(im, rot)) for im in eights]) for rot in rots]

In [ ]:
len(pool8), pool8[0].shape

In [26]:
plots(pool8[0][0:5])



In [27]:
def normalize(arr): return (arr-arr.mean())/arr.std()

In [28]:
filts8 = np.array([ims.mean(axis=0) for ims in pool8])
filts8 = normalize(filts8)

In [29]:
plots(filts8)



In [30]:
pool1 = [np.array([pool(correlate(im, rot)) for im in ones]) for rot in rots]
filts1 = np.array([ims.mean(axis=0) for ims in pool1])
filts1 = normalize(filts1)

In [31]:
plots(filts1)



In [32]:
def pool_corr(im): return np.array([pool(correlate(im, rot)) for rot in rots])

In [33]:
plots(pool_corr(eights[0]))



In [35]:
def sse(a,b): return ((a-b)**2).sum()
def is8_n2(im): return 1 if sse(pool_corr(im),filts1) > sse(pool_corr(im),filts8) else 0

In [36]:
sse(pool_corr(eights[0]), filts8), sse(pool_corr(eights[0]), filts1)


Out[36]:
(126.77776, 181.26105)

In [37]:
[np.array([is8_n2(im) for im in ims]).sum() for ims in [eights,ones]]


Out[37]:
[5223, 287]

In [38]:
[np.array([(1-is8_n2(im)) for im in ims]).sum() for ims in [eights,ones]]


Out[38]:
[166, 5892]

In [ ]:
def n1(a,b): return (np.fabs(a-b)).sum()
def is8_n1(im): return 1 if n1(pool_corr(im),filts1) > n1(pool_corr(im),filts8) else 0

In [ ]:
[np.array([is8_n1(im) for im in ims]).sum() for ims in [eights,ones]]

In [ ]:
[np.array([(1-is8_n1(im)) for im in ims]).sum() for ims in [eights,ones]]

Scratch


In [8]:
??correlate

In [13]:
correlate([1, 2], [0, 3]) # [1x0 + 1x3, 2x0 + 2x3]
correlate([1,2,3], [0,4,5]) # [1x0 + 1x4 + 1x5, 2x0 + 2x4 + 2x5, 3x0 + 3x4 + 3x5]


Out[13]:
array([14, 23, 27])

In [ ]: