In [1]:
import numpy as np
import glob

In [2]:
data_files = glob.glob('../../exp-vqa-shape/vqa_shape_dataset/train.*.input.npy')

In [3]:
count = 0
image_sum = 0

for file in data_files:
    data = np.load(file)
    image_sum += np.sum(data, axis=0)
    count += len(data)

image_mean = image_sum / count

In [4]:
np.save('../../exp-vqa-shape/data/image_mean.npy', image_mean.astype(np.float32))

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(image_mean.astype(np.uint8))


Out[5]:
<matplotlib.image.AxesImage at 0x7f8a69b51780>