In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
mnist = input_data.read_data_sets("../../data/MNIST", one_hot=True)
In [42]:
mnist.train.num_examples
Out[42]:
In [43]:
mnist.validation.num_examples
Out[43]:
In [44]:
mnist.test.num_examples
Out[44]:
In [45]:
# Train images
train_images = mnist.train.images
print(type(train_images))
print(train_images.shape)
In [46]:
# Train labels
train_labels = mnist.train.labels
print(type(train_labels))
print(train_labels.shape)
In [51]:
num_examples = 8
# Randomize list
images = np.random.permutation(train_images)
batch_images = images[:num_examples]
res_images = np.reshape(batch_images, (num_examples, 28, 28))
res_images.shape
Out[51]:
In [52]:
# Stack them horizontally
height, width = res_images.shape[1:]
final_width, final_height = num_examples * width, height
final_img = np.zeros((final_height, final_width))
for i in range(num_examples):
final_img[:, i*width:(i+1)*width] = res_images[i]
final_img.shape
Out[52]:
In [53]:
plt.figure(figsize=(15, 8))
io.imshow(final_img)
Out[53]:
In [ ]: