In [ ]:
%matplotlib inline

In [ ]:
import os
import glob

import numpy as np
from scipy.ndimage import median_filter
import pylab as plt
from skimage.io import imread

import astra
import h5py

In [ ]:
def log_progress(sequence, every=None, size=None):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = size / 200     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{index} / ?'.format(index=index)
                else:
                    progress.value = index
                    label.value = u'{index} / {size}'.format(
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = unicode(index or '?')

In [ ]:
data_folder = '/diskmnt/a/makov/tomo_data/podurec/krest'

In [ ]:
data_files = glob.glob(os.path.join(data_folder,'data','*.tif'))
data_files = sorted(data_files)
print len(data_files)

In [ ]:
test_im = imread(data_files[0])
sinogram = np.zeros((len(data_files), test_im.shape[0],test_im.shape[1]),dtype='float32')

In [ ]:
for idf, data_file in log_progress(list(enumerate(data_files))):
    sinogram[idf] = imread(data_file)

In [ ]:
print sinogram.shape

In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(sinogram[300],cmap=plt.cm.viridis)
plt.colorbar(orientation='horizontal')

In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(sinogram[:,290,:],cmap=plt.cm.viridis)
plt.colorbar(orientation='horizontal')

In [ ]:
def build_reconstruction_geomety(detector_size, angles):
    proj_geom = astra.create_proj_geom('parallel', 1.0, detector_size, angles)
    return proj_geom

In [ ]:
def astra_tomo2d(sinogram, angles):
    angles = angles.astype('float64') # hack for astra stability, may be removed in future releases
    detector_size = sinogram.shape[1]
    

    rec_size = detector_size # size of reconstruction region
    vol_geom = astra.create_vol_geom(rec_size, rec_size)

    proj_geom = build_reconstruction_geomety(detector_size, angles)
    
    sinogram_id = astra.data2d.create('-sino', proj_geom, data=sinogram)
    # Create a data object for the reconstruction
    rec_id = astra.data2d.create('-vol', vol_geom)

    # Set up the parameters for a reconstruction algorithm using the GPU
    cfg = astra.astra_dict('SART_CUDA')
    cfg['ReconstructionDataId'] = rec_id
    cfg['ProjectionDataId'] = sinogram_id
    cfg['option'] = {}
#     cfg['option']['ShortScan'] = True
    cfg['option']['MinConstraint'] = 0
#     cfg['option']['MaxConstraint'] = 0.02

    # Available algorithms:
    # SIRT_CUDA, SART_CUDA, EM_CUDA, FBP_CUDA (see the FBP sample)

    # Create the algorithm object from the configuration structure
    alg_id = astra.algorithm.create(cfg)

    # Run 150 iterations of the algorithm
    astra.algorithm.run(alg_id, 1000)
    # Get the result
    rec = astra.data2d.get(rec_id)

    # Clean up. Note that GPU memory is tied up in the algorithm object,
    # and main RAM in the data objects.
    astra.algorithm.delete(alg_id)
    astra.data2d.delete(rec_id)
    astra.data2d.delete(sinogram_id)
    astra.clear()
    return rec, proj_geom, cfg

In [ ]:
slices = np.arange(sinogram.shape[0]) 
# slices = np.hstack([slices[:150],slices[220:]])

In [ ]:
angles = slices*0.5/180.*np.pi
plt.plot(angles)

In [ ]:
s = np.copy(sinogram[slices,40,:])
s = median_filter(s,5)
s = (s-s.min())/(s.max()-s.min())+1e-2
# s[s<0.001] = 0.001
s = -np.log(s)
s = (s.T/s.T.sum(axis=0)).T

In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(s,cmap=plt.cm.viridis)
plt.colorbar(orientation='horizontal')

In [ ]:
plt.plot(s.sum(axis=-1))

In [ ]:
st = s[:,:-28]
plt.figure(figsize=(10,5))
plt.plot(st[0,:])
plt.plot(st[-1,::-1])
plt.grid(True)

In [ ]:
rec = astra_tomo2d(st, angles)

In [ ]:
plt.figure(figsize=(15,15))
plt.imshow(rec[0],cmap=plt.cm.gray)
plt.colorbar(orientation='horizontal')

In [ ]:
with h5py.File(os.path.join(data_folder,'res','res.h5'),'w') as f:
    for slice_id in log_progress(np.arange(sinogram.shape[1])):
        s = np.copy(sinogram[slices,slice_id,:])
        s = median_filter(s,5)
        s = (s-s.min())/(s.max()-s.min())+1e-2
        # s[s<0.001] = 0.001
        s = -np.log(s)
        s = (s.T/s.T.sum(axis=0)).T
        
        st = s[:,:-28]
        rec, proj_geom, cfg = astra_tomo2d(st, angles)
        f[str(slice_id)] = rec

In [ ]:
m = 0
with h5py.File(os.path.join(data_folder,'res','res.h5'),'r') as f:
    for name in log_progress(f.keys()):
        x = f[name].value
        m = np.maximum(m, x.max())
#         plt.imsave(os.path.join(data_folder,'res','{}.tif'.format(name)))

In [ ]:
print m

In [ ]:
with h5py.File(os.path.join(data_folder,'res','res.h5'),'r') as f:
    for name in log_progress(f.keys()):
        x = f[name].value
        plt.imsave(os.path.join(data_folder,'res','{}.tif'.format(name)),
                                x, vmin=0, vmax=m, cmap=plt.cm.gray)

In [ ]:
with h5py.File(os.path.join(data_folder,'res','res.h5'),'r') as f:
#     print f.keys()
    for name in [290, 590]:
        plt.figure(figsize=(15,15))
        plt.imshow(f[str(name).decode('utf-8')][500:800,300:1000],cmap=plt.cm.viridis)
        plt.title('SART {}'.format(name))
        plt.colorbar(orientation='horizontal')
        plt.savefig('SART {}.png'.format(name))
        plt.show()
        
        plt.figure(figsize=(15,15))
        t = imread(os.path.join(data_folder,str(name)+'.tif'))
        plt.imshow(np.flipud(t[400:700,100:800]),cmap=plt.cm.viridis)
        plt.title('Podurets {}'.format(name))
        plt.colorbar(orientation='horizontal')
        plt.savefig('Podurets {}.png'.format(name))
        plt.show()
        
        plt.figure(figsize=(15,15))
        t = imread(os.path.join(data_folder,str(name)+'.tif'))
        plt.imshow(np.flipud(t[400:700,100:800]),cmap=plt.cm.viridis,vmin=0)
        plt.title('Podurets {}_upper_zeros'.format(name))
        plt.colorbar(orientation='horizontal')
        plt.savefig('Podurets {}_upper_zeros.png'.format(name))
        plt.show()

In [ ]:
with open('amira.raw', 'wb') as amira_file:
    with h5py.File(os.path.join(data_folder,'res','res.h5'),'r') as h5f:
        for i in range(len(h5f.keys())):
            x= h5f[str(i).decode('utf-8')].value
            np.array(x).tofile(amira_file)

        file_shape = x.shape

        with open('tomo.hx', 'w') as af:
            af.write('# Amira Script\n')
            af.write('remove -all\n')
            af.write(r'[ load -raw ${SCRIPTDIR}/amira.raw little xfastest float 1 '+
                     str(file_shape[1])+' '+str(len(h5f.keys()))+' '+str(file_shape[0])+
                     ' 0 '+str(file_shape[1]-1)+' 0 '+str(len(h5f.keys())-1)+' 0 '+str(file_shape[0]-1)+
                     ' ] setLabel tomo.raw\n')

In [ ]:
file_shape

In [ ]:
with h5py.File(os.path.join(data_folder,'res','res.h5'),'r') as h5f:
    with open('tomo.hx', 'w') as af:
            af.write('# Amira Script\n')
            af.write('remove -all\n')
            af.write(r'[ load -raw ${SCRIPTDIR}/amira.raw little xfastest float 1 '+
                     str(file_shape[1])+' '+str(len(h5f.keys()))+' '+str(file_shape[0])+
                     ' 0 '+str(file_shape[1]-1)+' 0 '+str(len(h5f.keys())-1)+' 0 '+str(file_shape[0]-1)+
                     ' ] setLabel tomo.raw\n')

In [ ]: