In [ ]:
%matplotlib inline
In [ ]:
import pylab as plt
import numpy as np
import astra
import alg
import VarSIRT
In [ ]:
def STD(x):
return 1.1e-7 * x * x + 165.
def load_data(path_sino):
sinogram = plt.imread(path_sino)
if len(sinogram.shape) == 3:
sinogram = sinogram[...,0]
sinogram = np.flipud(sinogram)
sinogram = sinogram.astype('float32')
sinogram = sinogram[::, :]
# fig = plt.figure(figsize=(20,20))
# a=fig.add_subplot(1,1,1)
# imgplot = plt.imshow(sinogram, interpolation=None, cmap="gray")
# plt.colorbar(orientation='horizontal')
# plt.show()
# plt.savefig("sinogram.png")
# plt.close(fig)
detector_cell = sinogram.shape[1]
n_angles = sinogram.shape[0]
'''
Image Pixel Size (um)=11.000435
Object to Source (mm)=54.350
Camera to Source (mm)=225.315
'''
pixel_size = 11.000435e-3
os_distance = 54. / pixel_size
ds_distance = 225.315 / pixel_size
angles = np.arange(n_angles) * 0.4
angles = angles.astype('float32') / 180.0 * np.pi
angles = angles - (angles.max() + angles.min()) / 2
angles = angles + np.pi / 2
vol_geom = astra.create_vol_geom(detector_cell, detector_cell)
proj_geom = astra.create_proj_geom('fanflat', ds_distance / os_distance, detector_cell, angles,
os_distance, (ds_distance - os_distance))
return proj_geom, vol_geom, sinogram
def recontructr(proj_geom, vol_geom, sinogram):
k = sinogram.shape[0]/proj_geom['DetectorWidth']**2/(np.pi/2)
rec_fbp = alg.gpu_fbp(proj_geom, vol_geom, sinogram)
rec_fbp *= k
rec_sirt = alg.gpu_sirt(proj_geom, vol_geom, sinogram, rec_fbp, 3000)
return rec_fbp, rec_sirt
def rec_VSIRT(proj_geom, vol_geom, sinogram, rec_fbp, D):
# V = sinogram.copy()
V = 1.0 / D #STD(V)
fig = plt.figure(figsize=(20,20))
a=fig.add_subplot(1,1,1)
imgplot = plt.imshow(V, interpolation=None, cmap="gray")
plt.colorbar(orientation='horizontal');
plt.savefig("STD.png")
plt.close(fig)
sino_new = sinogram * V
proj_id = astra.create_projector('cuda', proj_geom, vol_geom)
W = astra.OpTomo(proj_id)
eps = 1e-30
x0 = np.copy(rec_fbp)
rec_vsirt = VarSIRT.run(W, sino_new, V, x0, eps, 3000, 'steepest')
en_1 = rec_vsirt['energy']
rec_vsirt = rec_vsirt['rec']
en_1 = np.asarray(en_1)
x = np.arange(0, len(en_1))
fig = plt.figure()
plt.semilogy(x, en_1, label="energy", linewidth=3.0)
plt.grid(True)
plt.ylabel('Value, a.u.')
plt.xlabel('Number of iterations, a.u.')
plt.legend(loc='best');
plt.savefig("conv.png")
return rec_vsirt
def plot_result(rec_fbp, rec_sirt, rec_vsirt):
rec_fbp = np.flipud(rec_fbp)
rec_sirt = np.flipud(rec_sirt)
rec_vsirt = np.flipud(rec_vsirt)
plt.imsave('FBP.png', rec_fbp, cmap="gray")
plt.imsave('SIRT.png', rec_sirt, cmap="gray")
plt.imsave('VarSIRT.png', rec_vsirt, cmap="gray")
s_x = 500
s_y = 500
e_x = 650
e_y = 650
fig = plt.figure(figsize=(10,4))
a=fig.add_subplot(1,3,1)
imgplot = plt.imshow(rec_vsirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray")
a.set_title('VarSIRT')
plt.colorbar(orientation='horizontal');
a=fig.add_subplot(1,3,2)
imgplot = plt.imshow(rec_sirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray")
a.set_title('SIRT')
plt.colorbar(orientation ='horizontal');
a=fig.add_subplot(1,3,3)
imgplot = plt.imshow(rec_vsirt[s_x:e_x, s_y:e_y] - rec_sirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray")
a.set_title('VarSIRT - SIRT')
plt.colorbar(orientation ='horizontal');
plt.savefig("diff_1.png")
plt.close()
s_x = 620
s_y = 560
e_x = 650
e_y = 600
f1 = plt.figure(figsize=(10,4))
a=f1.add_subplot(1,3,1)
imgplot = plt.imshow(rec_vsirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray", vmin=0, vmax=100)
a.set_title('VarSIRT')
plt.colorbar(orientation='horizontal');
a=f1.add_subplot(1,3,2)
imgplot = plt.imshow(rec_sirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray",vmin=0, vmax=100)
a.set_title('SIRT')
plt.colorbar(orientation ='horizontal');
a=f1.add_subplot(1,3,3)
imgplot = plt.imshow(rec_vsirt[s_x:e_x, s_y:e_y] - rec_sirt[s_x:e_x, s_y:e_y], interpolation=None, cmap="gray")
a.set_title('VarSIRT - SIRT')
plt.colorbar(orientation ='horizontal');
plt.savefig("diff_2.png")
plt.show()
# plt.close(f1)
num_str = 700
s = 400
e = 600
f2 = plt.figure(figsize=(10,7))
plt.plot(rec_sirt[num_str, s:e] / np.mean(rec_sirt[num_str, :]), label="SIRT", linewidth=1.0)
plt.plot(rec_vsirt[num_str, s:e] / np.mean(rec_vsirt[num_str, :]), label="VarSIRT", linewidth=1.0)
plt.xticks(np.arange(0, 225, 25), np.arange(s, e+25, 25))
plt.grid(True)
plt.ylabel('Value, a.u.')
plt.xlabel('Pixel number, a.u.')
plt.legend(loc='best');
plt.savefig("sr_1.png")
plt.show()
# plt.close(f2)
f3 = plt.figure(figsize=(10,7))
plt.plot(rec_fbp[num_str, s:e] / np.mean(rec_fbp[num_str, :]), label="FBP", linewidth=1.0)
plt.plot(rec_vsirt[num_str, s:e] / np.mean(rec_vsirt[num_str, :]), label="VarSIRT", linewidth=1.0)
plt.grid(True)
plt.xticks(np.arange(0, 225, 25), np.arange(s, e+25, 25))
plt.ylabel('Value, a.u.')
plt.xlabel('Pixel number, a.u.')
plt.legend(loc='best');
plt.savefig("sr_2.png")
plt.show()
# plt.close(f3)
In [ ]:
print("Program is started")
# path_sino = './noise_recon/S1S2S3_NoAv_2.74um__sino0245.tif' #'S1S2S3_2.74um__sino0245.tif'
proj_geom, vol_geom, sinogram = load_data(path_sino)
noise = np.random.rand(sinogram.shape[0],sinogram.shape[1])*STD(sinogram)*20
sinogram[:,200:205] = 65535
plt.figure(figsize=(12,9))
plt.imshow(sinogram)
plt.colorbar(orientation='horizontal')
D = STD(sinogram)
D[:,200:205] = 65535**2
plt.figure(figsize=(12,9))
plt.imshow(D)
plt.colorbar(orientation='horizontal')
rec_fbp, rec_sirt = recontructr(proj_geom, vol_geom, sinogram)
rec_vsirt = rec_VSIRT(proj_geom, vol_geom, sinogram, rec_fbp, D)
plot_result(rec_fbp, rec_sirt, rec_vsirt)
print("Completed successfully")
In [ ]:
plt.figure(figsize=(12,9))
plt.imshow(sinogram)
plt.colorbar(orientation='horizontal')
In [ ]:
plot_result(rec_fbp, rec_sirt, rec_vsirt)
In [ ]:
plt.figure(figsize=(15,10))
plt.subplot(131)
plt.imshow(rec_fbp[200:300,500:600], vmin=0, vmax=150, cmap=plt.cm.gray, interpolation='nearest')
plt.title('FBP')
plt.colorbar(orientation='horizontal')
plt.subplot(132)
plt.imshow(rec_sirt[200:300,500:600], vmin=0, vmax=150, cmap=plt.cm.gray, interpolation='nearest')
plt.title('SIRT')
plt.colorbar(orientation='horizontal')
plt.subplot(133)
plt.imshow(rec_vsirt[200:300,500:600], vmin=0, vmax=150,cmap=plt.cm.gray, interpolation='nearest')
plt.title('VSIRT')
plt.colorbar(orientation='horizontal')
In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(rec_vsirt, cmap=plt.cm.gray, vmin=0, vmax=200)
In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(rec_sirt, cmap=plt.cm.gray, vmin=0,vmax=200)
In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(rec_fbp, cmap=plt.cm.gray, vmin=0,vmax=200)
In [ ]:
def images_diff(im1, im2):
assert(im1.shape==im2.shape)
rec_diff = np.zeros(shape=(im1.shape[0],im1.shape[1],3), dtype='float32')
im1_t = im1.copy()
im1_t = (im1_t-im1_t.min())/(im1_t.max()-im1_t.min())
im2_t = im2.copy()
im2_t = (im2_t-im2_t.min())/(im2_t.max()-im2_t.min())
# nrecon_rec_t[nrecon_rec_t<0] = 0
diff_rec = im1_t-im2_t
rec_diff[...,0] = diff_rec*(diff_rec>0)
rec_diff[...,1] = -diff_rec*(diff_rec<0)
rec_diff[...,2] = rec_diff[...,1]
return rec_diff
In [ ]:
plt.figure(figsize=(15,10))
plt.subplot(131)
plt.imshow(100*images_diff(rec_vsirt[500:650,500:650],rec_fbp[500:650,500:650]))
plt.title('VSIRT/FBP')
plt.subplot(132)
plt.imshow(20*images_diff(rec_vsirt[500:650,500:650],rec_sirt[500:650,500:650]))
plt.title('VSIRT/SIRT')
plt.subplot(133)
plt.imshow(20*images_diff(rec_sirt[500:650,500:650],rec_fbp[500:650,500:650]))
plt.title('SIRT/FBP')
In [ ]:
plt.gray()
plt.figure(figsize=(15,10))
plt.subplot(131)
plt.imshow((rec_vsirt-rec_fbp)[500:650,500:650])
plt.title('VSIRT-FBP')
plt.colorbar(orientation='horizontal')
plt.subplot(132)
plt.imshow((rec_vsirt-rec_sirt)[500:650,500:650])
plt.title('VSIRT-SIRT')
plt.colorbar(orientation='horizontal')
plt.subplot(133)
plt.imshow((rec_sirt-rec_fbp)[500:650,500:650])
plt.title('SIRT-FBP')
plt.colorbar(orientation='horizontal')
In [ ]:
proj_id = astra.create_projector('cuda', proj_geom, vol_geom)
W = astra.OpTomo(proj_id)
In [ ]:
sino_fbp = (W*rec_fbp).reshape(sinogram.shape)
sino_vsirt = (W*rec_vsirt).reshape(sinogram.shape)
sino_sirt = (W*rec_sirt).reshape(sinogram.shape)
In [ ]:
def build_mask(rec):
X,Y = np.meshgrid(np.arange(rec.shape[0])-rec.shape[0]/2,np.arange(rec.shape[1])-rec.shape[1]/2)
mask = X*X+Y*Y < (rec.shape[0]/2)**2
return mask
In [ ]:
mask = build_mask(rec_fbp)
In [ ]:
sino_fbp = (W*rec_fbp).reshape(sinogram.shape)
sino_vsirt = (W*rec_vsirt).reshape(sinogram.shape)
sino_sirt = (W*rec_sirt).reshape(sinogram.shape)
for exp_sino in [sino_fbp, sino_sirt, sino_vsirt]:
plt.figure(figsize=(15,15))
plt.imshow(10*images_diff(exp_sino, sinogram))
In [ ]:
sino_fbp = (W*(rec_fbp*mask)).reshape(sinogram.shape)
sino_vsirt = (W*(rec_vsirt*mask)).reshape(sinogram.shape)
sino_sirt = (W*(rec_sirt*mask)).reshape(sinogram.shape)
for exp_sino in [sino_fbp, sino_sirt, sino_vsirt]:
plt.figure(figsize=(15,15))
plt.imshow(exp_sino - sinogram)
plt.colorbar(orientation='horizontal')
In [ ]:
sino_fbp = (W*(rec_fbp*mask)).reshape(sinogram.shape)
sino_vsirt = (W*(rec_vsirt*mask)).reshape(sinogram.shape)
sino_sirt = (W*(rec_sirt*mask)).reshape(sinogram.shape)
for exp_sino in [sino_fbp, sino_sirt, sino_vsirt]:
plt.figure(figsize=(15,15))
plt.imshow(10*images_diff(exp_sino, sinogram))
In [ ]: