In [1]:
%matplotlib inline
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import widgets,interact
from IPython.display import clear_output

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [2]:
dane = "/DATA/shared/datasets/cifar10/cifar10_test.tfrecord"
dane_train = "/DATA/shared/datasets/cifar10/cifar10_train.tfrecord"

In [3]:
label2txt = ["airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck" ]

In [4]:
def read_data(filename_queue):
    reader = tf.TFRecordReader()
    _, se = reader.read(filename_queue)
    f = tf.parse_single_example(se,features={'image/encoded':tf.FixedLenFeature([],tf.string),
                                            'image/class/label':tf.FixedLenFeature([],tf.int64),
                                            'image/height':tf.FixedLenFeature([],tf.int64),
                                            'image/width':tf.FixedLenFeature([],tf.int64)})
    image = tf.image.decode_png(f['image/encoded'],channels=3)
    image.set_shape( (32,32,3) ) 
    return image,f['image/class/label']

In [ ]:


In [5]:
fq = tf.train.string_input_producer([dane_train])
image_data, label = read_data(filename_queue=fq)

In [6]:
batch_size = 128
images, sparse_labels = tf.train.shuffle_batch( [image_data,label],batch_size=batch_size,
                                               num_threads=2,
                                               capacity=1000+3*batch_size,
                                               min_after_dequeue=1000
                                               )
images = (tf.cast(images,tf.float32)-128.0)/33.0

In [7]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

In [8]:
im,l = sess.run([images, sparse_labels])

Caution: this will use different batches!

im = sess.run(images)
l = sess.run(sparse_labels)

In [9]:
@interact(ith=widgets.IntSlider(max=im.shape[0]))
def show_ith(ith):
    clear_output(wait=True)
    plt.imshow((128+33*im[ith,:,:,:]).astype(np.uint8))
    plt.title(label2txt[l[ith]])
    plt.show()


Widget Javascript not detected.  It may not be installed or enabled properly.

In [ ]:


In [ ]: