Object detection using Tensorflow Object Detection API

Reference


In [ ]:
import io
import csv
import os

from PIL import Image, ImageDraw
import matplotlib.pylab as plt
# %matplotlib widget
%matplotlib inline

import tensorflow as tf
import tensorflow_hub as hub

In [ ]:
def create_tf_example(info):
    
    image = Image.open(info['path'])
    
    height = image.height # Image height
    width = image.width # Image width
    filename = os.path.basename(info['path']).encode('utf-8') # Filename of the image. Empty if image is not from file
    with io.BytesIO() as output:
        image.save(output, image.format)
        encoded_image_data = output.getvalue() # Encoded image bytes
    image_format = image.format.lower().encode('utf-8') # b'jpeg' or b'png'

    xmins = int(info['xmin'])/width # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = int(info['xmax'])/width # List of normalized right x coordinates in bounding box # (1 per box)
    ymins = int(info['ymin'])/height # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = int(info['ymax'])/height # List of normalized bottom y coordinates in bounding box # (1 per box)
    classes_text = info['label'].encode('utf-8') # List of string class name of bounding box (1 per box)
    classes = int(info['idx']) # List of integer class id of bounding box (1 per box)

    tf_example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
      'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
      'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename])),
      'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename])),
      'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_image_data])),
      'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
      'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=[xmins])),
      'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=[xmaxs])),
      'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=[ymins])),
      'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=[ymaxs])),
      'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[classes_text])),
      'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=[classes])),
    }))
    
    return tf_example

In [ ]:
def show_image_objects(info):
    image = Image.open(info['path'])
    draw = ImageDraw.Draw(image)
    box = (int(info['xmin']), int(info['ymin']),
           int(info['xmax']), int(info['ymax']))

    draw.rectangle(box, width=16)
    fig = plt.figure()
    ax = fig.add_subplot()

    ax.axis('off')
    ax.imshow(image)
    fig.show()

In [ ]:
with open('annotation.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        data = row
        break

In [ ]:
show_image_objects(data)

In [ ]:
# tf_record_output_filenames = [
#     '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards)
#     for idx in range(num_shards)
# ]
# tfrecords = [
#     exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name))
#     for file_name in tf_record_output_filenames
# ]
# with contextlib2.ExitStack() as tf_record_close_stack:
#     output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
#         tf_record_close_stack, output_filebase, num_shards)
#     for index, example in examples:
#         tf_example = create_tf_example(example)
#         output_shard_index = index % num_shards
#         output_tfrecords[output_shard_index].write(tf_example.SerializeToString())

In [ ]:
l = list()
with open('annotation.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        l.append(create_tf_example(row))

In [ ]:
with tf.io.TFRecordWriter('data.tfrecords') as file_writer:
    for e in l:
        file_writer.write(e.SerializeToString())