Computation and comparision of the bispectrum and the rotational bispectrum

We show how to compute the bispectrum and the rotational bispectrum, as presented in the paper

  • Image processing in the semidiscrete group of rototranslations by D. Prandi, U. Boscain and J.-P. Gauthier.

In [1]:
import numpy as np
from numpy import fft
from numpy import linalg as LA
from scipy import ndimage
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os

%matplotlib inline

Auxiliary functions


In [2]:
def int2intvec(a):
    """
    Auxiliary function to recover a vector with the digits of a 
    given integer (in inverse order)
    
    `a` : integer
    """
    digit = a%10
    vec = np.array([digit],dtype=int)
    a = (a-digit)/10
    while a!=0:
        digit = a%10
        vec = np.append(vec,int(digit))
        a = (a-digit)/10
    return vec

In [3]:
ALPHABET7 = "0123456"
ALPHABET10 = "0123456789"

def base_encode(num, alphabet):
    """
    Encode a number in Base X

    `num`: The number to encode
    """
    if (str(num) == alphabet[0]):
        return int(0)
    arr = []
    base = len(alphabet)
    while num:
        rem = num % base
        num = num // base
        arr.append(alphabet[rem])
    arr.reverse()
    return int(''.join(arr))

def base7to10(num):
    """
    Convert a number from base 10 to base 7
    
    `num`: The number to convert    
    """
    arr = int2intvec(num)
    num = 0
    for i in range(len(arr)):
        num += arr[i]*(7**(i))
    return num
    
def base10to7(num):
    """
    Convert a number from base 7 to base 10
    
    `num`: The number to convert    
    """
    
    return base_encode(num, ALPHABET7)

In [4]:
def rgb2gray(rgb):
    """
    Convert an image from RGB to grayscale
    
    `rgb`: The image to convert    
    """
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray

In [5]:
def oversampling(image, factor = 7):
    """
    Oversample a grayscale image by a certain factor, dividing each
    pixel in factor*factor subpixels with the same intensity.
    
    `image`:  The image to oversample
    `factor`: The oversampling factor
    """
    old_shape = image.shape
    new_shape = (factor*old_shape[0], factor*old_shape[1])
    new_image = np.zeros(new_shape, dtype = image.dtype)
    for i in range(old_shape[0]):
        for j in range(old_shape[1]):
            new_image[factor*i:factor*i+factor,factor*j:factor*j+factor] = image[i,j]*np.ones((factor,factor))
    return new_image

Spiral architecture implementation

Spiral architecture has been introduced by Sheridan in

The implementation with hyperpels that we use in the following is presented in

For a more detailed implementation, see the notebook Hexagonal grid.

We start by defining the centered hyperpel, which is defined on a 9x9 grid and is composed of 56 pixels. It has the shape

# o o x x x x x o o
# o x x x x x x x o
# o x x x x x x x o
# x x x x x x x x x
# x x x C x x x x x 
# o x x x x x x x o
# o x x x x x x x o
# o o x x x x x o o

In [6]:
# The centered hyperpel
hyperpel = np.array([\
                [-1,4],[0,4],[1,4],[2,4],[3,4],\
                [-2,3],[-1,3], [0,3], [1,3], [2,3], [3,3], [4,3],\
                [-2,2],[-1,2], [0,2], [1,2], [2,2], [3,2], [4,2],\
                [-3,1],[-2,1],[-1,1], [0,1], [1,1], [2,1], [3,1], [4,1],[5,1],\
                [-3,0],[-2,0],[-1,0], [0,0], [1,0], [2,0], [3,0], [4,0],[5,0],\
                [-2,-1],[-1,-1], [0,-1], [1,-1], [2,-1], [3,-1], [4,-1],\
                [-2,-2],[-1,-2], [0,-2], [1,-2], [2,-2], [3,-2], [4,-2],\
                [-1,-3], [0,-3], [1,-3], [2,-3], [3,-3]])

hyperpel_sa = hyperpel - np.array([1,1])

We now compute, in sa2hex, the address of the center of the hyperpel corresponding to a certain spiral address.


In [7]:
def sa2hex(spiral_address):
    # Split the number in basic unit and call the auxiliary function
    # Here we reverse the order, so that the index corresponds to the 
    # decimal position
    digits = str(spiral_address)[::-1] 
    
    hex_address = np.array([0,0])
    
    for i in range(len(digits)):
        if int(digits[i])<0 or int(digits[i])>6:
            print("Invalid spiral address!")
            return 
        elif digits[i]!= '0':
            hex_address += sa2hex_aux(int(digits[i]),i)
    return hex_address
        
# This computes the row/column positions of the base cases,
# that is, in the form a*10^(zeros).
def sa2hex_aux(a, zeros):
    # Base cases
    if zeros == 0:
        if a == 0:
            return np.array([0,0])
        elif a == 1:
            return np.array([0,8])
        elif a == 2:
            return np.array([-7,4])
        elif a == 3:
            return np.array([-7,-4])
        elif a == 4:
            return np.array([0,-8])
        elif a == 5:
            return np.array([7,-4])
        elif a == 6:
            return np.array([7,4])
    
    return sa2hex_aux(a,zeros-1)+ 2*sa2hex_aux(a%6 +1,zeros-1)

Then, we compute the value of the hyperpel corresponding to the spiral address, by averaging the values on the subpixels.


In [8]:
def sa_value(oversampled_image,spiral_address):
    """
    Computes the value of the hyperpel corresponding to the given
    spiral coordinate.
    """
    hp = hyperpel_sa + sa2hex(spiral_address)
    val = 0.
    for i in range(56):
        val += oversampled_image[hp[i,0],hp[i,1]]
    
    return val/56

Spiral addition and multiplication


In [10]:
def spiral_add(a,b,mod=0):
    addition_table = [
    [0,1,2,3,4,5,6],
    [1,63,15,2,0,6,64],
    [2,15,14,26,3,0,1],
    [3,2,26,25,31,4,0],
    [4,0,3,31,36,42,5],
    [5,6,0,4,42,41,53],
    [6,64,1,0,5,53,52]
    ]
    
    dig_a = int2intvec(a)
    dig_b = int2intvec(b) 
    
    if (dig_a<0).any() or (dig_a>7).any() \
      or (dig_b<0).any() or (dig_b>7).any():
        print("Invalid spiral address!")
        return
    
    if len(dig_a) == 1 and len(dig_b)==1:
        return addition_table[a][b]
    
    if len(dig_a) < len(dig_b):
        dig_a.resize(len(dig_b))
    elif len(dig_b) < len(dig_a):
        dig_b.resize(len(dig_a))
        
    res = 0
    
    for i in range(len(dig_a)):
        
        if i == len(dig_a)-1:
            res += spiral_add(dig_a[i],dig_b[i])*(10**i)
        else:
            temp = spiral_add(dig_a[i],dig_b[i])
            res += (temp%10)*(10**i)
        
            carry_on = spiral_add(dig_a[i+1],(temp - temp%10)/10)
            dig_a[i+1] = str(carry_on)
    
    if mod!=0:
        return res%mod
    
    return res

In [11]:
def spiral_mult(a,b, mod=0):
    multiplication_table = [
    [0,0,0,0,0,0,0],
    [0,1,2,3,4,5,6],
    [0,2,3,4,5,6,1],
    [0,3,4,5,6,1,2],
    [0,4,5,6,1,2,3],
    [0,5,6,1,2,3,4],
    [0,6,1,2,3,4,5],
    ]
    
    dig_a = int2intvec(a)
    dig_b = int2intvec(b) 
    
    if (dig_a<0).any() or (dig_a>7).any() \
      or (dig_b<0).any() or (dig_b>7).any():
        print("Invalid spiral address!")
        return
    
    sa_mult = int(0)
    
    for i in range(len(dig_b)):
        for j in range(len(dig_a)):
            temp = multiplication_table[dig_a[j]][dig_b[i]]*(10**(i+j))
            sa_mult=spiral_add(sa_mult,temp)
    
    if mod!=0:
        return sa_mult%mod
    
    return sa_mult

Computation of the bispectrum

We start by computing the vector $\omega_f(\lambda)$, where $\lambda$ is a certain spiral address.


In [12]:
def omegaf(fft_oversampled, sa):
    """
    Evaluates the vector omegaf corresponding to the given 
    spiral address sa.
    
    `fft_oversampled`: the oversampled FFT of the image
    `sa`: the spiral address where to compute the vector
    """
    
    omegaf = np.zeros(6, dtype=fft_oversampled.dtype)
    
    for i in range(1,7):
        omegaf[i-1] = sa_value(fft_oversampled,spiral_mult(sa,i))
    
    return omegaf

Then, we can compute the "generalized invariant" corresponding to $\lambda_1$, $\lambda_2$ and $\lambda_3$, starting from the FFT of the image. That is

$$ I^3_f(\lambda_1,\lambda_2,\lambda_3) = \langle\omega_f(\lambda_1)\odot\omega_f(\lambda_2),\omega_f(\lambda_3)\rangle. $$

In [13]:
def invariant(fft_oversampled, sa1,sa2,sa3):
    """
    Evaluates the generalized invariant of f on sa1, sa2 and sa3
    
    `fft_oversampled`: the oversampled FFT of the image
    `sa1`, `sa2`, `sa3`: the spiral addresses where to compute the invariant
    """
    
    omega1 = omegaf(fft_oversampled,sa1)
    omega2 = omegaf(fft_oversampled,sa2)
    omega3 = omegaf(fft_oversampled,sa3)
    
    # Attention: np.vdot uses the scalar product with the complex 
    # conjugation at the first place!
    return np.vdot(omega1*omega2,omega3)

Finally, this function computes the bispectrum (or the rotational bispectrum) corresponding to the spiral addresses in the following picture.


In [14]:
def bispectral_inv(fft_oversampled_example, rotational = False):
    """
    Computes the (rotational) bispectral invariants for any sa1 
    and any sa2 in the above picture.
    
    `fft_oversampled_example`: oversampled FFT of the image
    `rotational`: if True, we compute the rotational bispectrum
    """
    
    if rotational == True:
        bispectrum = np.zeros(9**2*6,dtype = fft_oversampled_example.dtype)
    else:
        bispectrum = np.zeros(9**2,dtype = fft_oversampled_example.dtype)
        
    indexes = [0,1,10,11,12,13,14,15,16]
    
    count = 0
    for i in range(9):
        sa1 = indexes[i]
        sa1_base10 = base7to10(sa1)
        for k in range(9):
            sa2 = indexes[k]
            if rotational == True:
                for r in range(6):
                    sa2_rot = spiral_mult(sa2,r)
                    sa2_rot_base10 = base7to10(sa2_rot)
                    sa3 = base10to7(sa1_base10+sa2_rot_base10)
                    bispectrum[count]=invariant(fft_oversampled_example,sa1,sa2,sa3)
                    count += 1
            else:
                sa2_base10 = base7to10(sa2)
                sa3 = base10to7(sa1_base10+sa2_base10)
                bispectrum[count]=invariant(fft_oversampled_example,sa1,sa2,sa3)
                count += 1
    
    return bispectrum

Some timing tests.


In [15]:
example =  1 - rgb2gray(plt.imread('./test-images/butterfly.png'))
fft_example = np.fft.fftshift(np.fft.fft2(example))
fft_oversampled_example = oversampling(fft_example)

In [16]:
%%timeit
bispectral_inv(fft_oversampled_example)


1 loops, best of 3: 372 ms per loop

In [17]:
%%timeit
bispectral_inv(fft_oversampled_example, rotational=True)


1 loops, best of 3: 2.35 s per loop

Tests

Here we define various functions to batch test the images in the test folder.


In [18]:
folder = './test-images'

In [19]:
def evaluate_invariants(image, rot = False):
    """
    Evaluates the invariants of the given image.
    
    `image`: the matrix representing the image (not oversampled)
    `rot`: if True we compute the rotational bispectrum
    """
    
    # compute the normalized FFT
    fft = np.fft.fftshift(np.fft.fft2(image))
    fft /= fft / LA.norm(fft)
    
    # oversample it
    fft_oversampled = oversampling(fft)
    
    return bispectral_inv(fft_oversampled, rotational = rot)

Some timing tests.


In [20]:
%%timeit
evaluate_invariants(example)


1 loops, best of 3: 1.07 s per loop

In [21]:
%%timeit
evaluate_invariants(example, rot = True)


1 loops, best of 3: 3.09 s per loop

In [22]:
def bispectral_folder(folder_name = folder, rot = False): 
    """
    Evaluates all the invariants of the images in the selected folder, 
    storing them in a dictionary with their names as keys.
    
    `folder_name`: path to the folder
    `rot`: if True we compute the rotational bispectrum
    """
    
    # we store the results in a dictionary
    results = {}
    
    for filename in os.listdir(folder_name):
        infilename = os.path.join(folder_name, filename)
        if not os.path.isfile(infilename): 
            continue

        base, extension = os.path.splitext(infilename)
        if extension == '.png':
            test_img = 1 - rgb2gray(plt.imread(infilename))
            bispectrum = evaluate_invariants(test_img, rot = rot)
            
            results[os.path.splitext(filename)[0]] = bispectrum
            
    return results

In [24]:
def bispectral_comparison(bispectrums, comparison = 'triangle', plot = True, log_scale = True):
    """
    Returns the difference of the norms of the given invariants w.r.t. the 
    comparison element.
    
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  the element to use as comparison
    """
    
    if comparison not in bispectrums:
        print("The requested comparison is not in the folder")    
        return
    
    
    bispectrum_diff = {}
    for elem in bispectrums:
        diff = LA.norm(bispectrums[elem]-bispectrums[comparison])
        # we remove nan results
        if not np.isnan(diff):
            bispectrum_diff[elem] = diff
        
    return bispectrum_diff

In [25]:
def bispectral_plot(bispectrums, comparison = 'triangle', log_scale = True):
    """
    Plots the difference of the norms of the given invariants w.r.t. the 
    comparison element (by default in logarithmic scale).
    
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  the element to use as comparison
    `log_scale`:   wheter the plot should be in log_scale
    """
    
    bispectrum_diff = bispectral_comparison(bispectrums, comparison = comparison)

    plt.plot(bispectrum_diff.values(),'ro')
    if log_scale == True:
        plt.yscale('log')
    for i in range(len(bispectrum_diff.values())):
        # if we plot in log scale, we do not put labels on items that are
        # too small, otherwise they exit the plot area.
        if log_scale and bispectrum_diff.values()[i] < 10**(-3):
            continue
        plt.text(i,bispectrum_diff.values()[i],bispectrum_diff.keys()[i][:3])
        plt.title("Comparison with as reference '"+ comparison +"'")

Construction of the table for the paper


In [26]:
comparisons_paper = ['triangle', 'rectangle', 'ellipse', 'etoile', 'diamond']

def extract_table_values(bispectrums, comparisons = comparisons_paper):
    """
    Extract the values for the table of the paper.
    
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  list of elements to use as comparison
    
    Returns a list of tuples. Each tuple contains the name of the comparison 
    element, the maximal value of the difference of the norm of the invariants 
    with its rotated and the minimal values of the same difference with the 
    other images.
    """
    table_values = []
    for elem in comparisons:
        diff = bispectral_comparison(bispectrums, comparison= elem, plot=False)

        l = len(elem)
        match = [x for x in diff.keys() if x[:l]==elem]
        not_match = [x for x in diff.keys() if x[:l]!=elem]

        max_match = max([ diff[k] for k in match ])
        min_not_match = min([ diff[k] for k in not_match ])
        
        table_values.append((elem,'%.2E' % (max_match),'%.2E' % min_not_match))
        
    return table_values

In [23]:
bispectrums = bispectral_folder()
bispectrums_rotational = bispectral_folder(rot=True)


/usr/local/lib/python2.7/site-packages/IPython/kernel/__main__.py:6: RuntimeWarning: invalid value encountered in divide

In [27]:
extract_table_values(bispectrums)


Out[27]:
[('triangle', '9.23E+10', '7.03E+12'),
 ('rectangle', '7.93E+10', '8.22E+12'),
 ('ellipse', '7.42E+10', '7.11E+12'),
 ('etoile', '7.27E+10', '5.54E+12'),
 ('diamond', '3.78E+10', '5.47E+12')]

In [28]:
extract_table_values(bispectrums_rotational)


Out[28]:
[('triangle', '2.26E+11', '1.72E+13'),
 ('rectangle', '1.94E+11', '2.01E+13'),
 ('ellipse', '1.82E+11', '1.74E+13'),
 ('etoile', '1.78E+11', '1.36E+13'),
 ('diamond', '9.27E+10', '1.34E+13')]