Using Convolutional Neural Networks

This is running on theano!

Welcome to the first week of the first deep learning certificate! We're going to use convolutional neural networks (CNNs) to allow our computer to see - something that is only possible thanks to deep learning.

Introduction to this week's task: 'Dogs vs Cats'

We're going to try to create a model to enter the Dogs vs Cats competition at Kaggle. There are 25,000 labelled dog and cat photos available for training, and 12,500 in the test set that we have to try to label for this competition. According to the Kaggle web-site, when this competition was launched (end of 2013): "State of the art: The current literature suggests machine classifiers can score above 80% accuracy on this task". So if we can beat 80%, then we will be at the cutting edge as at 2013!

Basic setup

There isn't too much to do to get started - just a few simple configuration steps.

This shows plots in the web page itself - we always wants to use this when using jupyter notebook:


In [1]:
%matplotlib inline

Define path to data: (It's a good idea to put it in a subdirectory of your notebooks folder, and then exclude that directory from git control by adding it to .gitignore.)


In [2]:
#path = "data/dogscats/"
path = "data/dogscats/sample/"

A few basic libraries that we'll need for the initial exercises:


In [3]:
from __future__ import division,print_function

import os, json
from glob import glob
import numpy as np
np.set_printoptions(precision=4, linewidth=100)
from matplotlib import pyplot as plt

We have created a file most imaginatively called 'utils.py' to store any little convenience functions we'll want to use. We will discuss these as we use them.


In [4]:
import utils
import importlib
importlib.reload(utils)
from utils import plots


Using gpu device 0: GeForce GTX 750 (CNMeM is disabled, cuDNN not available)
Using Theano backend.
Out[4]:
<module 'utils' from '/home/tw/nbs/utils.py'>

Use a pretrained VGG model with our Vgg16 class

Our first step is simply to use a model that has been fully created for us, which can recognise a wide variety (1,000 categories) of images. We will use 'VGG', which won the 2014 Imagenet competition, and is a very simple model to create and understand. The VGG Imagenet team created both a larger, slower, slightly more accurate model (VGG 19) and a smaller, faster model (VGG 16). We will be using VGG 16 since the much slower performance of VGG19 is generally not worth the very minor improvement in accuracy.

We have created a python class, Vgg16, which makes using the VGG 16 model very straightforward.

The punchline: state of the art custom model in 7 lines of code

Here's everything you need to do to get >97% accuracy on the Dogs vs Cats dataset - we won't analyze how it works behind the scenes yet, since at this stage we're just going to focus on the minimum necessary to actually do useful work.


In [5]:
# As large as you can, but no larger than 64 is recommended. 
# If you have an older or cheaper GPU, you'll run out of memory, so will have to decrease this.
# batch_size=64
batch_size=2

In [6]:
# Import our class, and instantiate
import vgg16
from vgg16 import Vgg16

In [12]:
vgg.classes


Out[12]:
['tench',
 'goldfish',
 'great_white_shark',
 'tiger_shark',
 'hammerhead',
 'electric_ray',
 'stingray',
 'cock',
 'hen',
 'ostrich',
 'brambling',
 'goldfinch',
 'house_finch',
 'junco',
 'indigo_bunting',
 'robin',
 'bulbul',
 'jay',
 'magpie',
 'chickadee',
 'water_ouzel',
 'kite',
 'bald_eagle',
 'vulture',
 'great_grey_owl',
 'European_fire_salamander',
 'common_newt',
 'eft',
 'spotted_salamander',
 'axolotl',
 'bullfrog',
 'tree_frog',
 'tailed_frog',
 'loggerhead',
 'leatherback_turtle',
 'mud_turtle',
 'terrapin',
 'box_turtle',
 'banded_gecko',
 'common_iguana',
 'American_chameleon',
 'whiptail',
 'agama',
 'frilled_lizard',
 'alligator_lizard',
 'Gila_monster',
 'green_lizard',
 'African_chameleon',
 'Komodo_dragon',
 'African_crocodile',
 'American_alligator',
 'triceratops',
 'thunder_snake',
 'ringneck_snake',
 'hognose_snake',
 'green_snake',
 'king_snake',
 'garter_snake',
 'water_snake',
 'vine_snake',
 'night_snake',
 'boa_constrictor',
 'rock_python',
 'Indian_cobra',
 'green_mamba',
 'sea_snake',
 'horned_viper',
 'diamondback',
 'sidewinder',
 'trilobite',
 'harvestman',
 'scorpion',
 'black_and_gold_garden_spider',
 'barn_spider',
 'garden_spider',
 'black_widow',
 'tarantula',
 'wolf_spider',
 'tick',
 'centipede',
 'black_grouse',
 'ptarmigan',
 'ruffed_grouse',
 'prairie_chicken',
 'peacock',
 'quail',
 'partridge',
 'African_grey',
 'macaw',
 'sulphur-crested_cockatoo',
 'lorikeet',
 'coucal',
 'bee_eater',
 'hornbill',
 'hummingbird',
 'jacamar',
 'toucan',
 'drake',
 'red-breasted_merganser',
 'goose',
 'black_swan',
 'tusker',
 'echidna',
 'platypus',
 'wallaby',
 'koala',
 'wombat',
 'jellyfish',
 'sea_anemone',
 'brain_coral',
 'flatworm',
 'nematode',
 'conch',
 'snail',
 'slug',
 'sea_slug',
 'chiton',
 'chambered_nautilus',
 'Dungeness_crab',
 'rock_crab',
 'fiddler_crab',
 'king_crab',
 'American_lobster',
 'spiny_lobster',
 'crayfish',
 'hermit_crab',
 'isopod',
 'white_stork',
 'black_stork',
 'spoonbill',
 'flamingo',
 'little_blue_heron',
 'American_egret',
 'bittern',
 'crane',
 'limpkin',
 'European_gallinule',
 'American_coot',
 'bustard',
 'ruddy_turnstone',
 'red-backed_sandpiper',
 'redshank',
 'dowitcher',
 'oystercatcher',
 'pelican',
 'king_penguin',
 'albatross',
 'grey_whale',
 'killer_whale',
 'dugong',
 'sea_lion',
 'Chihuahua',
 'Japanese_spaniel',
 'Maltese_dog',
 'Pekinese',
 'Shih-Tzu',
 'Blenheim_spaniel',
 'papillon',
 'toy_terrier',
 'Rhodesian_ridgeback',
 'Afghan_hound',
 'basset',
 'beagle',
 'bloodhound',
 'bluetick',
 'black-and-tan_coonhound',
 'Walker_hound',
 'English_foxhound',
 'redbone',
 'borzoi',
 'Irish_wolfhound',
 'Italian_greyhound',
 'whippet',
 'Ibizan_hound',
 'Norwegian_elkhound',
 'otterhound',
 'Saluki',
 'Scottish_deerhound',
 'Weimaraner',
 'Staffordshire_bullterrier',
 'American_Staffordshire_terrier',
 'Bedlington_terrier',
 'Border_terrier',
 'Kerry_blue_terrier',
 'Irish_terrier',
 'Norfolk_terrier',
 'Norwich_terrier',
 'Yorkshire_terrier',
 'wire-haired_fox_terrier',
 'Lakeland_terrier',
 'Sealyham_terrier',
 'Airedale',
 'cairn',
 'Australian_terrier',
 'Dandie_Dinmont',
 'Boston_bull',
 'miniature_schnauzer',
 'giant_schnauzer',
 'standard_schnauzer',
 'Scotch_terrier',
 'Tibetan_terrier',
 'silky_terrier',
 'soft-coated_wheaten_terrier',
 'West_Highland_white_terrier',
 'Lhasa',
 'flat-coated_retriever',
 'curly-coated_retriever',
 'golden_retriever',
 'Labrador_retriever',
 'Chesapeake_Bay_retriever',
 'German_short-haired_pointer',
 'vizsla',
 'English_setter',
 'Irish_setter',
 'Gordon_setter',
 'Brittany_spaniel',
 'clumber',
 'English_springer',
 'Welsh_springer_spaniel',
 'cocker_spaniel',
 'Sussex_spaniel',
 'Irish_water_spaniel',
 'kuvasz',
 'schipperke',
 'groenendael',
 'malinois',
 'briard',
 'kelpie',
 'komondor',
 'Old_English_sheepdog',
 'Shetland_sheepdog',
 'collie',
 'Border_collie',
 'Bouvier_des_Flandres',
 'Rottweiler',
 'German_shepherd',
 'Doberman',
 'miniature_pinscher',
 'Greater_Swiss_Mountain_dog',
 'Bernese_mountain_dog',
 'Appenzeller',
 'EntleBucher',
 'boxer',
 'bull_mastiff',
 'Tibetan_mastiff',
 'French_bulldog',
 'Great_Dane',
 'Saint_Bernard',
 'Eskimo_dog',
 'malamute',
 'Siberian_husky',
 'dalmatian',
 'affenpinscher',
 'basenji',
 'pug',
 'Leonberg',
 'Newfoundland',
 'Great_Pyrenees',
 'Samoyed',
 'Pomeranian',
 'chow',
 'keeshond',
 'Brabancon_griffon',
 'Pembroke',
 'Cardigan',
 'toy_poodle',
 'miniature_poodle',
 'standard_poodle',
 'Mexican_hairless',
 'timber_wolf',
 'white_wolf',
 'red_wolf',
 'coyote',
 'dingo',
 'dhole',
 'African_hunting_dog',
 'hyena',
 'red_fox',
 'kit_fox',
 'Arctic_fox',
 'grey_fox',
 'tabby',
 'tiger_cat',
 'Persian_cat',
 'Siamese_cat',
 'Egyptian_cat',
 'cougar',
 'lynx',
 'leopard',
 'snow_leopard',
 'jaguar',
 'lion',
 'tiger',
 'cheetah',
 'brown_bear',
 'American_black_bear',
 'ice_bear',
 'sloth_bear',
 'mongoose',
 'meerkat',
 'tiger_beetle',
 'ladybug',
 'ground_beetle',
 'long-horned_beetle',
 'leaf_beetle',
 'dung_beetle',
 'rhinoceros_beetle',
 'weevil',
 'fly',
 'bee',
 'ant',
 'grasshopper',
 'cricket',
 'walking_stick',
 'cockroach',
 'mantis',
 'cicada',
 'leafhopper',
 'lacewing',
 'dragonfly',
 'damselfly',
 'admiral',
 'ringlet',
 'monarch',
 'cabbage_butterfly',
 'sulphur_butterfly',
 'lycaenid',
 'starfish',
 'sea_urchin',
 'sea_cucumber',
 'wood_rabbit',
 'hare',
 'Angora',
 'hamster',
 'porcupine',
 'fox_squirrel',
 'marmot',
 'beaver',
 'guinea_pig',
 'sorrel',
 'zebra',
 'hog',
 'wild_boar',
 'warthog',
 'hippopotamus',
 'ox',
 'water_buffalo',
 'bison',
 'ram',
 'bighorn',
 'ibex',
 'hartebeest',
 'impala',
 'gazelle',
 'Arabian_camel',
 'llama',
 'weasel',
 'mink',
 'polecat',
 'black-footed_ferret',
 'otter',
 'skunk',
 'badger',
 'armadillo',
 'three-toed_sloth',
 'orangutan',
 'gorilla',
 'chimpanzee',
 'gibbon',
 'siamang',
 'guenon',
 'patas',
 'baboon',
 'macaque',
 'langur',
 'colobus',
 'proboscis_monkey',
 'marmoset',
 'capuchin',
 'howler_monkey',
 'titi',
 'spider_monkey',
 'squirrel_monkey',
 'Madagascar_cat',
 'indri',
 'Indian_elephant',
 'African_elephant',
 'lesser_panda',
 'giant_panda',
 'barracouta',
 'eel',
 'coho',
 'rock_beauty',
 'anemone_fish',
 'sturgeon',
 'gar',
 'lionfish',
 'puffer',
 'abacus',
 'abaya',
 'academic_gown',
 'accordion',
 'acoustic_guitar',
 'aircraft_carrier',
 'airliner',
 'airship',
 'altar',
 'ambulance',
 'amphibian',
 'analog_clock',
 'apiary',
 'apron',
 'ashcan',
 'assault_rifle',
 'backpack',
 'bakery',
 'balance_beam',
 'balloon',
 'ballpoint',
 'Band_Aid',
 'banjo',
 'bannister',
 'barbell',
 'barber_chair',
 'barbershop',
 'barn',
 'barometer',
 'barrel',
 'barrow',
 'baseball',
 'basketball',
 'bassinet',
 'bassoon',
 'bathing_cap',
 'bath_towel',
 'bathtub',
 'beach_wagon',
 'beacon',
 'beaker',
 'bearskin',
 'beer_bottle',
 'beer_glass',
 'bell_cote',
 'bib',
 'bicycle-built-for-two',
 'bikini',
 'binder',
 'binoculars',
 'birdhouse',
 'boathouse',
 'bobsled',
 'bolo_tie',
 'bonnet',
 'bookcase',
 'bookshop',
 'bottlecap',
 'bow',
 'bow_tie',
 'brass',
 'brassiere',
 'breakwater',
 'breastplate',
 'broom',
 'bucket',
 'buckle',
 'bulletproof_vest',
 'bullet_train',
 'butcher_shop',
 'cab',
 'caldron',
 'candle',
 'cannon',
 'canoe',
 'can_opener',
 'cardigan',
 'car_mirror',
 'carousel',
 "carpenter's_kit",
 'carton',
 'car_wheel',
 'cash_machine',
 'cassette',
 'cassette_player',
 'castle',
 'catamaran',
 'CD_player',
 'cello',
 'cellular_telephone',
 'chain',
 'chainlink_fence',
 'chain_mail',
 'chain_saw',
 'chest',
 'chiffonier',
 'chime',
 'china_cabinet',
 'Christmas_stocking',
 'church',
 'cinema',
 'cleaver',
 'cliff_dwelling',
 'cloak',
 'clog',
 'cocktail_shaker',
 'coffee_mug',
 'coffeepot',
 'coil',
 'combination_lock',
 'computer_keyboard',
 'confectionery',
 'container_ship',
 'convertible',
 'corkscrew',
 'cornet',
 'cowboy_boot',
 'cowboy_hat',
 'cradle',
 'crane',
 'crash_helmet',
 'crate',
 'crib',
 'Crock_Pot',
 'croquet_ball',
 'crutch',
 'cuirass',
 'dam',
 'desk',
 'desktop_computer',
 'dial_telephone',
 'diaper',
 'digital_clock',
 'digital_watch',
 'dining_table',
 'dishrag',
 'dishwasher',
 'disk_brake',
 'dock',
 'dogsled',
 'dome',
 'doormat',
 'drilling_platform',
 'drum',
 'drumstick',
 'dumbbell',
 'Dutch_oven',
 'electric_fan',
 'electric_guitar',
 'electric_locomotive',
 'entertainment_center',
 'envelope',
 'espresso_maker',
 'face_powder',
 'feather_boa',
 'file',
 'fireboat',
 'fire_engine',
 'fire_screen',
 'flagpole',
 'flute',
 'folding_chair',
 'football_helmet',
 'forklift',
 'fountain',
 'fountain_pen',
 'four-poster',
 'freight_car',
 'French_horn',
 'frying_pan',
 'fur_coat',
 'garbage_truck',
 'gasmask',
 'gas_pump',
 'goblet',
 'go-kart',
 'golf_ball',
 'golfcart',
 'gondola',
 'gong',
 'gown',
 'grand_piano',
 'greenhouse',
 'grille',
 'grocery_store',
 'guillotine',
 'hair_slide',
 'hair_spray',
 'half_track',
 'hammer',
 'hamper',
 'hand_blower',
 'hand-held_computer',
 'handkerchief',
 'hard_disc',
 'harmonica',
 'harp',
 'harvester',
 'hatchet',
 'holster',
 'home_theater',
 'honeycomb',
 'hook',
 'hoopskirt',
 'horizontal_bar',
 'horse_cart',
 'hourglass',
 'iPod',
 'iron',
 "jack-o'-lantern",
 'jean',
 'jeep',
 'jersey',
 'jigsaw_puzzle',
 'jinrikisha',
 'joystick',
 'kimono',
 'knee_pad',
 'knot',
 'lab_coat',
 'ladle',
 'lampshade',
 'laptop',
 'lawn_mower',
 'lens_cap',
 'letter_opener',
 'library',
 'lifeboat',
 'lighter',
 'limousine',
 'liner',
 'lipstick',
 'Loafer',
 'lotion',
 'loudspeaker',
 'loupe',
 'lumbermill',
 'magnetic_compass',
 'mailbag',
 'mailbox',
 'maillot',
 'maillot',
 'manhole_cover',
 'maraca',
 'marimba',
 'mask',
 'matchstick',
 'maypole',
 'maze',
 'measuring_cup',
 'medicine_chest',
 'megalith',
 'microphone',
 'microwave',
 'military_uniform',
 'milk_can',
 'minibus',
 'miniskirt',
 'minivan',
 'missile',
 'mitten',
 'mixing_bowl',
 'mobile_home',
 'Model_T',
 'modem',
 'monastery',
 'monitor',
 'moped',
 'mortar',
 'mortarboard',
 'mosque',
 'mosquito_net',
 'motor_scooter',
 'mountain_bike',
 'mountain_tent',
 'mouse',
 'mousetrap',
 'moving_van',
 'muzzle',
 'nail',
 'neck_brace',
 'necklace',
 'nipple',
 'notebook',
 'obelisk',
 'oboe',
 'ocarina',
 'odometer',
 'oil_filter',
 'organ',
 'oscilloscope',
 'overskirt',
 'oxcart',
 'oxygen_mask',
 'packet',
 'paddle',
 'paddlewheel',
 'padlock',
 'paintbrush',
 'pajama',
 'palace',
 'panpipe',
 'paper_towel',
 'parachute',
 'parallel_bars',
 'park_bench',
 'parking_meter',
 'passenger_car',
 'patio',
 'pay-phone',
 'pedestal',
 'pencil_box',
 'pencil_sharpener',
 'perfume',
 'Petri_dish',
 'photocopier',
 'pick',
 'pickelhaube',
 'picket_fence',
 'pickup',
 'pier',
 'piggy_bank',
 'pill_bottle',
 'pillow',
 'ping-pong_ball',
 'pinwheel',
 'pirate',
 'pitcher',
 'plane',
 'planetarium',
 'plastic_bag',
 'plate_rack',
 'plow',
 'plunger',
 'Polaroid_camera',
 'pole',
 'police_van',
 'poncho',
 'pool_table',
 'pop_bottle',
 'pot',
 "potter's_wheel",
 'power_drill',
 'prayer_rug',
 'printer',
 'prison',
 'projectile',
 'projector',
 'puck',
 'punching_bag',
 'purse',
 'quill',
 'quilt',
 'racer',
 'racket',
 'radiator',
 'radio',
 'radio_telescope',
 'rain_barrel',
 'recreational_vehicle',
 'reel',
 'reflex_camera',
 'refrigerator',
 'remote_control',
 'restaurant',
 'revolver',
 'rifle',
 'rocking_chair',
 'rotisserie',
 'rubber_eraser',
 'rugby_ball',
 'rule',
 'running_shoe',
 'safe',
 'safety_pin',
 'saltshaker',
 'sandal',
 'sarong',
 'sax',
 'scabbard',
 'scale',
 'school_bus',
 'schooner',
 'scoreboard',
 'screen',
 'screw',
 'screwdriver',
 'seat_belt',
 'sewing_machine',
 'shield',
 'shoe_shop',
 'shoji',
 'shopping_basket',
 'shopping_cart',
 'shovel',
 'shower_cap',
 'shower_curtain',
 'ski',
 'ski_mask',
 'sleeping_bag',
 'slide_rule',
 'sliding_door',
 'slot',
 'snorkel',
 'snowmobile',
 'snowplow',
 'soap_dispenser',
 'soccer_ball',
 'sock',
 'solar_dish',
 'sombrero',
 'soup_bowl',
 'space_bar',
 'space_heater',
 'space_shuttle',
 'spatula',
 'speedboat',
 'spider_web',
 'spindle',
 'sports_car',
 'spotlight',
 'stage',
 'steam_locomotive',
 'steel_arch_bridge',
 'steel_drum',
 'stethoscope',
 'stole',
 'stone_wall',
 'stopwatch',
 'stove',
 'strainer',
 'streetcar',
 'stretcher',
 'studio_couch',
 'stupa',
 'submarine',
 'suit',
 'sundial',
 'sunglass',
 'sunglasses',
 'sunscreen',
 'suspension_bridge',
 'swab',
 'sweatshirt',
 'swimming_trunks',
 'swing',
 'switch',
 'syringe',
 'table_lamp',
 'tank',
 'tape_player',
 'teapot',
 'teddy',
 'television',
 'tennis_ball',
 'thatch',
 'theater_curtain',
 'thimble',
 'thresher',
 'throne',
 'tile_roof',
 'toaster',
 'tobacco_shop',
 'toilet_seat',
 'torch',
 'totem_pole',
 'tow_truck',
 'toyshop',
 'tractor',
 'trailer_truck',
 'tray',
 'trench_coat',
 'tricycle',
 'trimaran',
 'tripod',
 'triumphal_arch',
 'trolleybus',
 'trombone',
 'tub',
 'turnstile',
 'typewriter_keyboard',
 'umbrella',
 'unicycle',
 'upright',
 'vacuum',
 'vase',
 'vault',
 'velvet',
 'vending_machine',
 'vestment',
 'viaduct',
 'violin',
 'volleyball',
 'waffle_iron',
 'wall_clock',
 'wallet',
 'wardrobe',
 'warplane',
 'washbasin',
 'washer',
 'water_bottle',
 'water_jug',
 'water_tower',
 'whiskey_jug',
 'whistle',
 'wig',
 'window_screen',
 'window_shade',
 'Windsor_tie',
 'wine_bottle',
 'wing',
 'wok',
 'wooden_spoon',
 'wool',
 'worm_fence',
 'wreck',
 'yawl',
 'yurt',
 'web_site',
 'comic_book',
 'crossword_puzzle',
 'street_sign',
 'traffic_light',
 'book_jacket',
 'menu',
 'plate',
 'guacamole',
 'consomme',
 'hot_pot',
 'trifle',
 'ice_cream',
 'ice_lolly',
 'French_loaf',
 'bagel',
 'pretzel',
 'cheeseburger',
 'hotdog',
 'mashed_potato',
 'head_cabbage',
 'broccoli',
 'cauliflower',
 'zucchini',
 'spaghetti_squash',
 'acorn_squash',
 'butternut_squash',
 'cucumber',
 'artichoke',
 'bell_pepper',
 'cardoon',
 'mushroom',
 'Granny_Smith',
 'strawberry',
 'orange',
 'lemon',
 'fig',
 'pineapple',
 'banana',
 'jackfruit',
 'custard_apple',
 'pomegranate',
 'hay',
 'carbonara',
 'chocolate_sauce',
 'dough',
 'meat_loaf',
 'pizza',
 'potpie',
 'burrito',
 'red_wine',
 'espresso',
 'cup',
 'eggnog',
 'alp',
 'bubble',
 'cliff',
 'coral_reef',
 'geyser',
 'lakeside',
 'promontory',
 'sandbar',
 'seashore',
 'valley',
 'volcano',
 'ballplayer',
 'groom',
 'scuba_diver',
 'rapeseed',
 'daisy',
 "yellow_lady's_slipper",
 'corn',
 'acorn',
 'hip',
 'buckeye',
 'coral_fungus',
 'agaric',
 'gyromitra',
 'stinkhorn',
 'earthstar',
 'hen-of-the-woods',
 'bolete',
 'ear',
 'toilet_tissue']

In [15]:
# %%capture x # ping bug: disconnect -> reconnect kernel workaround
vgg = Vgg16()
# Grab a few images at a time for training and validation.
# NB: They must be in subdirectories named based on their category
batches = vgg.get_batches(path+  'train', batch_size=batch_size)
batches.nb_class


Found 160 images belonging to 2 classes.

In [ ]:
val_batches = vgg.get_batches(path+'valid', batch_size=batch_size*2)
vgg.finetune(batches)
vgg.fit(batches, val_batches, nb_epoch=1, verbose=1)

In [14]:
#x.show()

The code above will work for any image recognition task, with any number of categories! All you have to do is to put your images into one folder per category, and run the code above.

Let's take a look at how this works, step by step...

Use Vgg16 for basic image recognition

Let's start off by using the Vgg16 class to recognise the main imagenet category for each image.

We won't be able to enter the Cats vs Dogs competition with an Imagenet model alone, since 'cat' and 'dog' are not categories in Imagenet - instead each individual breed is a separate category. However, we can use it to see how well it can recognise the images, which is a good first step.

First, create a Vgg16 object:


In [18]:
vgg = Vgg16()

Vgg16 is built on top of Keras (which we will be learning much more about shortly!), a flexible, easy to use deep learning library that sits on top of Theano or Tensorflow. Keras reads groups of images and labels in batches, using a fixed directory structure, where images from each category for training must be placed in a separate folder.

Let's grab batches of data from our training folder:


In [19]:
batches = vgg.get_batches(path+'train', batch_size=4)


Found 160 images belonging to 2 classes.

(BTW, when Keras refers to 'classes', it doesn't mean python classes - but rather it refers to the categories of the labels, such as 'pug', or 'tabby'.)

Batches is just a regular python iterator. Each iteration returns both the images themselves, as well as the labels.


In [20]:
imgs,labels = next(batches)

In [30]:
imgs[0].shape
labels


Out[30]:
(3, 224, 224)
Out[30]:
array([[ 1.,  0.],
       [ 1.,  0.],
       [ 1.,  0.],
       [ 0.,  1.]], dtype=float32)

As you can see, the labels for each image are an array, containing a 1 in the first position if it's a cat, and in the second position if it's a dog. This approach to encoding categorical variables, where an array containing just a single 1 in the position corresponding to the category, is very common in deep learning. It is called one hot encoding.

The arrays contain two elements, because we have two categories (cat, and dog). If we had three categories (e.g. cats, dogs, and kangaroos), then the arrays would each contain two 0's, and one 1.


In [24]:
plots(imgs, titles=labels)


We can now pass the images to Vgg16's predict() function to get back probabilities, category indexes, and category names for each image's VGG prediction.


In [25]:
vgg.predict(imgs, True)


Out[25]:
(array([ 0.8745,  0.2226,  0.5382,  0.5448], dtype=float32),
 array([361, 261, 281, 215]),
 ['skunk', 'keeshond', 'tabby', 'Brittany_spaniel'])

The category indexes are based on the ordering of categories used in the VGG model - e.g here are the first four:


In [26]:
vgg.classes[:4]


Out[26]:
['tench', 'goldfish', 'great_white_shark', 'tiger_shark']

(Note that, other than creating the Vgg16 object, none of these steps are necessary to build a model; they are just showing how to use the class to view imagenet predictions.)

Use our Vgg16 class to finetune a Dogs vs Cats model

To change our model so that it outputs "cat" vs "dog", instead of one of 1,000 very specific categories, we need to use a process called "finetuning". Finetuning looks from the outside to be identical to normal machine learning training - we provide a training set with data and labels to learn from, and a validation set to test against. The model learns a set of parameters based on the data provided.

However, the difference is that we start with a model that is already trained to solve a similar problem. The idea is that many of the parameters should be very similar, or the same, between the existing model, and the model we wish to create. Therefore, we only select a subset of parameters to train, and leave the rest untouched. This happens automatically when we call fit() after calling finetune().

We create our batches just like before, and making the validation set available as well. A 'batch' (or mini-batch as it is commonly known) is simply a subset of the training data - we use a subset at a time when training or predicting, in order to speed up training, and to avoid running out of memory.


In [ ]:
batch_size=64

In [ ]:
batches = vgg.get_batches(path+'train', batch_size=batch_size)
val_batches = vgg.get_batches(path+'valid', batch_size=batch_size)

Calling finetune() modifies the model such that it will be trained based on the data in the batches provided - in this case, to predict either 'dog' or 'cat'.


In [ ]:
vgg.finetune(batches)

Finally, we fit() the parameters of the model using the training data, reporting the accuracy on the validation set after every epoch. (An epoch is one full pass through the training data.)


In [ ]:
vgg.fit(batches, val_batches, nb_epoch=1)

That shows all of the steps involved in using the Vgg16 class to create an image recognition model using whatever labels you are interested in. For instance, this process could classify paintings by style, or leaves by type of disease, or satellite photos by type of crop, and so forth.

Next up, we'll dig one level deeper to see what's going on in the Vgg16 class.


In [ ]: