K-SVD et voitures de luxe...

Ce notebook se présente comme une proof of concept de l'utilisation que l'on peut faire du traitement des images par représentation sparse en général et de l'algorithme k-svd en particulier.

La méthode utilisée est la suivante:

  • on crée un dictionnaire à partir d'une base de données initiales constituée de petites images (ou patchs) issus d'une séri d'images représentant des voitures de luxe
  • on choisit une image nouvelle image de voiture sur internet que l'on bruite artificiellement
  • on calcule ensuite sa représentation sparse grâce à un algorithme du type batch-OMP et on voit ce que ça donne !

Constitution de la base d'entraînement


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
%matplotlib inline

In [2]:
path2db = "../Database/Cars/"

Un exemple d'image au hasard, juste pour ce faire une idée:


In [3]:
imgs = !ls $path2db
plt.imshow(plt.imread(path2db+imgs[1]))
plt.title("Un exemple de voiture")


Out[3]:
<matplotlib.text.Text at 0xaff6758c>

Fonction pour redimensioner les images

Sans cette fonction, on risque de générer trop de patchs. Dans un premier temps, je préfère diminuer un peu la taille des images. On verra ensuite si on peut lever cette restriction.


In [4]:
from skimage.transform import rescale

In [5]:
def resize_car(img):
    """Allow at most images of size (250, 250)"""
    x, y = img.shape[0], img.shape[1]
    if max(x, y) > 250:
        l = 250./max(x,y)
        img = rescale(img, l)
    return img

Division en patchs


In [6]:
def div_patchs(img, patch_dim = (8, 8), gray = True):
    """Return a 2d array the columns of which are the patches extracted from img.
    Args:
    - img: representation of an image as a 2d array
    - path_dim: dimensions of one patch (in pixels)
    - gray: wether to convert or not img to grey levels"""
    
    output = []
    xdim, ydim = patch_dim
    
    #if necessary convert img to gray
    if gray:
        img = rgb2gray(img)
        
    #extract patches
    for i in range(img.shape[0]/xdim):
        for j in range(img.shape[1]/ydim):
            output.append(img[i*xdim:(i+1)*xdim,j*ydim:(j+1)*ydim,].ravel())
            
    return np.array(output).T

In [7]:
def assemble_patches(patches, original_dim, patch_dim = (8, 8), gray = True):
    """Reassemble patches, given original dimesion of the image
    that was divided into patches"""
    
    #If gray, adapt original dimension
    if gray:
        original_dim = original_dim[:2]
        
    #a list to store all reshaped patches 
    temp = []
    #shortcuts
    xdim, ydim = patch_dim
    
    #reshape all patches (current columns of 'patches' matrix)
    for j in range(patches.shape[1]):
        temp.append(patches[:,j].reshape(patch_dim))
        
    #compute output image
    output_img = np.zeros(original_dim)
    k = 0
    #re-order patches
    for i in range(original_dim[0]/xdim):
        for j in range(original_dim[1]/ydim):
            output_img[i*xdim:(i+1)*xdim,j*ydim:(j+1)*ydim] = temp[k]
            k += 1
            
    return output_img

Une petite démonstration pour s'assurer que les fonctions fonctionnent bien ;) On remarque une bordure noir sur l'image réassemblée qui vient du fait que les dimensions originales de l'image ne sont pas nécesaairement des multiples des dimensions des patches.


In [8]:
img = rgb2gray(resize_car(plt.imread(path2db+imgs[1])))
patches = div_patchs(img)
reassembled = assemble_patches(patches, img.shape)
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(img, cmap = 'gray')
plt.title("Original image")
plt.subplot(1, 2, 2)
plt.imshow(reassembled, cmap = 'gray')
plt.title("Reassembled image")


Out[8]:
<matplotlib.text.Text at 0xacbb59cc>

Un aperçu de quelques patchs...


In [9]:
n = 20
plt.figure(figsize=(15, 12))
for i in range(n):
    plt.subplot(4, 5, i)
    plt.imshow(patches[:,np.random.randint(low = 0, high = patches.shape[1])].reshape((8, 8)),
                                           cmap = 'gray')


Amélioration possible: montrer d'où viennent les patchs dans l'image!

Assemblage de la base totale


In [15]:
X = np.hstack([div_patchs(resize_car(plt.imread(path2db+name))) for name in imgs[1:]])

In [16]:
X.shape


Out[16]:
(64, 21545)

On sélectionne au hasard 5000 patchs parmi ceux que l'on généré (c'est visiblement comme cela qu'ils ont fait dans l'article...)


In [17]:
mask = [True for _ in range(5000)] + [False for _ in range(X.shape[1] - 5000)]

In [18]:
mask = np.array(mask)
np.random.shuffle(mask)

In [19]:
Xtrain = X[:,mask]

Construction du dictionnaire


In [8]:
import imp
ksvd = imp.load_source('ksvd', '../Source/ksvd.py')

In [21]:
model = ksvd.KSVD(D = (64, 200), K = 10)

In [25]:
%%time
model.fit(Xtrain)


Training dictionary over 200 iterations
 [-----------------100%-----------------] 201 of 200 complete in 1653.5 sec   Done!
CPU times: user 27min 20s, sys: 11.6 s, total: 27min 31s
Wall time: 27min 33s

In [26]:
#Comme ce genre de calcul prend pas mal de temps,
#autant sauvegarder les résultats
import pickle
with open ('model.pickle', 'wb') as f:
    pickle.dump(model, f)

In [9]:
#Pour charger le modèle sans avoir à le ré-entraîner
import pickle
with open('model.pickle', 'rb') as f:
    model = pickle.load(f)

Test sur une image intégrale...

Sans bruit (commençons lentement mais surement)


In [10]:
img = resize_car(plt.imread(path2db+imgs[0]))
plt.imshow(img)


Out[10]:
<matplotlib.image.AxesImage at 0xac880ccc>

In [11]:
pp = div_patchs(img)
gamma = model.sparse_rep(pp)
new_patches = model.D.dot(gamma)
resulting_img = assemble_patches(new_patches, img.shape[:2])

In [12]:
plt.imshow(resulting_img, cmap = 'gray')


Out[12]:
<matplotlib.image.AxesImage at 0xac85f36c>

Avec bruit

Tout d'abord on définit une fonction permettant de calculer l'erreur de reconstruction (on utilise exactement la fonction de l'article): $$\sqrt{||B - \tilde{B}||^2 _F /64}$$ avec $B$ l'image initiale non bruitée et $\tilde{B}$ l'image débruitée.


In [21]:
def rec_err(B, B_):
    """Returns mean reconstrution error for all patches.
    Args: - B original image with no noise
          - B_ denoised image
    """
    
    return np.linalg.norm(div_patchs(B) - div_patchs(B_), axis=0).mean()/8.

On commence par bruiter très simplement une image


In [15]:
mask = np.random.random(img.shape[:2])
noisy = img.copy()
noisy[(mask > .9)] = 0

In [16]:
pp = div_patchs(noisy)
new_patches = model.denoise_patches(pp, .01)
resulting_img = assemble_patches(new_patches, img.shape[:2])

In [22]:
from skimage.filter import denoise_bilateral

plt.figure(figsize=(15, 8))
plt.subplot(131)
plt.imshow(resulting_img, cmap = 'gray')
plt.title("K-SVD denoising (err: {})".format(rec_err(rgb2gray(img), resulting_img)))
plt.subplot(132)
plt.imshow(denoise_bilateral(noisy), cmap = 'gray')
plt.title("Bilateral denoisng (err: {})".format(rec_err(rgb2gray(img), 
                                                    rgb2gray(denoise_bilateral(noisy)))))
plt.subplot(133)
plt.imshow(noisy, cmap = 'gray')
plt.title("Noisy image")


Out[22]:
<matplotlib.text.Text at 0xac0e81ec>

Un autre essai avec plus de bruit...


In [25]:
mask = np.random.random(img.shape[:2])
noisy = img.copy()
noisy[(mask > .6)] = 0
pp = div_patchs(noisy)
new_patches = model.denoise_patches(pp, .01)
resulting_img = assemble_patches(new_patches, img.shape[:2])

In [26]:
plt.figure(figsize=(10, 8))
plt.subplot(121)
plt.imshow(resulting_img, cmap = 'gray')
plt.title("K-SVD denoising (err: {})".format(rec_err(rgb2gray(img), resulting_img)))
plt.subplot(122)
plt.imshow(noisy, cmap = 'gray')
plt.title("Noisy image")


Out[26]:
<matplotlib.text.Text at 0xabb28e2c>

Pour finir, une courbe représentant l'erreur de reconstitution en fonction de la densité des pixels supprimés.


In [27]:
xx = [x/10. for x in range(1, 10)]
RMSE = []

for x in xx:
    #On construit une image bruitée
    mask = np.random.random(img.shape[:2])
    noisy = img.copy()
    noisy[(mask < x)] = 0
    pp = div_patchs(noisy)
    new_patches = model.denoise_patches(pp, .01)
    resulting_img = assemble_patches(new_patches, img.shape[:2])
    
    RMSE.append(rec_err(resulting_img, rgb2gray(img)))

In [30]:
plt.plot(xx, RMSE)
plt.title(u"Qualité de reconstruction")
plt.xlabel(u"Ratio de pixels corrompus")
plt.ylabel(u"Erreur de reconstruction")
plt.axis([.1, .9, 0., .55])


Out[30]:
[0.1, 0.9, 0.0, 0.55]

Bien que l'on obtienne des ratios du même ordre de grandeur que dans l'article, les nôtres sont un peu moins bons... Cela peut être dû:

  • au faut que notre dictionnaire contienne nettement moins d'atoms
  • au fait que nous avons construit notre dictionnaire à partir d'un nombre plus faible de patches