Download CIFAR-10 dataset and convert it to TFRecord format.


In [ ]:
import os
import numpy as np
from random import shuffle

import tensorflow as tf
from tensorflow.keras.datasets import cifar10

In [ ]:
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print('x_train shape: {}'.format(x_train.shape))
print('x_test shape : {}'.format(x_test.shape))

In [ ]:
# Combine Test and Train data sets.  We will shuffle and redistribute later.
X = np.concatenate([x_train, x_test]).astype(float)
Y = np.concatenate([y_train, y_test]).astype(int)

# Keras downloads labels into [n,1], reshape to [n].
Y = Y.reshape([Y.shape[0]])

# Create shuffled index toi reorder dta randomly.
ix = list(range(X.shape[0]))
shuffle(ix)

In [ ]:
import os

if not os.path.exists('data'):
    os.makedirs('data')
else:
    print('directory "data" already exists')

In [ ]:
shard_size = 5000

partition = [x for x in range(0, X.shape[0], shard_size)] + [X.shape[0]]
data_range = zip(partition[:-1], partition[1:])

data_range

In [ ]:
# Helper function for creating tfrecords files

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[value]))

In [ ]:
# Iterate through filenames and serialize images

f_prefix = './data/cifar10_data_'
f_digits = 3
f_suffix = 0

for _,(start,end) in enumerate(data_range):
    f_name = f_prefix + str(f_suffix).zfill(f_digits) + '.tfrecords'
    
    writer = tf.python_io.TFRecordWriter(f_name)
    images = X[start:end, :, :, :]
    labels = Y[start:end]
    
    for i in range(images.shape[0]):
        image = images[i, :, :, :]
        label = labels[i]
        
        e = tf.train.Example(features=tf.train.Features(
            feature={
                'label': _int64_feature(label),
                'image': _bytes_feature(image.tostring())
            }))
        writer.write(e.SerializeToString())   
    writer.close()
    print('finished writing {}'.format(f_name))
    f_suffix += 1

In [ ]: