This notebook demonstrates creating a custom sprite image for the TensorFlow Embedding Visualizer using the MNIST dataset.
In [ ]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from PIL import Image, ImageOps
In [ ]:
mnist = input_data.read_data_sets('/tmp/data', one_hot=True)
In [ ]:
rows = 32
cols = rows
im_size = 28
sprite = np.zeros((rows*im_size, cols*im_size))
idx = -1
for i in range(rows):
for j in range(cols):
idx +=1
image = mnist.test.images[idx].reshape((28,28))
row_coord = i * 28
col_coord = j * 28
sprite[row_coord:row_coord + 28, col_coord:col_coord + 28] = image
im = Image.fromarray(sprite * 255)
im = im.convert('RGB')
im = ImageOps.invert(im)
def get_color(lbl):
if lbl == 0: return (255, 102, 102)
if lbl == 1: return (255, 178, 102)
if lbl == 2: return (255, 255, 102)
if lbl == 3: return (178, 255, 102)
if lbl == 4: return (102, 255, 102)
if lbl == 5: return (102, 255, 178)
if lbl == 6: return (102, 255, 255)
if lbl == 7: return (102, 178, 255)
if lbl == 8: return (102, 102, 255)
if lbl == 9: return (178, 102, 255)
labels_file = open("labels.tsv", "w")
# colorize
orig_color = (255,255,255)
data = np.array(im)
idx = -1
for i in range(rows):
for j in range(cols):
idx +=1
row_coord = i * 28
col_coord = j * 28
label = np.argmax(mnist.test.labels[idx])
labels_file.write(str(label) + "\n")
replacement_color = get_color(label)
r = data[row_coord:row_coord + 28, col_coord:col_coord + 28]
r[(r == orig_color).all(axis = -1)] = replacement_color
im = Image.fromarray(data, mode='RGB')
im.save("sprite.png")
im.show()
labels_file.close()