In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
import time
import pyrap.images as pim

from joblib import Parallel, delayed
import multiprocessing

from astropy import units as u
from astropy.coordinates import SkyCoord
import os
import glob
from mpl_toolkits.axes_grid.anchored_artists import AnchoredText
import shutil

from scipy.interpolate import griddata

%matplotlib inline
#plt.rcParams['savefig.dpi'] = 100
plt.rcParams['figure.figsize']=8,8
plt.rcParams['font.size']=12
plt.rcParams['axes.labelsize']='xx-large'
plt.rcParams['axes.titlesize']='xx-large'

In [ ]:
def gen_index(n_measurements,n_images,i):
    return np.linspace(0,n_measurements*(n_images-1),n_images,dtype="int")+i

In [ ]:
def get_index(h,i):
    return np.linspace(0,h["num_apertures"]*(h["num_images"]-1),h["num_images"],dtype="int")+i

In [ ]:
def aperture_pixels(data,x,y,r,meshx,meshy):
    '''Return data (2D array) values which fall within a pixel distance r, of x and y locations. Meshgrids required to 
    number pixels.'''
    return data[np.where(np.sqrt((meshx-x)**2+(meshy-y)**2.) <= r)]

In [ ]:
def annulus(x,y,inner, outter, center=512):
    '''Returns the indexs for a 2D annulus.
    input:
        x: x axis meshgrid
        y: y axis meshgrid
        inner: inner '''
    dinner = np.cos(np.radians(inner))*center
    doutter =  np.cos(np.radians(outter))*center

    return (np.sqrt((x-center)**2.+(y-center)**2.) <= doutter)& (np.sqrt((x-center)**2.+(y-center)**2.) >= dinner)

In [ ]:
def loop_save(img, save=1):
    '''Input an image and save numpy arrays for the pixel x/y image position,  pixel values, 
    also use the images WCS to save the cooresponting Ra/dec values. 
    
    Kwarg:
        save: 1 = save dic a .npz file (default), 0 = return dic. 
    
    
    TODO 
        check .fits or .image, 
        strip .fits or .image from output filename. 
        recode nested for loops to something more clever/faster
    '''
    if os.path.exists(img+".npz"):
        print "Skipping: "+img
        return
    else:
        print "Processing: "+img
    image_wcs = WCS(img)
    img_data = np.array(fits.open(img)[0].data[0,0,:,:])
    #img_data = pim.image(img).getdata()[0,0,:,:]
    
    ra, dec,n, nn = image_wcs.wcs_pix2world(meshx,meshy,1,1,1)
    nump_dic = np.array([],dtype=[("ra",np.float32 ),
                              ("dec",np.float32),
                              ("x",np.int32),
                              ("y",np.int32),
                              ("rms",np.ndarray)] )


    t1 = time.time()
    for i in np.linspace(0,1024,1024/8, endpoint=False):
        #print i 
        for j in np.linspace(0,1024,1024/8, endpoint=False):
            if np.isfinite(ra[i,j]) and np.isfinite(dec[i,j]):
            
                nump_dic = np.append(nump_dic, np.array([(ra[i,j],dec[i,j],i,j,img_data[i,j]) ],
                                                    dtype=[("ra",np.float32 ),
                                                           ("dec",np.float32),
                                                           ("x",np.int32),
                                                           ("y",np.int32),
                                                           ("rms",np.ndarray)]))
    if save:
        np.savez(img+".npz",nump_dic)
        print time.time()-t1
        return
    else:
        print time.time()-t1
        return nump_dic

In [ ]:
## this needs to point to a folder full of pyse image.rms.fits files.
imdir=
fits_list = sorted(glob.glob(imdir+"*.image.rms.fits"))

meshx, meshy = np.meshgrid(np.linspace(0,1023,1024),np.linspace(0,1023,1024))

Parallel(n_jobs=12)(delayed(loop_save)(img)for img in fits_list)

In [ ]:
t1 = time.time()
img_list = sorted(glob.glob(imdir+"*.image.fits.npz"))

rms_annulus = np.zeros(18)

std = np.zeros(12849*len(img_list),dtype=np.float32)
x = np.zeros(12849*len(img_list),dtype=np.int32)
y = np.zeros(12849*len(img_list),dtype=np.int32)
ra = np.zeros(12849*len(img_list),dtype=np.float32)
dec = np.zeros(12849*len(img_list),dtype=np.float32)
beam = np.zeros(12849*len(img_list),dtype=np.float32)
weight = np.ones(12849*len(img_list),dtype=np.float32)

total = len(img_list)
   
for i in range(len(img_list)):
    print i, img_list[i]
    load_data = np.load(img_list[i])
    
    x[i*12849:(i+1)*12849] = load_data["arr_0"]["x"]
    y[i*12849:(i+1)*12849] = load_data["arr_0"]["y"]
        
    std[i*12849:(i+1)*12849] = load_data["arr_0"]["rms"]
    beam[i*12849:(i+1)*12849] = load_data["arr_0"]["rms"]

    ra[i*12849:(i+1)*12849] = load_data["arr_0"]["ra"]
    dec[i*12849:(i+1)*12849] = load_data["arr_0"]["dec"]
    
for i in range(18):
    mask = annulus(x,y,(i/2.+0.5)*10,i/2.*10)
    rms_annulus[i] = np.median(std[mask])
    beam[mask] = np.ones(len(beam[mask]))*np.median(std[mask])
        
print time.time() -t1

#plot elevation angle vs. median RMS in anulus across all RMS maps.
plt.plot(np.arange(18)/2.*10,rms_annulus,lw=5, c="black", label="total")
plt.xlabel("Elevation Angle")
plt.ylabel("Median RMS [Jy/beam]")
plt.legend()

In [ ]:
#save outputs so you dont need to calculate it again.
np.savez(imdir+"tot_beam",beam=beam)
np.savez(imdir+"tot_std.npz",std=std)

np.savez(imdir+"tot_ra.npz",ra=ra)
np.savez(imdir+"tot_dec.npz",dec=dec)

np.savez(imdir+"tot_x.npz",x=x)
np.savez(imdir+"tot_y.npz",y=y)