In [0]:
import corruptions
from importlib import reload
import matplotlib.pyplot as plt
import torchvision
import numpy as np
%matplotlib inline

In [0]:
train_mnist = torchvision.datasets.MNIST("../data/", train=True, download=True)
test_mnist = torchvision.datasets.MNIST("../data/", train=False, download=True)
IMAGES = [test_mnist[i][0] for i in range(50)]
LABELS = [test_mnist[i][1] for i in range(50)]

In [0]:
def show(x):
    plt.imshow(x, cmap='gray', vmin=0, vmax=255)
    plt.axis("off")
    plt.show()
    
def round_and_astype(x):
    return np.round(x).astype(np.uint8)

def inspect(corruption):
    for im, l in zip(IMAGES, LABELS):
        print("Label: " + str(l))
        x = np.array(corruption(im))
        show(round_and_astype(x))
        
def inspect_single(image, corruption):
    x = np.array(corruption(image))
    show(round_and_astype(x))
    
def save(image, corruption, filename):
    x = round_and_astype(np.array(corruption(image)))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.set_frame_on(False)
    ax.imshow(x, cmap='gray', vmin=0, vmax=255)
    plt.savefig(filename, bbox_inches='tight',transparent=True, pad_inches=0)

In [0]:
reload(corruptions)
inspect(corruptions.glass_blur)