In [1]:
import torch
from torch.utils.model_zoo import load_url
from seq2seq.models import Img2Seq
from seq2seq.tools.inference import CaptionGenerator
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.misc import imresize

In [2]:
checkpoint = load_url('https://dl.dropboxusercontent.com/s/05dvriaiqk74cum/caption_resnet50-4c0fa803.pth',
                      map_location={'gpu:0':'cpu'})

model = Img2Seq(**checkpoint['config'].model_config)
model.load_state_dict(checkpoint['state_dict'])
img_transform, target_tok = checkpoint['tokenizers'].values()

caption_model = CaptionGenerator(model,
                               img_transform=img_transform(allow_var_size=True, train=False),
                               target_tok=target_tok,
                               beam_size=12,
                               get_attention=True,
                               length_normalization_factor=0.4,
                               cuda=False)

def show_and_tell(filename):
    img = Image.open(filename, 'r')
    
    %matplotlib inline
    plt.figure()
    plt.imshow(np.asarray(img))

    
    return img, caption_model.describe(img)

def visualize_attention(img, attention, max_size=128., thresh=0.5):
    img = np.asarray(img)
    W, H = img.shape[1], img.shape[0]
    ratio = max_size / max(W,H)
    W, H = int(W*ratio), int(H*ratio)
    img = imresize(img, (H,W))
    attention, preds = attention
    fig, plots = plt.subplots(len(preds), 1, figsize=(100, 40))
    for i, p in enumerate(preds):
        resized_attention = imresize(attention[i].data.cpu().numpy(), (H,W))
        resized_attention = resized_attention / resized_attention.max()
        mask = resized_attention > thresh
        masked_img = img * mask[:,:,None]
        plots[i].set_title(p)
        plots[i].imshow(masked_img)
    plt.tight_layout()
    plt.show()

In [3]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000200348.jpg')
print(caption)


A man holding a baseball bat on a field.

In [4]:
visualize_attention(img, attentions)



In [5]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000450004.jpg')
print(caption)


A black bear is standing in the dirt.

In [6]:
visualize_attention(img, attentions)



In [7]:
img, (caption, attentions) = show_and_tell('/media/ssd/Datasets/COCO/val2014/COCO_val2014_000000400033.jpg')
print(caption)


A group of people standing on top of a beach.

In [8]:
visualize_attention(img, attentions)



In [ ]: