In [ ]:
%matplotlib inline
# %xmode verbose
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import astra
import alg
import RegVarSIRT
In [ ]:
def err_l2(img, rec):
return np.sum((img - rec)**2) / (rec.shape[0] * rec.shape[1])
In [ ]:
# make phantom
size = 512
mu1 = 0.006
mu2 = 0.005
mu3 = 0.004
phantom = np.zeros((size, size))
half_s = size / 2
y, x = np.meshgrid(range(size), range(size))
xx = (x - half_s).astype('float32')
yy = (y - half_s).astype('float32')
mask_ell1 = pow(xx + 0.1*size, 2)/np.power(0.35*size, 2) + pow(yy, 2)/np.power(0.15*size, 2) <= 1
mask_ell2 = pow(xx - 0.15*size, 2)/np.power(0.3*size, 2) + pow(yy - 0.15*size, 2)/np.power(0.15*size, 2) <= 1
phantom[mask_ell1] = mu1
phantom[mask_ell2] = mu2
phantom[np.logical_and(mask_ell1, mask_ell2)] = mu3
phantom[int(0.15*size):int(0.35*size), int(0.2*size):int(0.5*size)] = mu3
phantom[int(0.20*size):int(0.25*size), int(0.25*size):int(0.3*size)] = 0
phantom[int(0.30*size):int(0.35*size), int(0.35*size):int(0.4*size)] = mu1*10
phantom = 1e+1 * phantom
# make sinogram
n_angles = size//2
angles = np.arange(0.0, 180.0, 180.0 / n_angles)
angles = angles.astype('float32') / 180 * np.pi
pg = astra.create_proj_geom('parallel', 1.0, size, angles)
vg = astra.create_vol_geom((size, size))
sino = alg.gpu_fp(pg, vg, phantom)
print(sino.min(), sino.max())
In [ ]:
# estimate noise
D = sino.copy()
D = 1.0 * (0.05 + sino**2/100.0)
sino_noise = np.random.normal(sino, D)
print(sino_noise.min(), sino_noise.max())
Div = 1.0 / D
In [ ]:
proj_id = astra.create_projector('cuda', pg, vg)
W = astra.OpTomo(proj_id)
x0 = np.zeros_like(phantom)
rec_fbp = alg.gpu_fbp(pg, vg, sino_noise)
eps = 1e-10
n_iter = 10
Lambda=2.
#SIRT
rec_1 = RegVarSIRT.run(W, sino_noise, x0, eps=eps, n_it=n_iter, step='steepest')
al1 = rec_1['alpha']
en1 = rec_1['energy']
print(rec_1['iter'])
rec_s = rec_1['rec']
#SIRT+TV
rec_4 = RegVarSIRT.run(W, sino_noise, rec_s, Lambda=Lambda, eps=eps, n_it=n_iter, step='steepest')
al4 = rec_1['alpha']
en4 = rec_1['energy']
print(rec_1['iter'])
rec_st = rec_1['rec']
#VSIRT
rec_2 = RegVarSIRT.run(W, sino_noise, rec_st, Div, eps=eps, n_it=n_iter, step='steepest')
al2 = rec_2['alpha']
en2 = rec_2['energy']
print(rec_2['iter'])
rec_v = rec_2['rec']
#VSIRT+TV
rec_3 = RegVarSIRT.run(W, sino_noise, rec_st, Div, Lambda=Lambda, eps=eps, n_it=n_iter*100, step='steepest')
al3 = rec_3['alpha']
en3 = rec_3['energy']
print(rec_3['iter'])
rec_vt = rec_3['rec']
astra.projector.delete(proj_id)
In [ ]:
er_v = err_l2(phantom, rec_v)
er_vt = err_l2(phantom, rec_vt)
er_s = err_l2(phantom, rec_s)
er_st = err_l2(phantom, rec_st)
er_fbp = err_l2(phantom, rec_fbp)
# miv = rec_1.min()
# if (miv > rec_2.min()):
# miv = rec_2.min()
# mav = rec_1.max()
# if (mav < rec_2.max()):
# mav = rec_2.max()
miv, mav = 0, 0.1
sino_new = sino_noise * Div
plt.figure(figsize=(15,15))
plt.subplot(3,3,1)
plt.imshow(sino_noise, interpolation=None, cmap="gray")
plt.title('a) Noisy sinogram', loc='left')
plt.colorbar(ticks=[sino_noise.min(), sino_noise.max()/3, 2*sino_noise.max()/1, sino_noise.max()], orientation ='horizontal');
plt.subplot(3,3,2)
plt.imshow(D, interpolation=None, cmap="gray")
plt.title('b) STD', loc='left')
plt.colorbar(ticks=[D.min(), D.max()/3, 2*D.max()/3, D.max()], orientation ='horizontal');
plt.subplot(3,3,3)
plt.imshow(sino_new, interpolation=None, cmap="gray")
plt.title('c) Sinogram / STD', loc='left')
plt.colorbar(ticks=[sino_new.min(), sino_new.max()/3, 2*sino_new.max()/3, sino_new.max()], orientation ='horizontal')
plt.subplot(3,3,4)
plt.imshow(phantom, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('d) Phantom', loc='left')
plt.colorbar(orientation ='horizontal');
plt.subplot(3,3,5)
plt.imshow(rec_s, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('e) SIRT, Err=' + str('{:0.2e}'.format(er_s)), loc='left')
plt.colorbar(orientation ='horizontal');
plt.subplot(3,3,6)
plt.imshow(rec_v, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('f) VarSIRT, Err=' + str('{:0.2e}'.format(er_v)), loc='left')
plt.colorbar(orientation ='horizontal');
plt.subplot(3,3,7)
plt.imshow(rec_fbp, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('g) FBP, Err=' + str('{:0.2e}'.format(er_fbp)), loc='left')
plt.colorbar(orientation ='horizontal');
plt.subplot(3,3,8)
plt.imshow(rec_st, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('g) SIRT+TV, Err=' + str('{:0.2e}'.format(er_st)), loc='left')
plt.colorbar(orientation ='horizontal');
plt.subplot(3,3,9)
plt.imshow(rec_vt, interpolation=None, cmap="gray", vmin=miv, vmax=mav)
plt.title('g) VarSIRT+TV, Err=' + str('{:0.2e}'.format(er_vt)), loc='left')
plt.colorbar(orientation ='horizontal');
In [ ]:
f2 = plt.figure(figsize=(10,7))
plt.semilogy(en1, label="SIRT", linewidth=1.0)
plt.semilogy(en2, label="VSIRT", linewidth=1.0)
plt.semilogy(en3, label="VSIRT+TV", linewidth=1.0)
plt.semilogy(en4, label="SIRT+TV", linewidth=1.0)
plt.grid(True)
plt.ylabel('Value, a.u.')
plt.xlabel('Iteration number, a.u.')
plt.legend(loc='best');
# plt.savefig("10_sr.png")
# plt.show()
In [ ]:
plt.semilogy(al1, label="SIRT", linewidth=1.0)
plt.semilogy(al2, label="VSIRT", linewidth=1.0)
plt.semilogy(al3, label="VSIRT-TV", linewidth=1.0)
plt.semilogy(al4, label="SIRT-TV", linewidth=1.0)
plt.grid(True)
plt.ylabel('Value, a.u.')
plt.xlabel('Iteration number, a.u.')
plt.legend(loc='best');
In [ ]:
for i in np.hstack([np.arange(1,20),np.arange(20,100,10),np.arange(100,5000,100)]):
try:
plt.figure(figsize=(15,5))
plt.subplot(151)
plt.imshow(rec_fbp, cmap='gray', vmin=0, vmax=0.1)
plt.title('FBP {}'.format(i))
plt.subplot(152)
plt.imshow(plt.imread('s_{:004}.png'.format(i)).sum(axis=-1), cmap='gray')
plt.title('SIRT {}'.format(i))
plt.subplot(153)
plt.imshow(plt.imread('s_tv_{:004}.png'.format(i)).sum(axis=-1), cmap='gray')
plt.title('SIRT+TV {}'.format(i))
plt.subplot(154)
plt.imshow(plt.imread('v_{:004}.png'.format(i)).sum(axis=-1), cmap='gray')
plt.title('VSIRT {}'.format(i))
plt.subplot(155)
plt.imshow(plt.imread('v_tv_{:004}.png'.format(i)).sum(axis=-1), cmap='gray')
plt.title('VSIRT+TV {}'.format(i))
plt.show()
except:
plt.close()
In [ ]:
# !zsh -c "rm {s,v}_*.png"