TODO:
See also
In [ ]:
%matplotlib inline
import matplotlib
matplotlib.rcParams['figure.figsize'] = (7, 7)
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import PIL.Image as pil_img # PIL.Image is a module not a class...
Family $\psi_{a,b}$ (where $\left( {a , b} \right) \in{\mathbb{R^{+*} }\times\mathbb{R}}$) is defined from the "mother wavelet" $\Psi$:
$$\forall t \in \mathbb{R}, \quad \psi_{\color{green}{a}, \color{blue}{b}}(t) = \frac{1}{\sqrt{\color{green}{a}}} \Psi\left(\frac{t-\color{blue}{b}}{\color{green}{a}}\right)$$where
The original signal $f$ defined as:
$$f(t) = \frac{1}{C_{\Psi}} \int_{-\infty}^{\infty} \int_{-\infty}^{\infty} \frac{\color{red}{g(a,b)}}{|\color{green}{a}|^2} \psi_{\color{green}{a},\color{blue}{b}}(t) ~ da ~ db$$where $C_{\Psi}$ is a constant which depends on the chosen wavelet mother $\Psi$.
Weights are given by:
$$\color{red}{g(a, b)} = \int_{-\infty}^{\infty} f(t) ~ \psi_{\color{green}{a}, \color{blue}{b}}^*(t) ~ dt$$where $z*$ is the complex conjugate of $z$.
wavelet_transform(input_image, num_scales):
scales[0] $\leftarrow$ input_image
for $i \in [0, \dots, \text{num_scales} - 2]$
$\quad$ scales[$i + 1$] $\leftarrow$ convolve(scales[$i$], $i$)
$\quad$ scales[$i$] $\leftarrow$ scales[$i$] - scales[$i + 1$]
convolve(scale, $s_i$):
$c_0 \leftarrow 3/8$
$c_1 \leftarrow 1/4$
$c_2 \leftarrow 1/16$
$s \leftarrow \lfloor 2^{s_i} + 0.5 \rfloor$
for all columns $x_i$
$\quad$ for all rows $y_i$
$\quad\quad$ scale[$x_i$, $y_i$] $\leftarrow$ $c_0$ . scale[$x_i$, $y_i$] + $c_1$ . scale[$x_i-s$, $y_i$] + $c_1$ . scale[$x_i+s$, $y_i$] + $c_2$ . scale[$x_i-2s$, $y_i$] + $c_2$ . scale[$x_i+2s$, $y_i$]
for all columns $x_i$
$\quad$ for all rows $y_i$
$\quad\quad$ scale[$x_i$, $y_i$] $\leftarrow$ $c_0$ . scale[$x_i$, $y_i$] + $c_1$ . scale[$x_i$, $y_i-s$] + $c_1$ . scale[$x_i$, $y_i+s$] + $c_2$ . scale[$x_i$, $y_i-2s$] + $c_2$ . scale[$x_i$, $y_i+2s$]
In [ ]:
# Tool functions
def read_fits_file(file_path):
hdu_list = fits.open(file_path) # Open the FITS file
data = hdu_list[0].data
hdu_list.close() # Close the FITS file
return data
def plot(data, title="", cmap="gnuplot2"):
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.imshow(data, interpolation='nearest', origin='lower', cmap=cmap) # cmap=cm.inferno and cmap="inferno" are both valid
ax.set_title(title)
plt.colorbar(im) # draw the colorbar
plt.show()
In [ ]:
# %load /Users/jdecock/git/pub/jdhp/snippets/science/wavelet_transform/bspline_wavelet_transform.py
def get_pixel_value(image, x, y, type_border):
if type_border == 0:
try:
pixel_value = image[x, y]
return pixel_value
except IndexError as e:
return 0
elif type_border == 1:
num_lines, num_col = image.shape # TODO
x = x % num_lines
y = y % num_col
pixel_value = image[x, y]
return pixel_value
elif type_border == 2:
num_lines, num_col = image.shape # TODO
if x >= num_lines:
x = num_lines - 2 - x
elif x < 0:
x = abs(x)
if y >= num_col:
y = num_col - 2 - y
elif y < 0:
y = abs(y)
pixel_value = image[x, y]
return pixel_value
elif type_border == 3:
num_lines, num_col = image.shape # TODO
if x >= num_lines:
x = num_lines - 1 - x
elif x < 0:
x = abs(x) - 1
if y >= num_col:
y = num_col - 1 - y
elif y < 0:
y = abs(y) - 1
pixel_value = image[x, y]
return pixel_value
else:
raise ValueError()
In [ ]:
def smooth_bspline(input_image, type_border, step_trou):
# int num_lines = img_in.nl(); // num lines in the image
# int num_col = img_in.nc(); // num columns in the image
#
# int i, j, step;
#
# float coeff_h0 = 3. / 8.;
# float coeff_h1 = 1. / 4.;
# float coeff_h2 = 1. / 16.;
#
# Ifloat buff(num_lines, num_col, "buff smooth_bspline");
#
# step = (int)(pow((double)2., (double) step_trou) + 0.5);
#
# for (i = 0; i < num_lines; i ++)
# for (j = 0; j < num_col; j ++)
# buff(i,j) = coeff_h0 * img_in(i,j)
# + coeff_h1 * ( img_in(i, j-step, type_border)
# + img_in(i, j+step, type_border))
# + coeff_h2 * ( img_in(i, j-2*step, type_border)
# + img_in(i, j+2*step, type_border));
#
# for (i = 0; i < num_lines; i ++)
# for (j = 0; j < num_col; j ++)
# img_out(i,j) = coeff_h0 * buff(i,j)
# + coeff_h1 * ( buff(i-step, j, type_border)
# + buff(i+step, j, type_border))
# + coeff_h2 * ( buff(i-2*step, j, type_border)
# + buff(i+2*step, j, type_border));
input_image = input_image.astype('float64', copy=True)
coeff_h0 = 3. / 8.
coeff_h1 = 1. / 4.
coeff_h2 = 1. / 16.
num_lines, num_col = input_image.shape # TODO
buff = np.zeros(input_image.shape, dtype='float64')
img_out = np.zeros(input_image.shape, dtype='float64')
step = int(pow(2., step_trou) + 0.5)
#print("step =", step)
for i in range(num_lines):
for j in range(num_col):
buff[i,j] = coeff_h0 * get_pixel_value(input_image, i, j, type_border)
buff[i,j] += coeff_h1 * ( get_pixel_value(input_image, i, j-step, type_border) \
+ get_pixel_value(input_image, i, j+step, type_border))
buff[i,j] += coeff_h2 * ( get_pixel_value(input_image, i, j-2*step, type_border) \
+ get_pixel_value(input_image, i, j+2*step, type_border))
# for (i = 0; i < num_lines; i ++)
# for (j = 0; j < num_col; j ++)
# img_out(i,j) = coeff_h0 * buff(i,j)
# + coeff_h1 * ( buff(i-step, j, type_border)
# + buff(i+step, j, type_border))
# + coeff_h2 * ( buff(i-2*step, j, type_border)
# + buff(i+2*step, j, type_border));
for i in range(num_lines):
for j in range(num_col):
img_out[i,j] = coeff_h0 * get_pixel_value(buff, i, j, type_border)
img_out[i,j] += coeff_h1 * ( get_pixel_value(buff, i-step, j, type_border) \
+ get_pixel_value(buff, i+step, j, type_border))
img_out[i,j] += coeff_h2 * ( get_pixel_value(buff, i-2*step, j, type_border) \
+ get_pixel_value(buff, i+2*step, j, type_border))
return img_out
In [ ]:
def transform(image, num_scales):
# MR_Transf.band(0) = Image;
# for (s = 0; s < Nbr_Plan -1; s++)
# {
# smooth_bspline (MR_Transf.band(s),MR_Transf.band(s+1),Border,s);
# if (Details == True) MR_Transf.band(s) -= MR_Transf.band(s+1);
# }
image = image.astype('float64', copy=True)
scale_list = []
scale_list.append(image)
for scale_index in range(num_scales - 1):
previous_scale = scale_list[scale_index]
next_scale = smooth_bspline(previous_scale, 3, scale_index)
previous_scale -= next_scale
scale_list.append(next_scale)
return scale_list
In [ ]:
data = np.zeros([81,81])
data[40,40] = 1
num_scales = 6
plt.imshow(data);
In [ ]:
#%%timeit
scale_list = transform(data, num_scales)
In [ ]:
#print(transformed_image)
# Write data
for scale_index, scale in enumerate(scale_list):
plot(scale, "Scale {}".format(scale_index))
In [ ]:
plt.plot(scale_list[3][40,:]);
In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)
plt.imshow(rebuilt_data);
In [ ]:
err = data - rebuilt_data
print(err.min())
print(err.max())
In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/science/wavelet_transform/test.fits"
# Read data
data = read_fits_file(file_path)
num_scales = 4
plt.imshow(data, cmap="gnuplot2");
In [ ]:
#%%timeit
scale_list = transform(data, num_scales)
In [ ]:
#print(transformed_image)
# Write data
for scale_index, scale in enumerate(scale_list):
plot(scale, "Scale {}".format(scale_index))
In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)
plt.imshow(rebuilt_data, cmap="gnuplot2");
In [ ]:
err = data - rebuilt_data
print(err.min())
print(err.max())
In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau_512.png"
#file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau2.jpg"
# Read data: open the image and convert it to grayscale
data = np.array(pil_img.open(file_path).convert('L'))
num_scales = 4
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.imshow(data, cmap="gray", interpolation="spline36");
In [ ]:
#%%timeit
scale_list = transform(data, num_scales)
In [ ]:
for scale_index, scale in enumerate(scale_list):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
im = ax.imshow(scale, cmap="gray", interpolation="spline36") # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
plt.colorbar(im, fraction=0.045, pad=0.04)
ax.set_title("Scale {}".format(scale_index))
plt.show()
In [ ]:
#print(transformed_image)
# Write data
for scale_index, scale in enumerate(scale_list):
print(scale.min(), scale.max())
fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()
amplitude = max(scale.max(), scale.min() * -1.)
im1 = ax.imshow(scale, cmap="Reds", norm=LogNorm(vmin=0.01, vmax=amplitude))
plt.colorbar(im1, fraction=0.04125, pad=0.06)
if scale.min() < 0.:
im2 = ax.imshow(scale * -1., cmap="Blues", norm=LogNorm(vmin=0.01, vmax=amplitude))
plt.colorbar(im2, fraction=0.046, pad=0.04)
ax.set_title("Scale {}".format(scale_index))
plt.show()
In [ ]:
#print(transformed_image)
# Write data
for scale_index, scale in enumerate(scale_list):
print(scale.min(), scale.max())
fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()
im1 = ax.imshow(scale, cmap="gray", norm=LogNorm(vmin=0.001, vmax=scale.max()))
#plt.colorbar(im1, fraction=0.04125, pad=0.06)
if scale.min() < 0.:
im2 = ax.imshow(scale * -1., cmap="gray_r", norm=LogNorm(vmin=0.001, vmax=scale.min() * -1.))
plt.colorbar(im2, fraction=0.046, pad=0.04)
ax.set_title("Scale {}".format(scale_index))
plt.show()
In [ ]:
rebuilt_data = np.sum(scale_list, axis=0)
fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()
ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");
In [ ]:
err = data - rebuilt_data
print(err.min())
print(err.max())
In [ ]:
file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau_512.png"
#file_path = "/Users/jdecock/git/pub/jdhp/snippets/sample-images/doisneau2.jpg"
# Read data: open the image and convert it to grayscale
data = np.array(pil_img.open(file_path).convert('L'))
num_scales = 4
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.imshow(data, cmap="gray", interpolation="spline36");
In [ ]:
#%%timeit
scale_list = transform(data, num_scales)
In [ ]:
NOISE_FACTOR = [5., 0.25, 0.15, 0.]
CLEAN_FACTOR = [1.5, 0.75, 0.5, 0.] # [1.5, 0.7, 0.5, 0]
mask_list = []
noised_scale_list = []
cleaned_scale_list = []
for scale_index, scale in enumerate(scale_list):
# COMPUTE THE STD
mean = scale.mean()
std = scale.std()
print("Scale {}: mean={} std={}".format(scale_index, mean, std))
# ADD NOISE
noise = np.random.normal(loc=mean, scale=NOISE_FACTOR[scale_index] * std, size=scale.shape)
#noise = np.random.uniform(low=-CLEAN_FACTOR[scale_index] * std * 0.9,
# high=CLEAN_FACTOR[scale_index] * std * 0.9,
# size=scale.shape)
noised_scale = scale + noise
noised_scale_list.append(noised_scale)
# CLEAN SCALE
mask = np.abs(noised_scale + mean) > (std * CLEAN_FACTOR[scale_index])
mask_list.append(mask)
plt.imshow(mask)
plt.colorbar()
plt.show()
#mask = noised_scale > (std * CLEAN_FACTOR[scale_index])
cleaned_img = noised_scale * mask
cleaned_scale_list.append(cleaned_img)
#cleaned_scale_list[0] = noised_scale_list[0] * mask_list[1] # experimental
In [ ]:
for scale_index, (scale, noised_scale, cleaned_scale) in enumerate(zip(scale_list, noised_scale_list, cleaned_scale_list)):
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(16, 8))
ax1.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()
im1 = ax1.imshow(scale, cmap="gray", interpolation="spline36") # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
plt.colorbar(im1, ax=ax1)
im2 = ax2.imshow(noised_scale, cmap="gray", interpolation="spline36") # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
plt.colorbar(im2, ax=ax2)
im3 = ax3.imshow(cleaned_scale, cmap="gray", interpolation="spline36") # other cmaps: gray, Greys_r, seismic, bwr, coolwarm
plt.colorbar(im3, ax=ax3)
ax1.set_title("Scale {} (orig)".format(scale_index))
ax2.set_title("Scale {} (noised)".format(scale_index))
ax3.set_title("Scale {} (cleaned)".format(scale_index))
plt.show()
In [ ]:
rebuilt_data = np.sum(noised_scale_list, axis=0)
fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()
ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");
In [ ]:
rebuilt_data = np.sum(cleaned_scale_list, axis=0)
fig, ax = plt.subplots(figsize=(10,10))
ax.set_axis_off()
ax.imshow(rebuilt_data, cmap="gray", interpolation="spline36");
In [ ]:
err = data - rebuilt_data
print(err.min())
print(err.max())
In [ ]: