In [ ]:
%matplotlib inline
# %xmode verbose

In [ ]:
from glob import glob
import os

import numpy as np
import matplotlib.pyplot as plt
import astra
import alg
import RegVarSIRT
import pickle
import skimage
import skimage.io
import scipy as sp

In [ ]:
def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

In [ ]:
def err_l2(img, rec):
    return np.sum((img - rec)**2) / (rec.shape[0] * rec.shape[1])

In [ ]:
# make phantom
size = 512
mu1 = 0.006
mu2 = 0.005
mu3 = 0.004
phantom = np.zeros((size, size))
half_s = size / 2

y, x = np.meshgrid(range(size), range(size))
xx = (x - half_s).astype('float32')
yy = (y - half_s).astype('float32')
  
mask_ell1 = pow(xx + 0.1*size, 2)/np.power(0.35*size, 2) + pow(yy, 2)/np.power(0.15*size, 2) <= 1
mask_ell2 = pow(xx - 0.15*size, 2)/np.power(0.3*size, 2) + pow(yy - 0.15*size, 2)/np.power(0.15*size, 2) <= 1 
phantom[mask_ell1] = mu1
phantom[mask_ell2] = mu2
phantom[np.logical_and(mask_ell1, mask_ell2)] = mu3
phantom[int(0.15*size):int(0.35*size), int(0.2*size):int(0.5*size)] = mu3 
phantom[int(0.20*size):int(0.25*size), int(0.25*size):int(0.3*size)] = 0
phantom[int(0.30*size):int(0.35*size), int(0.35*size):int(0.4*size)] = mu1*10
phantom = 1e+1 * phantom

# make sinogram
n_angles = size//2
angles = np.arange(0.0, 180.0,  180.0 / n_angles)
angles = angles.astype('float32') / 180 * np.pi

pg = astra.create_proj_geom('parallel', 1.0, size, angles)
vg = astra.create_vol_geom((size, size))
sino = alg.gpu_fp(pg, vg, phantom)
print(sino.min(), sino.max())

In [ ]:
def load_data(path_sino, resize_x=None, resize_angle=None):
    
    if resize_x is None:
        x_resize = 4
    else:
        x_resize = resize_x
    
    if resize_angle is None:
        angle_resize = 4
    else:
        angle_resize = resize_angle
    
    sinogram = plt.imread(path_sino)
    if len(sinogram.shape) == 3:
        sinogram = sinogram[...,0]
    sinogram = np.flipud(sinogram)
    sinogram = sinogram.astype('float32')
    sinogram = sinogram[::angle_resize, ::x_resize]
    
#     fig = plt.figure(figsize=(20,20))
#     a=fig.add_subplot(1,1,1)
#     imgplot = plt.imshow(sinogram, interpolation=None, cmap="gray")
#     plt.colorbar(orientation='horizontal')
#     plt.show()
#     plt.savefig("sinogram.png")
#     plt.close(fig)

    detector_cell = sinogram.shape[1]
    n_angles = sinogram.shape[0] 

    pixel_size = 2.50e-3*x_resize
    os_distance = 49.430 / pixel_size
    ds_distance = 225.315 / pixel_size

    angles = np.arange(n_angles) * 0.2 * angle_resize
    angles = angles.astype('float32') / 180.0 * np.pi
    angles = angles - (angles.max() + angles.min()) / 2
    angles = angles + np.pi / 2

    vol_geom = astra.create_vol_geom(detector_cell, detector_cell)
    proj_geom = astra.create_proj_geom('fanflat', ds_distance / os_distance, detector_cell, angles,
                                                                        os_distance, (ds_distance - os_distance))
    return proj_geom, vol_geom, sinogram

In [ ]:
sinograms = glob('/home/makov/diskmnt/big/yaivan/noisy_data/noised_sino/BHI3_2.49um_*__sino0980.tif')
print(len(sinograms))

In [ ]:
plt.gray()

In [ ]:
for sinogram in log_progress(sinograms):
    x_resize = 2
    ang_resize = 2
    name = os.path.split(sinogram)[-1]
    print(name)
    pg, vg , sino_noise = load_data(sinogram, x_resize, ang_resize)
    
    # estimate noise
    D = 1.0 * (0.05 + (sino_noise/65535.)**2)
    Div = 1.0 / D
    
    plt.figure(figsize=(15,15))
    plt.imshow(sino_noise, cmap='gray')
    plt.colorbar(orientation='horizontal')
    plt.show()
    
    proj_id = astra.create_projector('cuda', pg, vg)
    W = astra.OpTomo(proj_id)
    x0 = np.zeros((sino_noise.shape[1], sino_noise.shape[1]))
    
    k = sino_noise.shape[0]/pg['DetectorWidth']**2/(np.pi/2)
    rec_fbp = alg.gpu_fbp(pg, vg, sino_noise)*k*x_resize * 2  # fixit
    
    eps = 1e-10
    n_iter = 200

    Lambda=4.

    #SIRT
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=n_iter, step='steepest')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('SIRT')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_s.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #SIRT+TV
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), Lambda=Lambda, eps=eps, n_it=n_iter, step='steepest')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('SIRT+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_s_tv.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #CGLS
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=n_iter, step='CG')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('CGLS')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_cg.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #CGLS+TV
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), Lambda=Lambda, eps=eps, n_it=n_iter, step='CG')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('CGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_cg_tv.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec 
    
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=3, step='CG')
    x0 = rec['rec']
    
    #VSIRT
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, eps=eps, n_it=n_iter, step='steepest')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VSIRT')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vs.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
        
    del rec['x_k']
    %xdel rec

    #VSIRT+TV
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, Lambda=Lambda, eps=eps, n_it=n_iter, step='steepest')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VSIRT+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vs_tv.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

    #VCGLS
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, eps=eps, n_it=n_iter, step='CG')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VCGLS')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vcg.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

    #VCGLS+TV
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, Lambda=Lambda, eps=eps, n_it=n_iter, step='CG')
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VCGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vcg_tv.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

In [ ]:
for sinogram in log_progress([sinograms[0],]):
    name = os.path.split(sinogram)[-1]
    pg, vg , sino_noise = load_data(sinogram, 1, 1)
    k = sino_noise.shape[0]/pg['DetectorWidth']**2/(np.pi/2)
    proj_id = astra.create_projector('cuda', pg, vg)
    W = astra.OpTomo(proj_id)
    rec_reference = alg.gpu_fbp(pg, vg, sino_noise)*k*2

rec_reference = rec_reference[::2,::2]

In [ ]:
plt.figure(figsize=(10,10))
plt.imshow(rec_reference, cmap='gray', vmin=0, vmax=150)
plt.title('ref_rec')
plt.show()

In [ ]:
from cycler import cycler
for sinogram in log_progress(sinograms):
    plt.figure(figsize=(8,8))
    name = os.path.split(sinogram)[-1]
    short_name=name.split('_')[2]
    prefixes = ['s','cg']
    prefixes.extend(['v'+p for p in prefixes])
    prefixes.extend([p+'_tv' for p in prefixes])
#     prefixes.extend([p+'_n' for p in prefixes])
    
    plt.rc('axes', prop_cycle=(cycler('color', ['r', 'g', 'b', 'y']*4)+
                           cycler('linestyle', ['-']*4+ ['--']*4+['-']*4+ ['--']*4)+
                           cycler('lw', [1]*8+ [2]*8)))
    
    for prefix in log_progress(prefixes):
        if not os.path.exists('{}_2_2_{}.pkl'.format(name, prefix)):
            print('NOT found {}'.format(prefix))
            continue

            
        with open('{}_2_2_{}.pkl'.format(name, prefix),'rb') as f:
            res = pickle.load(f)
            
        plt.semilogy(res['energy'], label='{}_{}'.format(short_name, prefix))
        %xdel res
        
    plt.title('Energy')
    plt.xlabel('Interation number')
    plt.ylabel('a.u.')
    plt.grid()
    plt.legend(loc=0)
    plt.savefig('{}_energy.png'.format(short_name))
    plt.show()

In [ ]:
def diff(x0,x1,norm):
    d0 = x0[mask>0].ravel()
    d1 = x1[mask>0].ravel()
    return sp.linalg.norm(d0-d1,norm)+sp.linalg.norm(x0[mask==0].ravel(),norm)

quality = {}
for sinogram in log_progress(sinograms):    
    mask = rec_reference>15
    name = os.path.split(sinogram)[-1]
    short_name=name.split('_')[2]
    quality[short_name] = {}
    
    prefixes = ['s','cg']
    prefixes.extend(['v'+p for p in prefixes])
    prefixes.extend([p+'_tv' for p in prefixes])
    prefixes.extend([p+'_n' for p in prefixes])
    
    for prefix in log_progress(prefixes):
        quality[short_name][prefix] = {}
        if not os.path.exists('{}_2_2_{}.pkl'.format(name, prefix)):
            print('NOT found {}'.format(prefix))
            continue
            
        with open('{}_2_2_{}.pkl'.format(name, prefix),'rb') as f:
            res = pickle.load(f)
            res['fbp'] = res['fbp']/4
    
        l1 = []
        l2 = []
        for i in range(len(res['x_k'])):
            l1.append(diff(res['x_k'][i], rec_reference, 1))
            l2.append(diff(res['x_k'][i], rec_reference, 2))
        
        quality[short_name][prefix]['l1'] = l1
        quality[short_name][prefix]['l2'] = l2
        
        with open('quality_all.pkl','wb') as f:
            pickle.dump(quality, f)
            
        pos_l1 = np.argmin(l1)
        pos_l2 = np.argmin(l2)
        
        plt.figure(figsize=(10,5))
        plt.plot(l1/np.max(l1), label='l1')
        # plt.hlines(sp.linalg.norm(res['fbp'].ravel()-rec_reference.ravel(), 1)/np.max(l1), 
        #            0, len(l1), 'r', label='fbp L1')
        plt.plot(l2/np.max(l2), label='l2')
        # plt.hlines(sp.linalg.norm(res['fbp'].ravel()-rec_reference.ravel(), 2)/np.max(l2), 
        #            0, len(l2), label='fbp L2')
        plt.legend(loc=0)
        plt.ylabel('a.u.')
        plt.xlabel('Iteration')
        plt.title('{}  {}'.format(short_name, prefix))
        plt.grid()
        plt.savefig('quality_{}_{}.png'.format(short_name, prefix))
        plt.show()
        
        plt.figure(figsize=(15,15))
        plt.subplot(221)
        plt.imshow(res['x_k'][pos_l2][600:800,400:600], interpolation='nearest', vmin=0, vmax=150)
        plt.title('{}_{}_{}'.format(short_name, prefix,pos_l2))
        plt.subplot(222)
        plt.imshow(res['x_k'][-1][600:800,400:600], interpolation='nearest', vmin=0, vmax=150)
        plt.title('{}_{}_{}'.format(short_name, prefix,len(l2)))
        plt.subplot(223)
        plt.imshow(res['fbp'][600:800,400:600], interpolation='nearest', vmin=0, vmax=150)
        plt.title('{}_{}'.format(short_name, 'fbp'))
        plt.subplot(224)
        # plt.imshow(res['x_k'][50][600:800,400:600], interpolation='nearest', vmin=0, vmax=300)
        plt.imshow(rec_reference[600:800,400:600], interpolation='nearest', vmin=0, vmax=150)
        plt.title('{}_{}'.format(short_name, 'reference'))
        
        plt.savefig(plt.savefig('rec_{}_{}.png'.format(short_name, prefix)))
        plt.show()
        
        
        
#     plt.title('Energy')
#     plt.xtitle('Interation number')
#     plt.ytitle('a.u.')
#     plt.grid()
#     plt.legend(loc=0)
#     plt.savefig('{}_energy.png'.format(sinogram))
#     plt.show()

In [ ]:
for k,v in quality.items():
    plt.figure(figsize=(10,10))
    plt.title(k)
    for kk,vv in v.items():
        if ('l2' in vv) and not (kk.endswith('n')):
            plt.plot(vv['l2'], label = kk)
        
    plt.xlabel('Iterations')
    plt.legend(loc=0)
    plt.grid()
    plt.savefig('quality_plot_l2_{}.png'.format(k))
    plt.show()

In [ ]:
for k,v in quality.items():
    plt.rc('axes', prop_cycle=(cycler('color', ['r', 'g', 'b', 'y']*2)+
                       cycler('marker', ['o']*4+ ['x']*4)
                       )
          )
    
    plt.figure(figsize=(10,10))
    plt.title(k)
    
    for kk,vv in v.items():
        optimal = []
        value = []
        labels = []
        if ('l2' in vv) and not (kk.endswith('n')):
            optimal.append(np.argmin(vv['l2']))
            value.append(np.min(vv['l2']))
            labels.append(kk)
            plt.plot(optimal[-1], value[-1] , label = labels[-1])
        
    plt.xlabel('Iterations')
    plt.ylabel('Reconstruction error')
    plt.legend(loc=0)
    plt.grid()
    plt.savefig('quality_plot_final_{}.png'.format(k))
    plt.show()

In [ ]:
sinograms

In [ ]:
plt.figure(figsize=(12,10))
plt.hist(rec_reference.ravel(), bins=1000);
plt.vlines(14,0,100000)
plt.grid()

In [ ]:
mask = rec_reference>15

In [ ]:
plt.plot([np.sum(x)*2 for x in res['x_k']])

In [ ]:
sino_noise.sum(axis=-1).mean()

In [ ]:
pg

In [ ]:
l1 = []
l2 = []
def diff(x0,x1,norm):
    d0 = x0[mask>0].ravel()
    d1 = x1[mask>0].ravel()
    return sp.linalg.norm(d0-d1,norm)+sp.linalg.norm(x0[mask==0].ravel(),norm)
                   
for i in range(len(res['x_k'])):
    l1.append(diff(res['x_k'][i], rec_reference, 1))
    l2.append(diff(res['x_k'][i], rec_reference, 2))
    
plt.figure(figsize=(10,5))
plt.plot(l1/np.max(l1), label='l1')
# plt.hlines(sp.linalg.norm(res['fbp'].ravel()-rec_reference.ravel(), 1)/np.max(l1), 
#            0, len(l1), 'r', label='fbp L1')
plt.plot(l2/np.max(l2), label='l2')
# plt.hlines(sp.linalg.norm(res['fbp'].ravel()-rec_reference.ravel(), 2)/np.max(l2), 
#            0, len(l2), label='fbp L2')
plt.legend(loc=0)
plt.ylabel('a.u.')
plt.x.label('Iteration')
plt.grid()
plt.show()

pos_l1 = np.argmin(l1)
pos_l2 = np.argmin(l2)

In [ ]:
skimage.io.imsave('test.tiff', skimage.img_as_float(res['x_k'][30]))

In [ ]:
short_name=name.split('_')[2]
plt.figure(figsize=(10,10))
# plt.imshow(res['x_k'][40], interpolation='nearest', vmin=0, vmax=300)
plt.plot(rec_reference[500], label='reference')
plt.plot(res['fbp'][500], label='fbp')
plt.plot(res['x_k'][3][500], label='40')
plt.title('{}_{}'.format(short_name, prefix))
plt.legend(loc=0)
plt.grid()
plt.show()

In [ ]:
plt.plot(res['alpha'])
plt.grid()

In [ ]:
for sinogram in log_progress([sinograms[0],]):    
    mask = rec_reference>15
    name = os.path.split(sinogram)[-1]
    short_name=name.split('_')[2]
    quality[short_name] = {}
    
    prefixes = ['cg',]

    for prefix in log_progress(prefixes):
        quality[short_name][prefix] = {}
        if not os.path.exists('{}_2_2_{}.pkl'.format(name, prefix)):
            print('NOT found {}'.format(prefix))
            
        with open('{}_2_2_{}.pkl'.format(name, prefix),'rb') as f:
            res = pickle.load(f)
            res['fbp'] = res['fbp']/4

In [ ]:
k

In [ ]:
import imp

In [ ]:
imp.reload(RegVarSIRT)

In [ ]:
for sinogram in log_progress([sinograms[-1],]):
    x_resize = 2
    ang_resize = 2
    name = os.path.split(sinogram)[-1]
    print(name)
    pg, vg , sino_noise = load_data(sinogram, x_resize, ang_resize)
    
    # estimate noise
    D = 1.0 * (0.05 + (sino_noise/65535.)**2)
    Div = 1.0 / D
    
    plt.figure(figsize=(15,15))
    plt.imshow(sino_noise, cmap='gray')
    plt.colorbar(orientation='horizontal')
    plt.show()
    
    proj_id = astra.create_projector('cuda', pg, vg)
    W = astra.OpTomo(proj_id)
    x0 = np.zeros((sino_noise.shape[1], sino_noise.shape[1]))
    
    k = sino_noise.shape[0]/pg['DetectorWidth']**2/(np.pi/2)
    rec_fbp = alg.gpu_fbp(pg, vg, sino_noise)*k*x_resize * 2  # fixit
    
    eps = 1e-10
    n_iter = 20

    Lambda=4.

   
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=3, step='CG')
    x0 = rec['rec']
    
    #VCGLS+TV
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, Lambda=Lambda, eps=eps, n_it=n_iter, step='CG', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'][600:800,400:600], interpolation='nearest',vmin=0,vmax=150)
    plt.title('VCGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()

In [ ]:
plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'][600:800,400:600], interpolation='nearest',vmin=0,vmax=150)
    plt.title('VCGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()

In [ ]:
for sinogram in log_progress([sinograms[-1],]):
    x_resize = 2
    ang_resize = 2
    name = os.path.split(sinogram)[-1]
    print(name)
    pg, vg , sino_noise = load_data(sinogram, x_resize, ang_resize)
    
    # estimate noise
    D = 1.0 * (0.05 + (sino_noise/65535.)**2)
    Div = 1.0 / D
    
    plt.figure(figsize=(15,15))
    plt.imshow(sino_noise, cmap='gray')
    plt.colorbar(orientation='horizontal')
    plt.show()
    
    proj_id = astra.create_projector('cuda', pg, vg)
    W = astra.OpTomo(proj_id)
    x0 = np.zeros((sino_noise.shape[1], sino_noise.shape[1]))
    
    k = sino_noise.shape[0]/pg['DetectorWidth']**2/(np.pi/2)
    rec_fbp = alg.gpu_fbp(pg, vg, sino_noise)*k*x_resize * 2  # fixit
    
    eps = 1e-10
    n_iter = 200

    Lambda=4.

    #SIRT
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=n_iter,
                         step='steepest', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('SIRT')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_s_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #SIRT+TV
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), Lambda=Lambda, eps=eps, n_it=n_iter,
                         step='steepest', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('SIRT+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_s_tv_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #CGLS
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=n_iter,
                         step='CG', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('CGLS')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_cg_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec

    #CGLS+TV
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), Lambda=Lambda, eps=eps, n_it=n_iter,
                         step='CG', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('CGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_cg_tv_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    del rec['x_k']
    %xdel rec 
    
    rec = RegVarSIRT.run(W, sino_noise, np.zeros_like(x0), eps=eps, n_it=3,
                         step='CG', normalize=True)
    x0 = rec['rec']
    
    #VSIRT
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, eps=eps, n_it=n_iter,
                         step='steepest', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VSIRT')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vs_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
        
    del rec['x_k']
    %xdel rec

    #VSIRT+TV
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, Lambda=Lambda, eps=eps, n_it=n_iter,
                         step='steepest', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VSIRT+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vs_tv_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

    #VCGLS
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, eps=eps, n_it=n_iter,
                         step='CG', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VCGLS')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vcg_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

    #VCGLS+TV
    rec = RegVarSIRT.run(W, sino_noise, x0, Div, Lambda=Lambda, eps=eps, n_it=n_iter,
                         step='CG', normalize=True)
    rec['fbp'] = rec_fbp
    
    plt.figure(figsize=(15,10))
    plt.subplot(121)
    plt.imshow(rec['rec'], interpolation='nearest')
    plt.title('VCGLS+TV')
    plt.subplot(122)
    plt.semilogy(rec['energy'])
    plt.grid()
    plt.show()
    
    with open('{}_{}_{}_vcg_tv_n.pkl'.format(name, x_resize, ang_resize),'wb') as f:
        pickle.dump(rec, f)
    
    del rec['x_k']
    %xdel rec

In [ ]:
from cycler import cycler
for sinogram in log_progress([sinograms[-1],]):
    plt.figure(figsize=(8,8))
    name = os.path.split(sinogram)[-1]
    short_name=name.split('_')[2]
    prefixes = ['s','cg']
    prefixes.extend(['v'+p for p in prefixes])
    prefixes.extend([p+'_tv' for p in prefixes])
    prefixes.extend([p+'_n' for p in prefixes])
    
    plt.rc('axes', prop_cycle=(cycler('color', ['r', 'g', 'b', 'y']*4)+
                           cycler('linestyle', ['-']*4+ ['--']*4+['-']*4+ ['--']*4)+
                           cycler('lw', [1]*8+ [2]*8)))
    
    for prefix in log_progress(prefixes):
        if not os.path.exists('{}_2_2_{}.pkl'.format(name, prefix)):
            print('NOT found {}'.format(prefix))

            
        with open('{}_2_2_{}.pkl'.format(name, prefix),'rb') as f:
            res = pickle.load(f)
            
        plt.semilogy(res['energy'], label='{}_{}'.format(short_name, prefix))
        %xdel res
        
    plt.title('Energy')
    plt.xlabel('Interation number')
    plt.ylabel('a.u.')
    plt.grid()
    plt.legend(loc=0)
    plt.savefig('{}_energy.png'.format(short_name))
    plt.show()

In [ ]: