In [ ]:
# Upgrade pillow to latest version (solves a colab Issue) :
! pip install -U 'Pillow>=5.2.0'
In [ ]:
import os, sys
from matplotlib import pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # Cleaner demos : Don't do this normally...
In [ ]:
if not os.path.isfile('./pytorch-vqa/README.md'):
!git clone https://github.com/Cyanogenoid/pytorch-vqa.git
sys.path.append(os.path.realpath('./pytorch-vqa'))
In [ ]:
# https://github.com/Cyanogenoid/pytorch-vqa/releases
if not os.path.isfile('./2017-08-04_00.55.19.pth'): # 81Mb model
!wget https://github.com/Cyanogenoid/pytorch-vqa/releases/download/v1.0/2017-08-04_00.55.19.pth
In [ ]:
try:
import torch
except:
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
accelerator = 'cu80' if os.path.exists('/opt/bin/nvidia-smi') else 'cpu'
!pip install -q \
http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl \
torchvision
In [ ]:
import torch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
In [ ]:
import model # from pytorch-vqa
#saved_state = torch.load('logs/2017-08-04_00:55:19.pth')
saved_state = torch.load('./2017-08-04_00.55.19.pth', map_location=device)
tokens = len(saved_state['vocab']['question']) + 1
saved_state.keys() # See what's in the saved state
In [ ]:
# Load the predefined model
vqa_net = torch.nn.DataParallel(model.Net(tokens))
vqa_net.load_state_dict(saved_state['weights'])
vqa_net.to(device)
vqa_net.eval()
In [ ]:
if not os.path.isfile('./pytorch-resnet/README.md'):
!git clone https://github.com/Cyanogenoid/pytorch-resnet.git
sys.path.append(os.path.realpath('./pytorch-resnet'))
In [ ]:
import resnet # from pytorch-resnet
import torchvision.transforms as transforms
from PIL import Image
def get_transform(target_size, central_fraction=1.0):
return transforms.Compose([
transforms.Scale(int(target_size / central_fraction)),
transforms.CenterCrop(target_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
class ResNetLayer4(torch.nn.Module):
def __init__(self):
super(ResNetLayer4, self).__init__()
self.model = resnet.resnet152(pretrained=True)
# from visual_qa_analysis/config.py
image_size = 448 # scale shorter end of image to this size and centre crop
#output_size = image_size // 32 # size of the feature maps after processing through a network
output_features = 2048 # number of feature maps thereof
central_fraction = 0.875 # only take this much of the centre when scaling and centre cropping
self.transform = get_transform(image_size, central_fraction)
def save_output(module, input, output):
self.buffer = output
self.model.layer4.register_forward_hook(save_output)
def forward(self, x):
self.model(x)
return self.buffer
def image_to_features(self, img_file):
img = Image.open(img_file).convert('RGB')
img_transformed = self.transform(img)
#print(img_transformed.size())
img_batch = img_transformed.unsqueeze(0).to(device)
return self.forward(img_batch)
resnet_layer4 = ResNetLayer4().to(device) # Downloads 241Mb model when first run
In [ ]:
# Sample images :
image_urls, image_path, image_files = [
'https://www.pets4homes.co.uk/images/articles/2709/large/tabby-cat-colour-and-pattern-genetics-5516c44dbd383.jpg',
'https://imgc.allpostersimages.com/img/print/posters/cat-black-jumping-off-wall_a-G-12469828-14258383.jpg',
'https://i.ytimg.com/vi/AIwlyly7Eso/hqdefault.jpg',
'https://upload.wikimedia.org/wikipedia/commons/9/9b/Black_pussy_-_panoramio.jpg',
'https://www.thehappycatsite.com/wp-content/uploads/2017/06/siamese5.jpg',
'https://c.pxhere.com/photos/15/e5/cat_roof_home_architecture_building_roofs_animal_sit-536976.jpg!d',
'http://kitticats.com/wp-content/uploads/2015/05/cat4.jpg',
], './img/', []
os.makedirs('./img', exist_ok=True)
for url in image_urls:
image_file=os.path.join(image_path, os.path.basename(url))
image_files.append(image_file)
if not os.path.isfile(image_file):
!wget {url} --directory-prefix ./img/
image_files
In [ ]:
v = resnet_layer4.image_to_features(image_files[0])
v.size()
In [ ]:
vocab = saved_state['vocab']
vocab.keys() # dict_keys(['question', 'answer'])
list(vocab['question'].items())[:5] # [('the', 1), ('is', 2), ('what', 3), ('are', 4), ('this', 5)]
In [ ]:
qtoken_to_index = vocab['question']
QUESTION_LENGTH_MAX = 30 # say...
def encode_question(question_str):
""" Turn a question into a vector of indices and a question length """
question_arr = question_str.lower().split(' ')
#vec = torch.zeros(QUESTION_LENGTH_MAX).long()
vec = torch.zeros(len(question_arr)).long()
for i, token in enumerate(question_arr):
vec[i] = qtoken_to_index.get(token, 0)
return vec.to(device), torch.tensor( len(question_arr) ).to(device)
In [ ]:
list(vocab['answer'].items())[:5] # [('yes', 0), ('no', 1), ('2', 2), ('1', 3), ('white', 4)]
In [ ]:
answer_words = ['UNDEF'] * len(vocab['answer'])
for w,idx in vocab['answer'].items():
answer_words[idx]=w
len(answer_words), answer_words[:10] # 3000, ['yes', 'no', '2', '1', 'white', '3', 'red', 'blue', '4', 'green']
In [ ]:
# Important things to know...
'colour' in qtoken_to_index, 'color' in qtoken_to_index, 'tabby' in answer_words
In [ ]:
image_idx = 1
image_filename = image_files[image_idx]
img = Image.open(image_filename).convert('RGB')
plt.imshow(img)
In [ ]:
v0 = resnet_layer4.image_to_features(image_filename)
In [ ]:
q, q_len = encode_question("is there a cat in the picture")
#q, q_len = encode_question("what color is the cat's fur")
#q, q_len = encode_question("is the cat jumping up or down")
q, q_len
In [ ]:
ans = vqa_net(v0, q.unsqueeze(0), q_len.unsqueeze(0))
ans.data.cpu()[0:10]
In [ ]:
_, answer_idx = ans.data.cpu().max(dim=1)
answer_words[ answer_idx ]
In [ ]:
def vqa_single_softmax(im_features, q_str):
q, q_len = encode_question(q_str)
ans = vqa_net(im_features, q.unsqueeze(0), q_len.unsqueeze(0))
return ans.data.cpu()
def vqa(image_filename, question_arr):
plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()
image_features = resnet_layer4.image_to_features(image_filename)
for question_str in question_arr:
_, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
#print(question_str+" -> "+answer_words[ answer_idx ])
print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)
In [ ]:
image_idx = 0 # 6
vqa(image_files[image_idx], [
"is there a cat in the picture",
"is this a picture of a cat",
"is the animal in the picture a cat or a dog",
"what color is the cat",
"what color are the cat's eyes",
])
In [ ]:
def leave_one_out(image_filename, question_base):
plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()
image_features = resnet_layer4.image_to_features(image_filename)
question_arr = question_base.lower().split(' ')
for i, word_omit in enumerate(question_arr):
question_str = ' '.join( question_arr[:i]+question_arr[i+1:] )
score, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
#print(question_str+" -> "+answer_words[ answer_idx ])
print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str) #, score
In [ ]:
image_idx = 0
leave_one_out(image_files[image_idx], "is there a cat in the picture") # mouse? dog?
In [ ]:
def leave_out_combos(image_filename, question_base):
plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()
image_features = resnet_layer4.image_to_features(image_filename)
question_arr = question_base.lower().split(' ')
for i in range(2 ** len(question_arr)):
q_arr = [question_arr[j] for j in range(len(question_arr)) if (i & (2**j))==0 ]
question_str = ' '.join( q_arr )
_, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str)
In [ ]:
image_idx = 4
leave_out_combos(image_files[image_idx], "is there a cat in the picture")
#leave_out_combos(image_files[image_idx], "what color are cat's eyes")
In [ ]:
def leave_out_best(image_filename, question_base):
plt.imshow(Image.open(image_filename).convert('RGB')); plt.show()
image_features = resnet_layer4.image_to_features(image_filename)
_, answer_true = vqa_single_softmax(image_features, question_base).max(dim=1)
print((answer_words[ answer_true ]+' '*8)[:8]+" <- "+question_base)
print()
while True:
question_arr = question_base.lower().split(' ')
score_best, q_best = None, ''
for i, word_omit in enumerate(question_arr):
question_str = ' '.join( question_arr[:i]+question_arr[i+1:] )
score, answer_idx = vqa_single_softmax(image_features, question_str).max(dim=1)
if answer_idx==answer_true:
print((answer_words[ answer_idx ]+' '*8)[:8]+" <- "+question_str) #, score
if (score_best is None or score>score_best):
score_best, question_base = score, question_str
print()
if score_best is None or len(question_base)==0: break
In [ ]:
image_idx = 3
leave_out_best(image_files[image_idx], "is there a cat in the picture")
In [ ]: