In [ ]:
%matplotlib inline
%xmode verbose

In [ ]:
import numpy as np 
import matplotlib.pyplot as plt 
import astra 
import scipy
import alg 
import sirt 
import sirt_noise

In [ ]:
def err_l2(img, rec):
    return np.sum((img - rec)**2) / (rec.shape[0] * rec.shape[1])

In [ ]:
def mean_value(lam, num = 100):
    m_v = 0.0
    factor = 1.0
    for x in np.arange(1, num + 1, 1):
        factor = factor * lam / x
        m_v -= np.log(x) * factor
    m_v = m_v * np.exp(-lam)
    return m_v

In [ ]:
def var_value(lam, M, num = 100):
    d_v = 0.0
    factor = 1.0
    for x in np.arange(1, num + 1, 1):
        factor = factor * lam / x
        d_v += (- np.log(x) - M)**2 * factor
    d_v = d_v * np.exp(-lam)
    return d_v

In [ ]:
# make phantom
size = 50
mu1 = 0.006
mu2 = 0.005
mu3 = 0.004
phantom = np.zeros((size, size))
half_s = size / 2

y, x = np.meshgrid(range(size), range(size))
xx = (x - half_s).astype('float32')
yy = (y - half_s).astype('float32')
  
mask_ell1 = pow(xx + 0.1*size, 2)/np.power(0.35*size, 2) + pow(yy, 2)/np.power(0.15*size, 2) <= 1
mask_ell2 = pow(xx - 0.15*size, 2)/np.power(0.3*size, 2) + pow(yy - 0.15*size, 2)/np.power(0.15*size, 2) <= 1 
phantom[mask_ell1] = mu1
phantom[mask_ell2] = mu2
phantom[np.logical_and(mask_ell1, mask_ell2)] = mu3
phantom[int(0.15*size):int(0.35*size), int(0.2*size):int(0.5*size)] = mu3 
phantom = 1e+1 * phantom

# make sinogram
n_angles = 90.0
angles = np.arange(0.0, 180.0,  180.0 / n_angles)
angles = angles.astype('float32') / 180 * np.pi

pg = astra.create_proj_geom('parallel', 1.0, size, angles)
vg = astra.create_vol_geom((size, size))
sino = alg.gpu_fp(pg, vg, phantom)
sino = sino.astype('float64')

print(sino.min(), sino.max())
i0 = 2e+2
sino = i0 * np.exp(-sino)

print(sino.min(), sino.max())

In [ ]:
M = np.zeros_like(sino)
D = np.zeros_like(sino)
for i in np.arange(0, sino.shape[0]):
    for j in np.arange(0, sino.shape[1]):
        M[i,j] = mean_value(sino[i,j], num = 600)
        D[i,j] = var_value(sino[i,j], M[i,j], num = 600)
print(D.min(), D.max())

Div = D.copy()
Div = np.sqrt(D)
#Div [Div <= 0.01 ] = 1 #Div[Div != 0.0].min() 
Div = 1.0 / (Div)
print(Div.min(), Div.max())

fig = plt.figure(figsize=(15,10))
a=fig.add_subplot(1,3,1)
imgplot = plt.imshow(sino, interpolation=None, cmap="gray")
a.set_title('Sinogram')
plt.colorbar(orientation='horizontal');
a=fig.add_subplot(1,3,2)
imgplot = plt.imshow(D, interpolation=None, cmap="gray")
a.set_title('Variance (V)')
plt.colorbar(orientation ='horizontal');
a=fig.add_subplot(1,3,3)
imgplot = plt.imshow(Div, interpolation=None, cmap="gray")
a.set_title('1.0 / standard deviation');
plt.colorbar(orientation ='horizontal');

In [ ]:
# add noise
sino_noise = np.random.poisson(lam=(sino)).astype('float64')
sino_noise[sino_noise > i0] = i0
print(sino_noise.min(), sino_noise.max())
sino_noise = np.log(i0) - np.log(sino_noise)
print(sino_noise.min(), sino_noise.max())

sino = np.log(i0) - np.log(sino)
print(sino.min(), sino.max())

sino_new = sino_noise * Div
print(sino_new.min(), sino_new.max())

fig = plt.figure(figsize=(15,10))
a=fig.add_subplot(1,3,1)
imgplot = plt.imshow(sino, interpolation=None, cmap="gray")
a.set_title('Sinogram')
plt.colorbar(orientation='horizontal');
a=fig.add_subplot(1,3,2)
imgplot = plt.imshow(sino_noise, interpolation=None, cmap="gray")
a.set_title('Noisy sinogram')
plt.colorbar(orientation ='horizontal');
a=fig.add_subplot(1,3,3)
imgplot = plt.imshow(sino_new, interpolation=None, cmap="gray")
a.set_title('Noisy sinogram / standard deviation');
plt.colorbar(orientation ='horizontal');

In [ ]:
proj_id = astra.create_projector('cuda', pg, vg)
W = astra.OpTomo(proj_id)
x0 = np.zeros_like(phantom)
eps = 1e-30

x0 = np.zeros_like(phantom)
rec_1 = sirt.run(W, sino_noise, x0, eps, 100, 'steepest')
en_1 = rec_1['energy'] 
alpha_1 = rec_1['alpha']
rec_1 = rec_1['rec']

x0 = np.zeros_like(phantom)
#x0 = rec_1.copy()
rec_2 = sirt_noise.run(W, sino_new, Div, x0, eps, 100, 'steepest')
en_2 = rec_2['energy'] 
alpha_2 = rec_2['alpha']
rec_2 = rec_2['rec']

astra.projector.delete(proj_id)

In [ ]:
er_1 = err_l2(phantom, rec_1)
er_2 = err_l2(phantom, rec_2)

fig = plt.figure(figsize=(15,10))

a=fig.add_subplot(2,3,1)
imgplot = plt.imshow(sino_noise, interpolation=None, cmap="gray")
a.set_title('Noisy sinogram')
plt.colorbar(orientation='horizontal');

a=fig.add_subplot(2,3,2)
imgplot = plt.imshow(D, interpolation=None, cmap="gray")
a.set_title('Variance (V)')
plt.colorbar(orientation ='horizontal');

a=fig.add_subplot(2,3,3)
imgplot = plt.imshow(sino_new, interpolation=None, cmap="gray")
a.set_title('Noisy sinogram / standard deviation')
plt.colorbar(orientation ='horizontal');

a=fig.add_subplot(2,3,4)
imgplot = plt.imshow(phantom, interpolation=None, cmap="gray")
a.set_title('phantom')
plt.colorbar(orientation ='horizontal');

a=fig.add_subplot(2,3,5)
imgplot = plt.imshow(rec_1, interpolation=None, cmap="gray")
a.set_title('SIRT, Err=' + str('{:0.2e}'.format(er_1)))
plt.colorbar(orientation ='horizontal');

a=fig.add_subplot(2,3,6)
imgplot = plt.imshow(rec_2, interpolation=None, cmap="gray")
a.set_title('Variance SIRT, Err=' + str('{:0.2e}'.format(er_2)))
plt.colorbar(orientation ='horizontal');
plt.savefig("log_poisson_500.png")

In [ ]: