In [1]:
import torch
from torch.utils.model_zoo import load_url
from seq2seq.models import Img2Seq
from seq2seq.tools.inference import CaptionGenerator
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.misc import imresize
In [2]:
checkpoint = load_url('https://dl.dropboxusercontent.com/s/05dvriaiqk74cum/caption_resnet50-4c0fa803.pth',
map_location={'gpu:0':'cpu'})
model = Img2Seq(**checkpoint['config'].model_config)
model.load_state_dict(checkpoint['state_dict'])
img_transform, target_tok = checkpoint['tokenizers'].values()
caption_model = CaptionGenerator(model,
img_transform=img_transform(allow_var_size=True, train=False),
target_tok=target_tok,
beam_size=12,
get_attention=True,
length_normalization_factor=0.4,
cuda=False)
def show_and_tell(filename):
img = Image.open(filename, 'r')
%matplotlib inline
plt.figure()
plt.imshow(np.asarray(img))
return img, caption_model.describe(img)
def visualize_attention(img, attention, max_size=128., thresh=0.5):
img = np.asarray(img)
W, H = img.shape[1], img.shape[0]
ratio = max_size / max(W,H)
W, H = int(W*ratio), int(H*ratio)
img = imresize(img, (H,W))
attention, preds = attention
fig, plots = plt.subplots(len(preds), 1, figsize=(100, 40))
for i, p in enumerate(preds):
resized_attention = imresize(attention[i].data.cpu().numpy(), (H,W))
resized_attention = resized_attention / resized_attention.max()
mask = resized_attention > thresh
masked_img = img * mask[:,:,None]
plots[i].set_title(p)
plots[i].imshow(masked_img)
plt.tight_layout()
plt.show()
In [3]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000200348.jpg')
print(caption)
In [4]:
visualize_attention(img, attentions)
In [5]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000450004.jpg')
print(caption)
In [6]:
visualize_attention(img, attentions)
In [7]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000400033.jpg')
print(caption)
In [8]:
visualize_attention(img, attentions)
In [ ]: