A toy example on machine learning:
"Learning" a colour image using non-negative matrix approximation

I am going to use an accessible example to try to explain my PhD research on x-ray spectromicroscopy imaging and classifying absorption spectra using non-negative matrix approximation (NNMA). (Don't worry if it doesn't resonate with you immediately — keep reading! The purpose of this page is to show a tangible example!)

I'll start with some introduction about what NNMA does, relate it experimental x-ray samples, introduce an accessible analogy, show some Python code that performs the NNMA itself, present the results, and finally add some concluding comments.

Non-negative matrix approximation (NNMA, or also NMF, for non-negative matrix factorization) has been used as a learning algorithm in facial recognition. The idea is to "factorize" a database of faces into their component features (parts of the eyes, nose, mouth, etc.), with each feature having some weighting. This process simplifies a jumble of faces into a simpler set (of component features). Then, given a new face, could we figure out whether this face belongs to the database (perhaps posed under different lighting conditions, or tilted at a different angle), or is it an entirely new face? If we could add up the features we've just learned in some weighted combination to match this face, then we have successfully recognized the new face as one already belonging to the database. If not, then it is likely not a face in the original database.

Now, instead of a database of faces, we have a stack of 2D x-ray absorption images of some sample (say, a biological cell). Each of the images in the stack is of the same sample, only differing in x-ray absorption energy. If you choose a fixed pixel on the image, and traverse down the stack, you can see how the absorption of that pixel changes as a function of energy — this is the absorption spectrum at that pixel. As an analogy to the face database described above: each spectrum is like a series of faces of the same person under different conditions, and each pixel is like the face of a different person. So we have a set of many, many spectra (as many spectra as there are pixels in the image, could be >106), and we would like to decompose them into a smaller, simpler set which could still be added up in some weighted combination to reconstruct the spectrum of any pixel. In a nutshell: we can use NNMA to unmix or reduce the large set of mixed spectra into the simplest set possible which would still contain the information needed to reconstruct any of the original spectra. This simpler set provides a clearer picture for understanding the chemical components that make up the sample.

OK, enough with the words — onto the promised tangible example! (Note: as with any simplified analogy, the correspondence is not exact and some details are not accounted for; but simplifications often offer insights that can help clarify a more complex problem.)

Suppose you have an RGB colour image (of anything colourful; say, peppers):

Each pixel contains some colour, but even two red pixels may not be the exact same shade of red. So here we have 503×302 pixels — potentially 151,906 different shades of different colours. Can we reduce this set into a smaller set that could still be used to reconstruct the colour of any given pixel? For an RGB image, the answer is already given: yes! Each pixel is an (R, G, B) tuple (where each of the R, G, and B can take on any value from 0 to 255 inclusive) and so can be decomposed into its R (red), G (green), and B (blue) components! The values in the tuples are just the weightings of the respective components in the pixel.
E.g., (255, 128, 0) represents fully saturated red + half-saturated green + no blue = an orangey colour.

So instead of describing the image using 151,906 different colours, each pixel is now specified instead as some weighted combination of 3 colours. Thus we have reduced it a big, chaotic set of colours to a simpler set of just 3!

The above may sound obvious (or maybe not), but let's relate this back to our x-ray spectromicroscopy application: at each pixel in the image, the absorption spectrum is a result of the absorption characteristics of a mixture of materials. We don't necessarily know what the materials are, nor their ratios. Since the amount of each material may differ in each pixel, we have as many different absorption spectra as there are number of pixels (as mentioned above, possibly >106). We would like to reduce this vast set of spectra to a smaller set, which can be used (in some weighted combination) to describe the absorption of any pixel. This smaller set, like the {R, G, B} set of colours, can shed light on the understanding of the material composition of a sample. (Small note: instead of a set like {R, G, B}, in this case we're looking for a set of spectra; each spectrum is an array of numbers that is a function of energy, not a scalar value like R, or G, or B. But this is just some detail that's not too important for the purpose of this illustration.)

Let's run an NNMA test on an RGB image to decompose it into its components and respective weightings, then use them to reconstruct the image and compare it with the original.

We will write an NNMATest class, which contains the iterative update functions and cost function calculations used in NNMA. It also plots the results. If you are running an active version of this page, you can change the parameters in the main function at the end to see how results vary (e.g., you can insert your favourite RGB image, or change the number of components to look for, etc. — the parameters will be explained below).


In [1]:
%matplotlib inline
import numpy as np
import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython import display
import time

In [2]:
class NNMATest():

    # Assume an RGB image file is given
    def __init__(self, fileName, nComponents=3, sparseParam=0.5, maxIters=10, plotPixComp=False):
        self.imgOrig = Image.open(fileName)
        imArray = np.array(self.imgOrig)
        imArray = imArray / 255.  # scale all RGB values wrt 1
        self.nRows = imArray.shape[0]
        self.nCols = imArray.shape[1]
        self.nPixels = self.nRows * self.nCols
        data = np.zeros((3, self.nPixels))
        for i in range(3):
            data[i, :] = imArray[:, :, i].flatten()
        self.data = data  # data matrix we want to reconstruct
        self.nComponents = nComponents  # no. of "hidden" components to extract
        self.sparseParam = sparseParam  # sparseness parameter
        self.maxIters = maxIters  # max no. of iterations to run
        self.frequencies = np.array([0., 1., 2.])  # our independent variable: in this case, the RGB colours
        self.nFrequencies = len(self.frequencies)
        self.plotPixComp = plotPixComp  # toggle whether to plot component composition of each pixel

        
    # Initialize the colour and weighting matrices
    def initMatrices(self):
        cInit = np.random.rand(self.nFrequencies, self.nComponents)
        wInit = np.random.rand(self.nComponents, self.nPixels)
        return cInit, wInit
    
    # Update the weighting matrix
    def wUpdate(self, c, w):
        wUpdateFactor = np.dot(c.T, self.data) / (np.dot(c.T, np.dot(c, w)) 
                                                  + self.sparseParam + 1e-9 )
        wUpdated = w * wUpdateFactor
        negInd = np.where(wUpdated < 0.)
        if negInd:
            wUpdated[negInd] = 1e-5
        return wUpdated

    # Update the colour matrix
    def cUpdate(self, c, wUpdated):
        cUpdateFactor = np.dot(self.data, wUpdated.T) / (np.dot(c, np.dot(wUpdated, wUpdated.T)) 
                                                    + 1e-9 )
        cUpdated = c * cUpdateFactor
        negInd = np.where(cUpdated < 0.)
        if negInd:
            cUpdated[negInd] = 1e-5
        # Normalize each column of c
        for k in range(self.nComponents):
            cUpdated[:, k] = cUpdated[:, k] / np.sum(cUpdated[:, k]) * 255.
            
        return cUpdated
    
    # Calculate the cost function to be minimized
    def calcCost(self, c, w, count):
        dataRecon = np.dot(c, w)
        costBasic = 0.5 * (np.linalg.norm(self.data - dataRecon))**2
        costSparse = np.sum(np.sum(np.abs(w)))
        costTotal = costBasic + self.sparseParam * costSparse
        self.costArray[count, 0] = costTotal
        if (count > 0):
            deltaCost = self.costArray[count, 0] - self.costArray[count-1, 0]
        elif (count == 0):
            deltaCost = 1e-13
        self.costArray[count, 1] = deltaCost
        return costTotal, deltaCost
    
    # Plot results
    def plotResults(self):
        # Plot original image
        figRecon = plt.figure(0, figsize=(15, 15))
        plt.subplot(1, 3, 1)
        plt.title("Original image")
        plt.imshow(self.imgOrig, origin="lower")
        
        # Plot reconstructed image
        imgReconTemp = self.dataRecon.reshape((self.nFrequencies, self.nRows, self.nCols))
        imgReconTemp = imgReconTemp / np.amax(imgReconTemp)
        imgRecon = np.zeros((self.nRows, self.nCols, self.nFrequencies))
        for k in range(self.nFrequencies):
            imgRecon[:, :, k] = imgReconTemp[k, :, :]
        imgRecon = Image.fromarray(np.uint8(imgRecon*255.), "RGB")
        plt.subplot(1, 3, 2)
        plt.title("Reconstructed image")
        plt.imshow(imgRecon, origin="lower")
        
        # Plot difference between original and reconstructed images
        dataDiffTemp = np.abs(self.data - self.dataRecon).reshape(
            (self.nFrequencies, self.nRows, self.nCols))
        dataDiffTemp = dataDiffTemp / np.amax(dataDiffTemp)
        dataDiff = np.zeros((self.nRows, self.nCols, self.nFrequencies))
        for k in range(self.nFrequencies):
            dataDiff[:, :, k] = dataDiffTemp[k, :, :]
        dataDiff = Image.fromarray(np.uint8(dataDiff*255,), "RGB")
        plt.subplot(1, 3, 3)
        plt.title("Difference image")
        plt.imshow(dataDiff, origin="lower")

        # Plot each component's "spectrum" and image
        plt.figure(1, figsize=(15, 15))
        self.wRecon = self.wRecon.reshape((self.nComponents, self.nRows, self.nCols))
        colours = ['r', 'g', 'b', 'm', 'c', 'y', 'k']
        for k in range(self.nComponents):
            if k < len(colours): lineColour = colours[k]
            else: lineColour = 'k'
            plt.subplot(self.nComponents, 2, k*2+1)
            plt.title("Component %s spectrum" %(k+1))
            plt.plot(self.frequencies, self.cRecon[:, k], linewidth=3., color=lineColour)
            plt.subplot(self.nComponents, 2, (k+1)*2)
            plt.title("Component %s image" %(k+1))
            plt.imshow(self.wRecon[k, :, :], cmap=cm.gray)
        
        display.display(plt.gcf())
        plt.close()
    
    # Visualize how each pixel's component content changes during iterations
    def plotWPixels(self, ind, count):
        wReconNorm = self.wRecon / np.amax(self.wRecon)
        wRecon_0 = wReconNorm[0, ind]
        wRecon_1 = wReconNorm[1, ind]
        wRecon_2 = wReconNorm[2, ind]
        plt.figure(3, figsize=(15, 8))
        axMin = -0.1
        axMax = 1.
        plt.subplot(1, 3, 1)
        plt.cla()
        plt.scatter(wRecon_0, wRecon_1, marker='o', color='brown', edgecolor="none")
        plt.xlabel("Component 1")
        plt.ylabel("Component 2")
        plt.xlim([axMin, axMax])
        plt.ylim([axMin, axMax])
        plt.subplot(1, 3, 2)
        plt.cla()
        plt.scatter(wRecon_0, wRecon_2, marker='o', color='m', edgecolor="none")
        plt.title("Iteration progress: {0}/{1}".format(count, self.maxIters))
        plt.xlabel("Component 1")
        plt.ylabel("Component 3")
        plt.xlim([axMin, axMax])
        plt.ylim([axMin, axMax])
        plt.subplot(1, 3, 3)
        plt.cla()
        plt.scatter(wRecon_1, wRecon_2, marker='o', color='c', edgecolor="none")
        plt.xlabel("Component 2")
        plt.ylabel("Component 3")
        plt.xlim([axMin, axMax])
        plt.ylim([axMin, axMax])
        display.clear_output(wait=True)
        display.display(plt.gcf())
        time.sleep(0.2)        
        plt.close()
    
    # Begin NNMA calculations
    def calcNNMA(self):
        print("Starting NNMA:")
        cInit, wInit = self.initMatrices()
        self.costArray = np.zeros((self.maxIters+1, 2))
        self.cRecon = cInit
        self.wRecon = wInit
        count = 0
        cost, deltaCost = self.calcCost(cInit, wInit, count)
        # If plotPixComp==True, pick 100 random pixels to plot their component distribution
        ind = np.random.random_integers(0, self.nPixels, 1000)
        if self.plotPixComp == True: self.plotWPixels(ind, count)
        while (count < self.maxIters) and ((deltaCost < -1e-6) or (deltaCost > 0.)):
            count = count + 1
            wUpdated = self.wUpdate(self.cRecon, self.wRecon)
            cUpdated = self.cUpdate(self.cRecon, wUpdated)
            cost, deltaCost = self.calcCost(cUpdated, wUpdated, count)
            self.cRecon = cUpdated
            self.wRecon = wUpdated
            self.dataRecon = np.dot(self.cRecon, self.wRecon)
            if (count%10 == 0):
                if self.plotPixComp == True:
                    self.plotWPixels(ind, count)
                else:
                    print "Iteration progress: {0}/{1}".format(count, self.maxIters)
                    display.clear_output(wait=True)
                            
            
        self.plotResults()

If you're running an active version of this page, you can run the main function in the cell below to display results. The arguments and keyword arguments can be customized to see how the results vary.

  • fileName: insert the file name of your favourite RGB image;
  • nComponents: the (unknown, estimated) number of components to look for;
  • sparseParam: the sparseness parameter describing the dataset;
  • maxIters: the max number of iterations to run;
  • plotPixComp: whether to plot distribution of component content of pixels during iterations. The more sparse you choose sparseParam to be, the more you will see the pixels migrate towards one of the two axes (which means that you're trying to force each pixel to belong purely to one component, rather than some mixture). (Note: it's nice to watch the pixels move, but the plot renderings slow down computation time).

In [ ]:
if __name__ == "__main__":
    fileName = "figures/peppers.jpg"
    test = NNMATest(fileName, nComponents=3, sparseParam=0.3, maxIters=1000, plotPixComp=False)
    test.calcNNMA()

If you're not running an active version of this page, then here I'll show the results of running the above code. First, the individual components (representing R, G, B) on the left, and their corresponding weighting maps on the right:

Wherever the weighting map is brighter, those pixels contain more of that corresponding component. For example, in the component 3 thickness map, the red peppers show up as the brightest — so the component 3 represents R in this case (despite the fact that it's drawn in blue). Similarly, component 1 represents G, and component 2 represents B. The white-coloured garlic shows up quite strongly in all of the components because white contains all three components.

Now this is what happens when we combine (matrix multiply) the discovered components with their weighting maps:

Not bad! We can reduce the error even further by running more iterations.

Some comments:

We mentioned before that there are some limits to the comparison between running NNMA on RGB images and running it on x-ray spectromicroscopy images. One such limit is this:

We've interpreted the 3 components that we extracted using NNMA as the 3 basic RGB colours that make up each pixel. This is cheating a bit, since we knew beforehand that the format of the image is RGB and that each pixel is composed of an RGB tuple, and so we knew to decompose the image into 3 components.

In x-ray spectromicroscopy, we don't know beforehand how many components to look for, since we don't usually know what the sample is composed of exactly. We may have some prior knowledge about the sample — say, it contains different compounds of copper — but we don't know exactly which ones, nor how many different kinds, so there is some guesswork involved (fortunately, not pure guesswork as there are other methods we could run on the data before NNMA, such as principal component analysis and cluster analysis, to get an idea beforehand about the number and variation of components in the sample).

This is like trying to run NNMA on an RGB image not knowing what objects are in it, and seeing if we could find components that represent peppers, garlic, the backdrop, etc. But the image itself does not contain this information (a pixel representing a purple background would contain exactly the same RGB tuple as a pixel representing a pepper with the same purple colour, for example) — so we cannot run NNMA on an RGB image and expect to discover this type of information (though there are other image classification techniques that could be used for this). X-ray absorption spectra, however, do contain information about the chemical content of the material at each pixel — and that's why it's interesting to apply NNMA to x-ray spectromicroscopy images to study the composition of materials!

So I'll end with one of the samples that I've used NNMA on: a human sperm cell (x-ray spectromicroscopy experiment undertaken by H. Fleckenstein) — component absorption spectra on the left, component weighting maps on the right (components correspond to: background, flagellum, mitochondria sheath and posterior ring, acrosomal cap, and nucleus):

There are a few things I have not discussed in detail above, such as how to choose the number of components to start with, smoothness and sparseness regularizations, comparison with cluster analysis spectra as a basis set, etc. If you are interested, you can read more about it in this paper:
Mak, R., Lerotic, M., Fleckenstein, H., Vogt, S., Wild, S. M., Leyffer, S., Sheynkin, Y., and Jacobsen, C., "Non-negative matrix analysis for effective feature extraction in x-ray spectromicroscopy", Faraday Discuss. 171, doi:10.1039/c4fd00023d (2014)