In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
In [3]:
img = mnist.train.images[123]
img = np.reshape(img,(28,28))
plt.imshow(img, cmap = 'gray')
plt.show()
img = np.reshape(img,(28,28,1))
print img.shape, 'label = ', mnist.train.labels[123]
In [6]:
from dvd import dvd
img_embedding = dvd.get_embedding_x(img)
print img_embedding.shape
In [3]:
from sklearn import linear_model
from sklearn.metrics import accuracy_score
clf = linear_model.LogisticRegression()
clf.fit(mnist.train.images, mnist.train.labels)
preds = clf.predict(mnist.test.images)
print accuracy_score(preds, mnist.test.labels)
In [6]:
train = np.reshape(mnist.train.images, (mnist.train.images.shape[0],28,28))
print 'initial training shape = ', train.shape
train = dvd.get_embedding_X(train)
print 'training shape after embedding =', train.shape
test = np.reshape(mnist.test.images, (mnist.test.images.shape[0],28,28))
test = dvd.get_embedding_X(test)
In [7]:
from sklearn import linear_model
from sklearn.metrics import accuracy_score
clf = linear_model.LogisticRegression()
clf.fit(train, mnist.train.labels)
preds = clf.predict(test)
print accuracy_score(preds, mnist.test.labels)
An improvement of 8% by just including one line of code, this is as good as using CNN to build models. Remember, we are just using logistic regression to train the model. How good is that !
The real USP of transfer learning is that it is generic and you can feed in any image input for any classification task and use this a feature learning step. Now, let us take a moment to marvel at transfer learning.
In [ ]: