In [1]:
%matplotlib inline
from lab006 import *
import numpy as np
import matplotlib.pyplot as plt
# Question 1
target = list(mnist_read('train-images-idx3-ubyte', 'train-labels-idx1-ubyte'))
plt.figure()
for i in range(5):
label, img = target[i]
# Max Pooling
maxPoolResult = max_pooling(img)
maxPoolResult.astype(np.uint8)
plt.subplot(2, 5, i+1)
mnist_show(maxPoolResult)
# Average Pooling
avgPoolResult = avg_pooling(img)
avgPoolResult.astype(np.uint8)
plt.subplot(2, 5, i+6)
mnist_show(avgPoolResult)
plt.show()