In [ ]:
from tensorflow.keras.datasets import mnist, fashion_mnist
_ = mnist.load_data()
_ = fashion_mnist.load_data()
In [ ]:
if not op.exists("BSR_bsds500.tgz"):
urlretrieve("http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz", "BSR_bsds500.tgz")
In [ ]:
import tarfile
import os
import pickle as pkl
import numpy as np
import skimage
import skimage.io
import skimage.transform
BST_PATH = 'BSR_bsds500.tgz'
rand = np.random.RandomState(42)
f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
if name.startswith('BSR/BSDS500/data/images/train/'):
train_files.append(name)
print('Loading BSR training images')
background_data = []
for name in train_files:
try:
fp = f.extractfile(name)
bg_img = skimage.io.imread(fp)
background_data.append(bg_img)
except:
continue
def compose_image(digit, background):
"""Difference-blend a digit and a random patch from a background image."""
w, h, _ = background.shape
dw, dh, _ = digit.shape
x = np.random.randint(0, w - dw)
y = np.random.randint(0, h - dh)
bg = background[x:x+dw, y:y+dh]
return np.abs(bg - digit).astype(np.uint8)
def mnist_to_img(x):
"""Binarize MNIST digit and convert to RGB."""
x = (x > 0).astype(np.float32)
d = x.reshape([28, 28, 1]) * 255
return np.concatenate([d, d, d], 2)
def create_mnistm(X):
"""
Give an array of MNIST digits, blend random background patches to
build the MNIST-M dataset as described in
http://jmlr.org/papers/volume17/15-239/15-239.pdf
"""
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
for i in range(X.shape[0]):
if i % 10000 == 0:
print('Processing example', i)
bg_img = rand.choice(background_data)
d = mnist_to_img(X[i])
d = compose_image(d, bg_img)
X_[i] = d
return X_
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print('Building train set...')
train = create_mnistm(x_train)
print('Building test set...')
test = create_mnistm(x_test)
# Save dataset as pickle
with open('mnistm_data.pkl', 'wb+') as f:
pkl.dump({ 'x_train': train, 'x_test': test, "y_train": y_train, "y_test": y_test}, f, pkl.HIGHEST_PROTOCOL)
print("Done!")
In [ ]: