In [722]:
import numpy as np

import scipy.linalg
import scipy.weave

import sklearn
from sklearn.base import BaseEstimator

class RobustPCA(BaseEstimator):
    '''Robust PCA
    
    Candès, Emmanuel J., et al. "Robust principal component analysis?." Journal of the ACM (JACM) 58.3 (2011): 11.
    
    http://arxiv.org/abs/0912.3599
    
    '''
    
    def __nuclear_prox(self, A, r=1.0):
        '''Proximal operator for scaled nuclear norm:
        Y* <- argmin_Y  r * ||Y||_* + 1/2 * ||Y - A||_F^2

        Arguments:
            A    -- (ndarray) input matrix
            r    -- (float>0) scaling factor

        Returns:
            Y    -- (ndarray) if A = USV', then Y = UTV'
                              where T = max(S - r, 0)
        '''
    
        U, S, V = scipy.linalg.svd(A, full_matrices=False)
        
        T = np.maximum(S - r, 0.0)
        
        Y = (U * T).dot(V)
        
        return Y

    def __l1_prox(self, A, r=1.0):
        '''Proximal operator for entry-wise matrix l1 norm:
        Y* <- argmin_Y r * ||Y||_1 + 1/2 * ||Y - A||_F^2

        Arguments:
            A    -- (ndarray) input matrix
            r    -- (float>0) scaling factor

        Returns:
            Y    -- (ndarray) Y = A after shrinkage
        '''
    
        Y = np.zeros_like(A)
    
        numel = A.size
        
        shrinkage = r"""
            for (int i = 0; i < numel; i++) {
                Y[i] = 0;
    
                if (A[i] - r > 0) {
                    Y[i] = A[i] - r;
                } else if (A[i] + r <= 0) {
                    Y[i] = A[i] + r;
                }
            }
        """
    
        scipy.weave.inline(shrinkage, ['numel', 'A', 'r', 'Y'])
        
        return Y

    def __cost(self, Y, Z):
        '''Get the cost of an RPCA solution.

        Arguments:
            Y       -- (ndarray)    the low-rank component
            Z       -- (ndarray)    the sparse component
            alpha   -- (float>0)    the balancing factor

        Returns:
            total, nuclear_norm, l1_norm -- (list of floats)
        '''
        nuclear_norm = scipy.linalg.svd(Y,
                                        full_matrices=False,
                                        compute_uv=False).sum()
        
        l1_norm = np.abs(Z).sum()
    
        return nuclear_norm + self.alpha_ * l1_norm, nuclear_norm, l1_norm

    def __init__(self, alpha=None, max_iter=100, verbose=False):
        '''
        
        Arguments:
            alpha -- (float > 0) weight between low-rank and noise term
                     If left as None, alpha will be automatically set to
                     
                     sqrt(max(X.shape))
                     
            max_iter -- (int > 0) maximum number of iterations
        '''
        
        self.alpha = alpha
        self.max_iter = max_iter
        self.verbose = verbose
        
    def fit(self, X):
        '''Fit the robust PCA model to a matrix X'''
        
        self.fit_transform(X)
        return self
    
    def fit_transform(self, X):
        
        # Some magic numbers for dynamic augmenting penalties in ADMM.
        # Changing these shouldn't effect correctness, only convergence rate.
        
        RHO_MIN      = 1e0
        RHO_MAX      = 1e5
        MAX_RATIO    = 2e0
        SCALE_FACTOR = 1.5e0
        
        ABS_TOL      = 1e-4
        REL_TOL      = 1e-3
    
        # update rules:
        #  Y+ <- nuclear_prox(X - Z - W, 1/rho)
        #  Z+ <- l1_prox(X - Y - W, alpha/rho)
        #  W+ <- W + Y + Z - X
    
        # Initialize
        rho = RHO_MIN
    
        # Scale the data to a workable range
        X = X.astype(np.float)
        Xmin = np.min(X)
        rescale = max(1e-8, np.max(X - Xmin))
        
        Xt = (X - Xmin) / rescale
        
        Y   = Xt.copy()
        Z   = np.zeros_like(Xt)
        W   = np.zeros_like(Xt)
    
        norm_X = scipy.linalg.norm(Xt)
    
        if self.alpha is None:
            self.alpha_ = max(Xt.shape)**(-0.5)
        else:
            self.alpha_ = self.alpha

        m   = X.size

        _DIAG = {
             'err_primal': [],
             'err_dual':   [],
             'eps_primal': [],
             'eps_dual':   [],
             'rho':        []
        }
        
        for t in range(self.max_iter):
            Y = self.__nuclear_prox(Xt - Z - W, 1.0/rho)
            Z_old = Z.copy()
            Z = self.__l1_prox(Xt - Y - W, self.alpha_ /  rho)

            residual_pri  = Y + Z - Xt
            residual_dual = Z - Z_old
        
            res_norm_pri  = scipy.linalg.norm(residual_pri)
            res_norm_dual = rho * scipy.linalg.norm(residual_dual)
        
            W = W + residual_pri
        
            eps_pri  = np.sqrt(m) * ABS_TOL + REL_TOL * max(scipy.linalg.norm(Y), scipy.linalg.norm(Z), norm_X)
            eps_dual = np.sqrt(m) * ABS_TOL + REL_TOL * scipy.linalg.norm(W)
        
            _DIAG['eps_primal'].append(eps_pri)
            _DIAG['eps_dual'  ].append(eps_dual)
            _DIAG['err_primal'].append(res_norm_pri)
            _DIAG['err_dual'  ].append(res_norm_dual)
            _DIAG['rho'       ].append(rho)
        
            if res_norm_pri <= eps_pri and res_norm_dual <= eps_dual:
                break
            
            if res_norm_pri > MAX_RATIO * res_norm_dual and rho * SCALE_FACTOR <= RHO_MAX:
                rho = rho * SCALE_FACTOR
                W   = W / SCALE_FACTOR
            
            elif res_norm_dual > MAX_RATIO * res_norm_pri and rho / SCALE_FACTOR >= RHO_MIN:
                rho = rho / SCALE_FACTOR
                W   = W * SCALE_FACTOR
       
        if self.verbose:
            if t < self.max_iter - 1:
                print 'Converged in %d steps' % t
            else:
                print 'Reached maximum iterations'
    
        # Scale back up to the original data scale
        Z = (Z + Xmin) * rescale
        self.embedding_= X - Z
        
        _DIAG['cost'] = self.__cost(self.embedding_, Z)
        
        self.diagnostics_ = _DIAG
        
        return self.embedding_

Example application: image denoising


In [723]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn
seaborn.set(style='white')
%matplotlib inline

In [724]:
import sklearn.datasets

In [725]:
# Load in the example images from sklearn
data = sklearn.datasets.load_sample_images()

In [726]:
# Convert each image from RGB to grayscale
X = [np.mean(D, axis=-1) for D in data.images]

In [736]:
# Pull out the first image + laplace noise
dx = X[0] + np.random.laplace(scale=5, size=X[0].shape)

In [737]:
plt.subplot(121)
plt.imshow(X[0], interpolation='none')
plt.axis('off')
plt.title('Original')

plt.subplot(122)
plt.imshow(dx, interpolation='none')
plt.axis('off')
plt.title('Noisy')

plt.tight_layout()



In [738]:
# Build a model object
M = RobustPCA(verbose=True)

In [739]:
M.fit(dx)


Converged in 23 steps
Out[739]:
RobustPCA(alpha=None, max_iter=100, verbose=True)

In [740]:
plt.semilogy(M.diagnostics_['err_primal'], label='Primal error')
plt.semilogy(M.diagnostics_['err_dual'], label='Dual error')
plt.xlabel('Iterations')
plt.legend()
plt.tight_layout()



In [741]:
plt.plot(M.diagnostics_['rho'], label=r'$\rho$')
plt.legend()
plt.title('Augmenting term factor')
plt.tight_layout()



In [742]:
# Helper function to pull out the normalized spectrum of a matrix
def spectrum(X, norm=True):
    
    v = scipy.linalg.svd(X, compute_uv=False)
    if norm:
        v = v / v.max()
    return v

In [743]:
# How do the spectra compare?

plt.bar(np.arange(10), spectrum(dx)[:10], width=0.45, label='Input X')
plt.bar(np.arange(10) + 0.5, spectrum(M.embedding_)[:10], width=0.45, color='r', label='Low-rank approximation')

plt.xticks(0.45 + np.arange(10), range(1,11))
plt.xlabel(r'$i$')

plt.ylabel(r'$\sigma_i / \sigma_1$')

plt.title('Normalized singular value distribution')

plt.legend()
plt.tight_layout()



In [744]:
# How do the images look?
plt.subplot(131)
plt.imshow(dx, interpolation='none')
plt.title('Input')
plt.axis('off')

plt.subplot(132)
plt.imshow(M.embedding_, interpolation='none')
plt.title('Low-rank approximation')
plt.axis('off')

plt.subplot(133)
plt.imshow(dx - M.embedding_, interpolation='none')
plt.title('Residual')
plt.axis('off')

plt.tight_layout()