In [0]:
import corruptions
from importlib import reload
import matplotlib.pyplot as plt
import torchvision
import numpy as np
%matplotlib inline
In [0]:
train_mnist = torchvision.datasets.MNIST("../data/", train=True, download=True)
test_mnist = torchvision.datasets.MNIST("../data/", train=False, download=True)
IMAGES = [test_mnist[i][0] for i in range(50)]
LABELS = [test_mnist[i][1] for i in range(50)]
In [0]:
def show(x):
plt.imshow(x, cmap='gray', vmin=0, vmax=255)
plt.axis("off")
plt.show()
def round_and_astype(x):
return np.round(x).astype(np.uint8)
def inspect(corruption):
for im, l in zip(IMAGES, LABELS):
print("Label: " + str(l))
x = np.array(corruption(im))
show(round_and_astype(x))
def inspect_single(image, corruption):
x = np.array(corruption(image))
show(round_and_astype(x))
def save(image, corruption, filename):
x = round_and_astype(np.array(corruption(image)))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.set_frame_on(False)
ax.imshow(x, cmap='gray', vmin=0, vmax=255)
plt.savefig(filename, bbox_inches='tight',transparent=True, pad_inches=0)
In [0]:
reload(corruptions)
inspect(corruptions.glass_blur)