In [366]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from skimage import feature, morphology
from skimage.transform import probabilistic_hough_line

from astropy.io import fits

In [350]:
def downsample_image(image, n):
    """Downsample input image  n x n

    Parameters
    ----------
    image : ndarray 
        2D nx x ny CCD image data.
    n : int
        Resampling scale.
        
    Returns
    -------
    result : 
        Resampled image with shape = image.shape//n
    """
    ny, nx = image.shape
    ny = (ny//n) * n
    nx = (nx//n) * n
    result = image[0:ny, 0:nx].reshape(ny//n,n,nx//n,n).mean(axis=-1).mean(axis=-2)
    return result

def plot_mask(maskfile, hdu='MASK', downsample=1):
    hdus = fits.open(maskfile)
    img = hdus[hdu].data
    
    if downsample > 1:
        m = downsample_image(img, downsample)
    else:
        m = np.copy(img)
    
    dpi = 256
    fig = plt.figure(figsize=(m.shape[1]/dpi, m.shape[0]/dpi), dpi=dpi)
    ax = plt.Axes(fig, [0,0,1,1])
    ax.set_axis_off()
    fig.add_axes(ax)
    
    vmax = 1.
    if hdu == 'IMAGE':
        vmax = 100.
        
    ax.imshow(m, cmap='gray', vmin=0, vmax=vmax, origin='lower')
    fig.canvas.draw()
    return fig

def dilate_mask(maskfile, hdu='MASK', downsample=1, dilate=10):
    hdus = fits.open(maskfile)
    img = hdus[hdu].data
    
    if downsample > 1:
        m = downsample_image(img, downsample)
    else:
        m = np.copy(img)
        
    m_binary = m > 0
    m_dilate = morphology.binary_dilation(m_binary, np.ones((dilate, dilate)))
        
    dpi = 256
    fig = plt.figure(figsize=(m_dilate.shape[1]/dpi, m_dilate.shape[0]/dpi), dpi=dpi)
    ax = plt.Axes(fig, [0,0,1,1])
    ax.set_axis_off()
    fig.add_axes(ax)
    
    ax.imshow(m_dilate, cmap='gray', vmin=0, vmax=1, origin='lower')
    fig.canvas.draw()
    return fig

def plot_chunks(img, mask, proc=None, prefix='test', title=None):
    
    if proc is None:
        if img.shape != mask.shape:
            return
    else:
        if img.shape != mask.shape or img.shape != proc.shape:
            return
    
    nrow, ncol = img.shape
    erow = np.linspace(0,nrow,9, dtype=int)
    ecol = np.linspace(0,ncol,9, dtype=int)

    for i, (r0, r1) in enumerate(zip(erow[:-1], erow[1:])):
        for j, (c0, c1) in enumerate(zip(ecol[:-1], ecol[1:])):
            output = '{}_{:02d}_{:02d}.pdf'.format(prefix, i,j)
            subimg = img[r0:r1, c0:c1]
            submask = mask[r0:r1, c0:c1]

            if proc is None:
                fig, axes = plt.subplots(1,2, figsize=(8,4))
            else:
                fig, axes = plt.subplots(1,3, figsize=(12,4))
                subproc = proc[r0:r1, c0:c1]

            ax = axes[0]
            im = ax.imshow(subimg, cmap='gray', origin='lower', interpolation='nearest', vmin=0, vmax=100)
#             cb = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            ax.set(xticks=[], yticks=[], title='{}: IMAGE'.format(prefix))
            ax.text(0.02,0.02, '{}:{}, {}:{}'.format(r0,r1,c0,c1), color='yellow', fontsize=8,
                    transform=ax.transAxes)
            ax.text(0.02,0.96, '{}, {}'.format(i,j), color='yellow', fontsize=8,
                    transform=ax.transAxes)
            
            ax = axes[1]
            im = ax.imshow(submask, cmap='gray', origin='lower', interpolation='nearest')
            ax.set(xticks=[], yticks=[], title='{}: MASK'.format(prefix))

            if proc is not None:
                ax = axes[2]
                im = ax.imshow(subproc, cmap='gray', origin='lower', interpolation='nearest')
                # cb = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                ptitle = prefix
                if title is not None:
                    ptitle = '{}: {}'.format(prefix, title)
                ax.set(xticks=[], yticks=[], title=ptitle);

            fig.tight_layout()
            fig.savefig(output)
            plt.close()

In [364]:
fig = plot_mask('/Users/sybenzvi/Documents/DESI/cmx/preproc-z3-00044518.fits', hdu='IMAGE', downsample=4)
fig.savefig('raw_image.png')