VQA : Use and Abuse

To answer a question

  • Convert the image to features 'v'
  • Convert the question to a torch vector of longs
  • Pass both into the the VQA model
  • Interpret the softmax-y answer vectors

In [ ]:
# Upgrade pillow to latest version (solves a colab Issue) :
! pip install -U 'Pillow>=5.2.0'

In [ ]:
import os, sys

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=UserWarning) # Cleaner demos : Don't do this normally...

Download the Prebuilt VQA model and Weights


In [ ]:
if not os.path.isfile('./pytorch-vqa/README.md'):
    !git clone https://github.com/Cyanogenoid/pytorch-vqa.git
sys.path.append(os.path.realpath('./pytorch-vqa'))

In [ ]:
# https://github.com/Cyanogenoid/pytorch-vqa/releases

if not os.path.isfile('./2017-08-04_00.55.19.pth'):   # 81Mb model
    !wget https://github.com/Cyanogenoid/pytorch-vqa/releases/download/v1.0/2017-08-04_00.55.19.pth

In [ ]:
try: 
    import torch
except:
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    accelerator = 'cu80' if os.path.exists('/opt/bin/nvidia-smi') else 'cpu'
    !pip install -q \
      http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl \
      torchvision

In [ ]:
import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [ ]:
import model # from pytorch-vqa

#saved_state = torch.load('logs/2017-08-04_00:55:19.pth')
saved_state = torch.load('./2017-08-04_00.55.19.pth', map_location=device)
tokens = len(saved_state['vocab']['question']) + 1

saved_state.keys()  # See what's in the saved state

In [ ]:
# Load the predefined model
vqa_net = torch.nn.DataParallel(model.Net(tokens))
vqa_net.load_state_dict(saved_state['weights'])
vqa_net.to(device)
vqa_net.eval()

Now get the Correct Image feature network


In [ ]:
if not os.path.isfile('./pytorch-resnet/README.md'):
    !git clone https://github.com/Cyanogenoid/pytorch-resnet.git
sys.path.append(os.path.realpath('./pytorch-resnet'))

In [ ]:
import resnet  # from pytorch-resnet

import torchvision.transforms as transforms
from PIL import Image

def get_transform(target_size, central_fraction=1.0):
    return transforms.Compose([
        transforms.Scale(int(target_size / central_fraction)),
        transforms.CenterCrop(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

class ResNetLayer4(torch.nn.Module):
    def __init__(self):
        super(ResNetLayer4, self).__init__()
        self.model = resnet.resnet152(pretrained=True)
        
        # from  visual_qa_analysis/config.py
        image_size = 448  # scale shorter end of image to this size and centre crop
        #output_size = image_size // 32  # size of the feature maps after processing through a network
        output_features = 2048  # number of feature maps thereof
        central_fraction = 0.875 # only take this much of the centre when scaling and centre cropping

        self.transform = get_transform(image_size, central_fraction)

        def save_output(module, input, output):
            self.buffer = output
        self.model.layer4.register_forward_hook(save_output)

    def forward(self, x):
        self.model(x)
        return self.buffer
    
    def image_to_features(self, img_file):
        img = Image.open(img_file).convert('RGB')
        img_transformed = self.transform(img)
        #print(img_transformed.size())
        img_batch = img_transformed.unsqueeze(0).to(device)
        return self.forward(img_batch) 
    
resnet_layer4 = ResNetLayer4().to(device)  # Downloads 241Mb model when first run

In [ ]:
# Sample images : 
image_urls, image_path, image_files = [
    'https://www.pets4homes.co.uk/images/articles/2709/large/tabby-cat-colour-and-pattern-genetics-5516c44dbd383.jpg',
    'https://imgc.allpostersimages.com/img/print/posters/cat-black-jumping-off-wall_a-G-12469828-14258383.jpg',
    'https://i.ytimg.com/vi/AIwlyly7Eso/hqdefault.jpg',
    'https://upload.wikimedia.org/wikipedia/commons/9/9b/Black_pussy_-_panoramio.jpg',
    'https://www.thehappycatsite.com/wp-content/uploads/2017/06/siamese5.jpg',
    'https://c.pxhere.com/photos/15/e5/cat_roof_home_architecture_building_roofs_animal_sit-536976.jpg!d',
    'http://kitticats.com/wp-content/uploads/2015/05/cat4.jpg',
], './img/', []
os.makedirs('./img', exist_ok=True)
for url in image_urls:
    image_file=os.path.join(image_path, os.path.basename(url))
    image_files.append(image_file)
    if not os.path.isfile(image_file):
        !wget {url} --directory-prefix ./img/
image_files

In [ ]:
v = resnet_layer4.image_to_features(image_files[0])
v.size()

Have a look at how the vocab is built


In [ ]:
vocab = saved_state['vocab']
vocab.keys()  # dict_keys(['question', 'answer'])
list(vocab['question'].items())[:5]  # [('the', 1), ('is', 2), ('what', 3), ('are', 4), ('this', 5)]

In [ ]:
qtoken_to_index = vocab['question']
QUESTION_LENGTH_MAX = 30 # say...
    
def encode_question(question_str):
    """ Turn a question into a vector of indices and a question length """
    question_arr = question_str.lower().split(' ')
    #vec = torch.zeros(QUESTION_LENGTH_MAX).long()
    vec = torch.zeros(len(question_arr)).long()  
    for i, token in enumerate(question_arr):
        vec[i] = qtoken_to_index.get(token, 0)
    return vec.to(device), torch.tensor( len(question_arr) ).to(device)

In [ ]:
list(vocab['answer'].items())[:5]    # [('yes', 0), ('no', 1), ('2', 2), ('1', 3), ('white', 4)]

In [ ]:
answer_words = ['UNDEF'] * len(vocab['answer'])
for w,idx in vocab['answer'].items():
    answer_words[idx]=w
len(answer_words), answer_words[:10]  # 3000, ['yes', 'no', '2', '1', 'white', '3', 'red', 'blue', '4', 'green']

In [ ]:
# Important things to know...
'colour' in qtoken_to_index, 'color' in qtoken_to_index, 'tabby' in answer_words

Let's test a single Image


In [ ]:
image_idx = 1
image_filename = image_files[image_idx]

img = Image.open(image_filename).convert('RGB')
plt.imshow(img)

In [ ]:
v0 = resnet_layer4.image_to_features(image_filename)

In [ ]:
q, q_len = encode_question("is there a cat in the picture")
#q, q_len = encode_question("what color is the cat's fur")
#q, q_len = encode_question("is the cat jumping up or down")
q, q_len

In [ ]:
ans = vqa_net(v0, q.unsqueeze(0), q_len.unsqueeze(0))
ans.data.cpu()[0:10]

In [ ]:
_, answer_idx = ans.data.cpu().max(dim=1)
answer_words[ answer_idx ]

Let's systematise a little


In [ ]:
def vqa_single_softmax(im_features, q_str):
    q, q_len = encode_question(q_str)
    ans = vqa_net(im_features, q.unsqueeze(0), q_len.unsqueeze(0))
    return ans.data.cpu()

def vqa(image_filename, question_arr):
    plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()    
    image_features = resnet_layer4.image_to_features(image_filename)
    for question_str in question_arr:
        _, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
        #print(question_str+" -> "+answer_words[ answer_idx ])
        print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)

In [ ]:
image_idx = 0  # 6 

vqa(image_files[image_idx], [
    "is there a cat in the picture",
    "is this a picture of a cat",
    "is the animal in the picture a cat or a dog",
    "what color is the cat",
    "what color are the cat's eyes",
])

Now let's stress the model

Leave one word out


In [ ]:
def leave_one_out(image_filename, question_base):
    plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()    
    image_features = resnet_layer4.image_to_features(image_filename)
    question_arr = question_base.lower().split(' ')
    for i, word_omit in enumerate(question_arr):
        question_str = ' '.join( question_arr[:i]+question_arr[i+1:] )
        score, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
        #print(question_str+" -> "+answer_words[ answer_idx ])
        print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)  #, score

In [ ]:
image_idx = 0

leave_one_out(image_files[image_idx], "is there a cat in the picture")  # mouse? dog?

Leave all combos of words out ( think : Binary )


In [ ]:
def leave_out_combos(image_filename, question_base):
    plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()    
    image_features = resnet_layer4.image_to_features(image_filename)
    question_arr = question_base.lower().split(' ')
    for i in range(2 ** len(question_arr)):
        q_arr = [question_arr[j] for j in range(len(question_arr)) if (i & (2**j))==0 ]
        question_str = ' '.join( q_arr )
        _, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
        print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)

In [ ]:
image_idx = 4

leave_out_combos(image_files[image_idx], "is there a cat in the picture")
#leave_out_combos(image_files[image_idx], "what color are cat's eyes")

Iteratively, leave out the word that is 'weakest'


In [ ]:
def leave_out_best(image_filename, question_base):
    plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()    
    image_features = resnet_layer4.image_to_features(image_filename)
    _, answer_true = vqa_single_softmax(image_features, question_base).max(dim=1)
    print((answer_words[ answer_true ]+' '*8)[:8]+" <- "+question_base)
    print()
    while True:
        question_arr = question_base.lower().split(' ')
        score_best, q_best = None, ''
        for i, word_omit in enumerate(question_arr):
            question_str = ' '.join( question_arr[:i]+question_arr[i+1:] )
            score, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
            if answer_idx==answer_true:
                print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)  #, score        
                if (score_best is None or score>score_best):
                    score_best, question_base = score, question_str
        print()
        if score_best is None or len(question_base)==0: break

In [ ]:
image_idx = 3

leave_out_best(image_files[image_idx], "is there a cat in the picture")

In [ ]: