In [ ]:
#%pdb
%matplotlib inline
import matplotlib
matplotlib.rcParams['font.family'] = 'stixgeneral'
matplotlib.rcParams['figure.dpi'] = 200
import matplotlib.pyplot as plt
import os
import yt
yt.mylog.setLevel("WARNING")
import numpy as np
from yt_synchrotron_emissivity import *
from mpl_toolkits.axes_grid1 import AxesGrid, ImageGrid
from scipy.ndimage import gaussian_filter
import pyfits

In [ ]:
#dirs = ['/home/ychen/data/00only_0605_hinf/',\
#        '/home/ychen/data/00only_0529_h1/',\
#        '/home/ychen/data/00only_0605_h0/',]

ptype = 'lobe'
def plot_synchrotron_spectralindex_imagegrid(proj_axis, nus):

    dirs = ['/d/d5/ychen/2015_production_runs/0204_hinf_10Myr/',\
            '/d/d5/ychen/2015_production_runs/1022_h1_10Myr/',\
            #'/d/d5/ychen/2015_production_runs/0204_h0_10Myr/',
            '/d/d5/ychen/2016_production_runs/1212_h0_10Myr/'\
           ]


    filenumbers = [100, 200, 600, 910, 1050]

    iterator = []
    for filenumber in filenumbers:
        for dir in dirs:
            iterator.append((filenumber, dir))

    labels = ['toroidal', 'helical', 'poloidal']


    zoom_fac = 6
    extend_cells = 32
    res = (256, 128)
    sigma = 1

    cmap = plt.cm.jet
    #cmap.set_bad('navy')

    fig = plt.figure(figsize=(8,6))

    grid = ImageGrid(fig, (0.075,0.05,0.85,0.90),
                    nrows_ncols = (len(filenumbers), len(dirs)),
                    axes_pad = 0.05,
                    label_mode = "L",
                    share_all = True,
                    cbar_location="right",
                    cbar_mode="single",
                    cbar_size="2%",
                    cbar_pad="0%")

    for i, (filenumber, dir) in enumerate(iterator):
        # Load the data and create a single plot
        ds = yt.load(os.path.join(dir, 'data/MHD_Jet_10Myr_hdf5_plt_cnt_%04d' % filenumber))
        print(dir, filenumber, ds.current_time.in_units('Myr'))
        width = ds.domain_width[[2,1]]/zoom_fac

        fitsname = synchrotron_fits_filename(ds, dir, ptype, proj_axis)
        if not os.path.isfile(fitsname): continue
        hdulist = pyfits.open(fitsname)
        frb_I = {}
        for nu in nus:
            stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
            frb_I[nu] = hdulist[stokes.I[1]].data
        header = hdulist[stokes.I[1]].header
        xr = -header['CRPIX1']*header['CDELT1'] + header['CRVAL1']
        xl = (header['NAXIS1'] - header['CRPIX1'])*header['CDELT1'] + header['CRVAL1']
        yr = -header['CRPIX2']*header['CDELT2'] + header['CRVAL2']
        yl = (header['NAXIS2'] - header['CRPIX2'])*header['CDELT2'] + header['CRVAL2']
        ext = ds.arr([yr, yl, xr, xl], input_units='cm').in_units('kpc')
        
#         ds_sync = yt.load(synchrotron_filename(ds, extend_cells=extend_cells))

#         # Setting up units and coordinates (we want z-y figures)
#         ds_sync.field_list
#         ds_sync.coordinates.x_axis['x'] = 2
#         ds_sync.coordinates.x_axis[0] = 2
#         ds_sync.coordinates.y_axis['x'] = 1
#         ds_sync.coordinates.y_axis[0] = 1
#         frb_I = {}
#         for nu in nus:
#             stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
#             if stokes.I not in ds_sync.field_list: continue
#             if proj_axis in ['x','y','z']:
#                 p = yt.ProjectionPlot(ds_sync, proj_axis, stokes.I, center=[0,0,0], width=width, max_level=6)
#                 frb_I[nu] = p.frb.data[stokes.I].v
#             else:
#                 p = yt.OffAxisProjectionPlot(ds_sync, proj_axis, stokes.I, center=[0,0,0], width=width, north_vector=[0,1,0])
#                 frb_I[nu] = p.frb.data[stokes.I].v

        nu1, nu2 = nus

        I1 = gaussian_filter(frb_I[nu1], sigma)
        I2 = gaussian_filter(frb_I[nu2], sigma)
        alpha = np.log10(I2/I1)/np.log10(nu2[0]/nu1[0])
        alpha = np.ma.masked_where(I2<1E-7, np.array(alpha))
        #ext = ds.arr([-0.5*width[0], 0.5*width[0], -0.5*width[1], 0.5*width[1]]).in_units('kpc')

        #print(ext)
        ax = grid[i].axes
        im = ax.imshow(alpha.transpose(), cmap=cmap, vmin=-2, vmax=-0.5, extent=ext, origin='lower', aspect='equal')
        ax.set_facecolor('navy')
        cbar = grid.cbar_axes[i].colorbar(im)
        cbar.ax.tick_params(direction='in')
        #cbar.ax.set_yticks([-0.5, -0])

        if i // len(dirs) == 0:
            ax.annotate(labels[i % len(dir)], (0.65, 0.75) , xycoords='axes fraction', color='white')
        if i % len(dirs) == 0:
            timestamp = '%.1f Myr' % ds.current_time.in_units('Myr')
            ax.annotate(timestamp, (0.04, 0.75) , xycoords='axes fraction', color='white')
    return fig, grid

In [ ]:
nus = [(150, 'MHz'), (1400, 'MHz')]
fig, grid = plot_synchrotron_spectralindex_imagegrid('x', nus)

In [ ]:
nus = [(150, 'MHz'), (1400, 'MHz')]
fig, grid = plot_synchrotron_spectralindex_imagegrid([1,0,2], nus)

In [ ]:
clabel = 'Spectral Index (%s) (1.4GHz/150MHz)' % ptype
cax = grid.cbar_axes[0]
cax.set_ylabel(clabel)

for i, ax in enumerate(grid.axes_all):
    #ax.tick_params(axis='x', color='grey')
    ax.tick_params(color='grey', direction='in')
    ax.grid(ls='--', alpha=0.5)
    if i == 6:
        ax.set_ylabel('y (kpc)')
    if i == 13:
        ax.set_xlabel('z (kpc)')
        #print(ax.get_xlabel())

fig.subplots_adjust(left=0.2, bottom=0.2, right=0.5)
fig.set_figwidth(8)
#fig.set_figheight(11)
fig

In [ ]:
# For x projection
for i, ax in enumerate(grid.axes_all):
    ax.set_yticks([-25,0,25])
    ax.set_yticklabels([-25,0,25])
    ax.set_xticks([-75,-50,-25,0,25,50,75])
    ax.set_xticklabels(['',-50,'',0,'',50,''])

fig.savefig('synchrotron_spectralindex_x.pdf', bbox_inches='tight')

In [ ]:
# For [1,0,2] projection
for i, ax in enumerate(grid.axes_all):
    ax.set_yticks([-25,0,25])
    ax.set_yticklabels([-25,0,25])
    ax.set_xticks([-50,-25,0,25,50])
    ax.set_xticklabels([-50,'',0,'',50])

fig.savefig('synchrotron_spectralindex_1_0_2.pdf', bbox_inches='tight')