Transformée en ondelettes


In [ ]:
%matplotlib inline

import matplotlib
matplotlib.rcParams['figure.figsize'] = (7, 7)

from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.colors import LogNorm

import PIL.Image as pil_img     # PIL.Image is a module not a class...

Wavelets: mother wavelet $\Psi$

Family $\psi_{a,b}$ (where $\left( {a , b} \right) \in{\mathbb{R^{+*} }\times\mathbb{R}}$) is defined from the "mother wavelet" $\Psi$:

$$\forall t \in \mathbb{R}, \quad \psi_{\color{green}{a}, \color{blue}{b}}(t) = \frac{1}{\sqrt{\color{green}{a}}} \Psi\left(\frac{t-\color{blue}{b}}{\color{green}{a}}\right)$$

where

  • $a$ is the scale factor,
  • $b$ is the translation factor.

Wavelets: general case (1D continuous case)

The original signal $f$ defined as:

$$f(t) = \frac{1}{C_{\Psi}} \int_{-\infty}^{\infty} \int_{-\infty}^{\infty} \frac{\color{red}{g(a,b)}}{|\color{green}{a}|^2} \psi_{\color{green}{a},\color{blue}{b}}(t) ~ da ~ db$$

where $C_{\Psi}$ is a constant which depends on the chosen wavelet mother $\Psi$.

Weights are given by:

$$\color{red}{g(a, b)} = \int_{-\infty}^{\infty} f(t) ~ \psi_{\color{green}{a}, \color{blue}{b}}^*(t) ~ dt$$

where $z*$ is the complex conjugate of $z$.

wavelet_transform(input_image, num_scales):


scales[0] $\leftarrow$ input_image

for $i \in [0, \dots, \text{num_scales} - 2]$

$\quad$ scales[$i + 1$] $\leftarrow$ convolve(scales[$i$], $i$)

$\quad$ scales[$i$] $\leftarrow$ scales[$i$] - scales[$i + 1$]


convolve(scale, $s_i$):


$c_0 \leftarrow 3/8$

$c_1 \leftarrow 1/4$

$c_2 \leftarrow 1/16$

$s \leftarrow \lfloor 2^{s_i} + 0.5 \rfloor$

for all columns $x_i$

$\quad$ for all rows $y_i$

$\quad\quad$ scale[$x_i$, $y_i$] $\leftarrow$ $c_0$ . scale[$x_i$, $y_i$] + $c_1$ . scale[$x_i-s$, $y_i$] + $c_1$ . scale[$x_i+s$, $y_i$] + $c_2$ . scale[$x_i-2s$, $y_i$] + $c_2$ . scale[$x_i+2s$, $y_i$]

for all columns $x_i$

$\quad$ for all rows $y_i$

$\quad\quad$ scale[$x_i$, $y_i$] $\leftarrow$ $c_0$ . scale[$x_i$, $y_i$] + $c_1$ . scale[$x_i$, $y_i-s$] + $c_1$ . scale[$x_i$, $y_i+s$] + $c_2$ . scale[$x_i$, $y_i-2s$] + $c_2$ . scale[$x_i$, $y_i+2s$]

## Default wavelet implemetation ### *mr_transform()* function Implemented in *~/bin/isap/cxx/sparse2d/src/libsparse2d/MR_Trans.cc* ```cpp static void mr_transform (Ifloat &Image, MultiResol &MR_Transf, Bool EdgeLineTransform, type_border Border, Bool Details) { // [...] MR_Transf.band(0) = Image; for (s = 0; s < Nbr_Plan -1; s++) { smooth_bspline (MR_Transf.band(s),MR_Transf.band(s+1),Border,s); MR_Transf.band(s) -= MR_Transf.band(s+1); } // [...] } ```
## Default wavelet implemetation ### *smooth_bspline()* function (the "convolve" function in the previous pseudo-code) Implemented in *isap/cxx/sparse2d/src/libsparse2d/IM_Smooth.cc* ```cpp void smooth_bspline (const Ifloat & Im_in, Ifloat &Im_out, type_border Type, int Step_trou) { int Nl = Im_in.nl(); int Nc = Im_in.nc(); int i,j,Step; float Coeff_h0 = 3. / 8.; float Coeff_h1 = 1. / 4.; float Coeff_h2 = 1. / 16.; Ifloat Buff(Nl,Nc,"Buff smooth_bspline"); Step = (int)(pow((double)2., (double) Step_trou) + 0.5); for (i = 0; i < Nl; i ++) for (j = 0; j < Nc; j ++) Buff(i,j) = Coeff_h0 * Im_in(i,j) + Coeff_h1 * ( Im_in (i, j-Step, Type) + Im_in (i, j+Step, Type)) + Coeff_h2 * ( Im_in (i, j-2*Step, Type) + Im_in (i, j+2*Step, Type)); for (i = 0; i < Nl; i ++) for (j = 0; j < Nc; j ++) Im_out(i,j) = Coeff_h0 * Buff(i,j) + Coeff_h1 * ( Buff (i-Step, j, Type) + Buff (i+Step, j, Type)) + Coeff_h2 * ( Buff (i-2*Step, j, Type) + Buff (i+2*Step, j, Type)); } ```

In [ ]:
# Tool functions

def read_fits_file(file_path):
    hdu_list = fits.open(file_path) # Open the FITS file
    data = hdu_list[0].data
    hdu_list.close()                # Close the FITS file
    return data

def plot(data, title="", cmap="gnuplot2"):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    im = ax.imshow(data, interpolation='nearest', origin='lower', cmap=cmap)   # cmap=cm.inferno and cmap="inferno" are both valid
    ax.set_title(title)
    
    plt.colorbar(im) # draw the colorbar
    plt.show()

Transformée en ondelettes "À trous"

Une implementation Python


In [ ]:
# %load /Users/jdecock/git/pub/jdhp/snippets/science/wavelet_transform/bspline_wavelet_transform.py

def get_pixel_value(image, x, y, type_border):
    if type_border == 0:
        try:
            pixel_value = image[x, y]
            return pixel_value
        except IndexError as e:
            return 0
    elif type_border == 1:
        num_lines, num_col = image.shape    # TODO
        x = x % num_lines
        y = y % num_col
        pixel_value = image[x, y]
        return pixel_value
    elif type_border == 2:
        num_lines, num_col = image.shape    # TODO

        if x >= num_lines:
            x = num_lines - 2 - x
        elif x < 0:
            x = abs(x)

        if y >= num_col:
            y = num_col - 2 - y
        elif y < 0:
            y = abs(y)

        pixel_value = image[x, y]
        return pixel_value
    elif type_border == 3:
        num_lines, num_col = image.shape    # TODO

        if x >= num_lines:
            x = num_lines - 1 - x
        elif x < 0:
            x = abs(x) - 1

        if y >= num_col:
            y = num_col - 1 - y
        elif y < 0:
            y = abs(y) - 1

        pixel_value = image[x, y]
        return pixel_value
    else:
        raise ValueError()

In [ ]:
def smooth_bspline(input_image, type_border, step_trou):
#    int num_lines = img_in.nl();  // num lines in the image
#    int num_col = img_in.nc();  // num columns in the image
#
#    int i, j, step;
#
#    float coeff_h0 = 3. / 8.;
#    float coeff_h1 = 1. / 4.;
#    float coeff_h2 = 1. / 16.;
#
#    Ifloat buff(num_lines, num_col, "buff smooth_bspline");
#
#    step = (int)(pow((double)2., (double) step_trou) + 0.5);
#
#    for (i = 0; i < num_lines; i ++)
#    for (j = 0; j < num_col; j ++)
#       buff(i,j) = coeff_h0 * img_in(i,j)
#                 + coeff_h1 * (  img_in(i, j-step, type_border)
#                               + img_in(i, j+step, type_border))
#                 + coeff_h2 * (  img_in(i, j-2*step, type_border)
#                               + img_in(i, j+2*step, type_border));
#
#    for (i = 0; i < num_lines; i ++)
#    for (j = 0; j < num_col; j ++)
#       img_out(i,j) = coeff_h0 * buff(i,j)
#                    + coeff_h1 * (  buff(i-step, j, type_border)
#                                  + buff(i+step, j, type_border))
#                    + coeff_h2 * (  buff(i-2*step, j, type_border)
#                                  + buff(i+2*step, j, type_border));
    
    input_image = input_image.astype('float64', copy=True)

    coeff_h0 = 3. / 8.
    coeff_h1 = 1. / 4.
    coeff_h2 = 1. / 16.

    num_lines, num_col = input_image.shape    # TODO

    buff = np.zeros(input_image.shape, dtype='float64')
    img_out = np.zeros(input_image.shape, dtype='float64')

    step = int(pow(2., step_trou) + 0.5)

    #print("step =", step)

    for i in range(num_lines):
        for j in range(num_col):
            buff[i,j]  = coeff_h0 *    get_pixel_value(input_image, i, j,        type_border) 
            buff[i,j] += coeff_h1 * (  get_pixel_value(input_image, i, j-step,   type_border) \
                                     + get_pixel_value(input_image, i, j+step,   type_border))
            buff[i,j] += coeff_h2 * (  get_pixel_value(input_image, i, j-2*step, type_border) \
                                     + get_pixel_value(input_image, i, j+2*step, type_border))

#    for (i = 0; i < num_lines; i ++)
#    for (j = 0; j < num_col; j ++)
#       img_out(i,j) = coeff_h0 * buff(i,j)
#                    + coeff_h1 * (  buff(i-step, j, type_border)
#                                  + buff(i+step, j, type_border))
#                    + coeff_h2 * (  buff(i-2*step, j, type_border)
#                                  + buff(i+2*step, j, type_border));
    for i in range(num_lines):
        for j in range(num_col):
            img_out[i,j]  = coeff_h0 *    get_pixel_value(buff, i,        j, type_border) 
            img_out[i,j] += coeff_h1 * (  get_pixel_value(buff, i-step,   j, type_border) \
                                        + get_pixel_value(buff, i+step,   j, type_border))
            img_out[i,j] += coeff_h2 * (  get_pixel_value(buff, i-2*step, j, type_border) \
                                        + get_pixel_value(buff, i+2*step, j, type_border))

    return img_out

In [ ]:
def transform(image, num_scales):
    # MR_Transf.band(0) = Image;
    # for (s = 0; s < Nbr_Plan -1; s++)
    # {
    #     smooth_bspline (MR_Transf.band(s),MR_Transf.band(s+1),Border,s);
    #     if  (Details == True) MR_Transf.band(s) -= MR_Transf.band(s+1);
    # }

    image = image.astype('float64', copy=True)

    scale_list = []
    scale_list.append(image)

    for scale_index in range(num_scales - 1):
        previous_scale = scale_list[scale_index]

        next_scale = smooth_bspline(previous_scale, 3, scale_index)

        previous_scale -= next_scale

        scale_list.append(next_scale)

    return scale_list

Example 1


In [ ]:
data = np.zeros([81,81])
data[40,40] = 1

num_scales = 6

plt.imshow(data);

In [ ]:
#%%timeit
scale_list = transform(data, num_scales)

In [ ]:
#print(transformed_image)

# Write data
for scale_index, scale in enumerate(scale_list):
    plot(scale, "Scale {}".format(scale_index))

In [ ]:
plt.plot(scale_list[3][40,:]);

In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)

plt.imshow(rebuilt_data);

In [ ]:
err = data - rebuilt_data

print(err.min())
print(err.max())

Example 2


In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/science/wavelet_transform/test.fits"

# Read data
data = read_fits_file(file_path)

num_scales = 4

plt.imshow(data, cmap="gnuplot2");

In [ ]:
#%%timeit
scale_list = transform(data, num_scales)

In [ ]:
#print(transformed_image)

# Write data
for scale_index, scale in enumerate(scale_list):
    plot(scale, "Scale {}".format(scale_index))

In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)

plt.imshow(rebuilt_data, cmap="gnuplot2");

In [ ]:
err = data - rebuilt_data

print(err.min())
print(err.max())

Example 3


In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau_512.png"
#file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau2.jpg"

# Read data: open the image and convert it to grayscale
data = np.array(pil_img.open(file_path).convert('L'))

num_scales = 4

fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.imshow(data, cmap="gray", interpolation="spline36");

In [ ]:
#%%timeit
scale_list = transform(data, num_scales)

In [ ]:
for scale_index, scale in enumerate(scale_list):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_axis_off()

    im = ax.imshow(scale, cmap="gray", interpolation="spline36")  # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
    plt.colorbar(im, fraction=0.045, pad=0.04)

    ax.set_title("Scale {}".format(scale_index))
    plt.show()

In [ ]:
#print(transformed_image)

# Write data
for scale_index, scale in enumerate(scale_list):
    print(scale.min(), scale.max())
    
    fig, ax = plt.subplots(figsize=(10,10))
    ax.set_axis_off()
    
    amplitude = max(scale.max(), scale.min() * -1.)
    
    im1 = ax.imshow(scale, cmap="Reds", norm=LogNorm(vmin=0.01, vmax=amplitude))
    plt.colorbar(im1, fraction=0.04125, pad=0.06)
    
    if scale.min() < 0.:
        im2 = ax.imshow(scale * -1., cmap="Blues", norm=LogNorm(vmin=0.01, vmax=amplitude))
        plt.colorbar(im2, fraction=0.046, pad=0.04)
    
    ax.set_title("Scale {}".format(scale_index))
    plt.show()

In [ ]:
#print(transformed_image)

# Write data
for scale_index, scale in enumerate(scale_list):
    print(scale.min(), scale.max())
    
    fig, ax = plt.subplots(figsize=(10,10))
    ax.set_axis_off()
    
    im1 = ax.imshow(scale, cmap="gray", norm=LogNorm(vmin=0.001, vmax=scale.max()))
    #plt.colorbar(im1, fraction=0.04125, pad=0.06)
    
    if scale.min() < 0.:
        im2 = ax.imshow(scale * -1., cmap="gray_r", norm=LogNorm(vmin=0.001, vmax=scale.min() * -1.))
        plt.colorbar(im2, fraction=0.046, pad=0.04)
    
    ax.set_title("Scale {}".format(scale_index))
    plt.show()

In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)

fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()

ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");

In [ ]:
err = data - rebuilt_data

print(err.min())
print(err.max())

Example 4: (artificial) noise filtering


In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau_512.png"
#file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau2.jpg"

# Read data: open the image and convert it to grayscale
data = np.array(pil_img.open(file_path).convert('L'))

num_scales = 4

fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.imshow(data, cmap="gray", interpolation="spline36");

In [ ]:
#%%timeit
scale_list = transform(data, num_scales)

In [ ]:
NOISE_FACTOR = [5., 0.25, 0.15, 0.]
CLEAN_FACTOR = [1.5, 0.75, 0.5, 0.]  # [1.5, 0.7, 0.5, 0]

mask_list = []
noised_scale_list = []
cleaned_scale_list = []

for scale_index, scale in enumerate(scale_list):
    # COMPUTE THE STD
    mean = scale.mean()
    std = scale.std()
    print("Scale {}: mean={} std={}".format(scale_index, mean, std))
    
    # ADD NOISE
    noise = np.random.normal(loc=mean, scale=NOISE_FACTOR[scale_index] * std, size=scale.shape)
    #noise = np.random.uniform(low=-CLEAN_FACTOR[scale_index] * std * 0.9,
    #                          high=CLEAN_FACTOR[scale_index] * std * 0.9,
    #                          size=scale.shape)
    noised_scale = scale + noise
    noised_scale_list.append(noised_scale)

    # CLEAN SCALE
    mask = np.abs(noised_scale + mean) > (std * CLEAN_FACTOR[scale_index])
    mask_list.append(mask)
    plt.imshow(mask)
    plt.colorbar()
    plt.show()
    #mask = noised_scale > (std * CLEAN_FACTOR[scale_index])
    cleaned_img = noised_scale * mask
    cleaned_scale_list.append(cleaned_img)

#cleaned_scale_list[0] = noised_scale_list[0] * mask_list[1]    # experimental

In [ ]:
for scale_index, (scale, noised_scale, cleaned_scale) in enumerate(zip(scale_list, noised_scale_list, cleaned_scale_list)):
    fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(16, 8))
    ax1.set_axis_off()
    ax2.set_axis_off()
    ax3.set_axis_off()

    im1 = ax1.imshow(scale, cmap="gray", interpolation="spline36")  # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
    plt.colorbar(im1, ax=ax1)

    im2 = ax2.imshow(noised_scale, cmap="gray", interpolation="spline36")  # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
    plt.colorbar(im2, ax=ax2)
    
    im3 = ax3.imshow(cleaned_scale, cmap="gray", interpolation="spline36")  # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
    plt.colorbar(im3, ax=ax3)

    ax1.set_title("Scale {} (orig)".format(scale_index))
    ax2.set_title("Scale {} (noised)".format(scale_index))
    ax3.set_title("Scale {} (cleaned)".format(scale_index))
    
    plt.show()

In [ ]:
rebuilt_data = np.sum(noised_scale_list, axis=0)

fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()

ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");

In [ ]:
rebuilt_data = np.sum(cleaned_scale_list, axis=0)

fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()

ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");

In [ ]:
err = data - rebuilt_data

print(err.min())
print(err.max())

In [ ]: