prepare data

17flowers.tgz has the size of about 60 MB. It might take some time to download...


In [1]:
import os
import requests

def download_data(url, timeout = 15):
    response = requests.get(url, allow_redirects=False, timeout=timeout)
    return response.content

def save_data(filename, data):
    with open(filename, "wb") as fout:
        fout.write(data)

url = "http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz"
data = download_data(url)

current_dir = os.getcwd()
filename = "17flowers.tgz"
filepath = os.path.join(current_dir, filename)
save_data(filepath, data)

make directories


In [2]:
dataset_name = "17flowers"

# ./data/17flowers
path = os.path.join(current_dir, "data", dataset_name)

train_dir = os.path.join(path, 'train')
test_dir  = os.path.join(path, 'test')
valid_dir = os.path.join(path, 'valid')

# recursive
# https://docs.python.org/3.6/library/os.html#os.makedirs
if not os.path.exists(train_dir):
    os.makedirs(train_dir)

if not os.path.exists(test_dir):
    os.makedirs(test_dir)

if not os.path.exists(valid_dir):
    os.makedirs(valid_dir)

copy data to the directories


In [3]:
import tarfile

tf = tarfile.open(filename, 'r')
#tf.list(); ls -l

flower_classes = ["Tulip", "Snowdrop", "LilyValley", "Bluebell", "Crocus", "Iris", "Tigerlily", "Daffodil", "Fritillary", "Sunflower", "Daisy", "ColtsFoot", "Dandelion", "Cowslip", "Buttercup", "Windflower", "Pansy"]

for name in flower_classes:
    os.mkdir(os.path.join(train_dir, name))
    os.mkdir(os.path.join(test_dir, name))
    os.mkdir(os.path.join(valid_dir, name))

In [4]:
flower_dics = {}
for idx,flower in enumerate(flower_classes):
    flower_dics[flower] = (80*idx+1, 80*(idx+1))
print(flower_dics)


{'Cowslip': (1041, 1120), 'Bluebell': (241, 320), 'Buttercup': (1121, 1200), 'Dandelion': (961, 1040), 'Snowdrop': (81, 160), 'Tulip': (1, 80), 'Fritillary': (641, 720), 'Sunflower': (721, 800), 'Tigerlily': (481, 560), 'Daisy': (801, 880), 'Iris': (401, 480), 'Crocus': (321, 400), 'LilyValley': (161, 240), 'Windflower': (1201, 1280), 'Pansy': (1281, 1360), 'ColtsFoot': (881, 960), 'Daffodil': (561, 640)}

In [5]:
# inside the tgz file, there is a jpg directory
jpg_dir = os.path.join(path, 'jpg')
tf.extractall(path)

In [6]:
import shutil

for f_str in sorted(os.listdir(jpg_dir)):
    if f_str.endswith('.jpg'):
        # image_0001.jpg => 1
        prefix = f_str.replace('.jpg', '')
        idx = int(prefix.split('_')[1])

        for name in flower_dics:
            start, end = flower_dics[name]
            if idx in range(start, end + 1):
                source = os.path.join(jpg_dir, f_str)
                dest = os.path.join(train_dir, name)
                shutil.copy(source, dest)
                continue

訓練データの各ディレクトリからランダムに10枚を検証用(valid_dir)とする


In [7]:
import random
random.seed(0)

for d_str in os.listdir(train_dir):
    files = os.listdir(os.path.join(train_dir, d_str))
    random.shuffle(files)
    for f_str in files[:10]:
        source = os.path.join(train_dir, d_str, f_str)
        dest = os.path.join(valid_dir, d_str)
        shutil.move(source, dest)

訓練データの各ディレクトリからランダムに10枚をテスト(test_dir)とする


In [8]:
random.seed(0)
for d_str in os.listdir(train_dir):
    files = os.listdir(os.path.join(train_dir, d_str))
    random.shuffle(files)
    for f_str in files[:10]:
        source = os.path.join(train_dir, d_str, f_str)
        dest = os.path.join(test_dir, d_str)
        shutil.move(source, dest)