In [ ]:
#%pdb
%matplotlib inline
import matplotlib.pyplot as plt
import os
import yt
yt.mylog.setLevel("INFO")
import numpy as np
from yt_synchrotron_emissivity import *
from yt import FITSImageData
from yt.visualization.volume_rendering.off_axis_projection import off_axis_projection
from astropy.wcs import WCS
from astropy import units as u

In [ ]:
fname = '/d/d5/ychen/2015_production_runs/1022_h1_10Myr/data/MHD_Jet_10Myr_hdf5_plt_cnt_0910_synchrotron_peak_gc8'

# Assumed distance to the object
dist_obj = 165.95*yt.units.Mpc
# Assumed coordinate of the object
coord = [229.5, 42.82]

nus = [(150, 'MHz'), (1400, 'MHz')]
zoom_fac = 8
#proj_axis = [1,0,2]
proj_axis = 'x'
ptype = 'lobe'
gc = 8

In [ ]:
ds_sync = yt.load(fname)
ds_sync.field_list

In [ ]:
fields = []

width = ds_sync.domain_width[1:]/zoom_fac
#res = ds_sync.domain_dimensions[1:]*ds_sync.refine_by**ds_sync.index.max_level//zoom_fac
res = [512, 1024] if zoom_fac == 8 else [1024, 2048]

rad = yt.units.rad
cdelt1 = (width[0]/dist_obj/res[0]*rad).in_units('deg')
cdelt2 = (width[1]/dist_obj/res[1]*rad).in_units('deg')

# Setting up wcs header
w = WCS(naxis=2)
# reference pixel coordinate
w.wcs.crpix = [res[0]/2,res[1]/2]
# sizes of the pixel in degrees
w.wcs.cdelt = [cdelt1.base, cdelt2.base]
# converting ra and dec into degrees
w.wcs.crval = coord
# the units of the axes are in degrees
w.wcs.cunit = ['deg']*2
w.wcs.equinox = 2000
wcs_header = w.to_header()

# Assuming beam area = 1 pixel^2
beam_area = cdelt1*cdelt2
beam_axis = np.sqrt(beam_area/2/np.pi)*2*np.sqrt(2*np.log(2))
# Major and minor beam axes
beam_axis = float(beam_axis.in_units('deg').v)

header_dict = {
           'CTYPE1': 'RA---SIN',
           'CTYPE2': 'DEC--SIN',
           'CROTA1': (0, 'Rotation in degrees.'),
           'CROTA2': (0, 'Rotation in degrees.'),
           'CTYPE3': 'FREQ',
           'CUNIT3': 'Hz',
           'BMAJ': (beam_axis, 'Beam major axis (deg)'),
           'BMIN': (beam_axis, 'Beam minor axis (deg)'),
           'BPA': (0.0, 'Beam position angle (deg)')
          }

In [ ]:
for nu in nus:
    stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
    fields += stokes.IQU
    #fields.append(stokes.I)
    ds_sync.unit_registry.add('beam', float(beam_area.in_units('rad**2').v),
                      dimensions=yt.units.dimensions.solid_angle, tex_repr='beam')
    for field in stokes.IQU:
        ds_sync.field_info[field].units = 'Jy/cm/arcsec**2'
        ds_sync.field_info[field].output_units = 'Jy/cm/arcsec**2'
if proj_axis in ['x', 'y', 'z']:
    prj = ds_sync.proj(stokes.I, proj_axis)
    frb = prj.to_frb(width[0], res, height=width[1])
    fits_image = FITSImageData(frb, fields=fields, wcs=w)
else:
    buf = {}
    width = ds_sync.coordinates.sanitize_width(proj_axis, width, (1.0, 'unitary'))
    wd = tuple(w.in_units('code_length').v for w in width)
    for field in fields:
        buf[field] = off_axis_projection(ds_sync, [0,0,0], proj_axis, wd,
                        res, field, north_vector=[1,0,0], num_threads=0).swapaxes(0,1)
    fits_image = FITSImageData(buf, fields=fields, wcs=w)
for nu in nus:
    stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
    field = stokes.I[1]
    fits_image[field].data.units.registry.add('beam', float(beam_area.in_units('rad**2').v),
                      dimensions=yt.units.dimensions.solid_angle, tex_repr='beam')
    fits_image.set_unit(field, 'Jy/beam')
    nu = yt.YTQuantity(*nu)
    header_dict.update({
           'OBJECT': 'Simulation %i %s' % (nu.v, nu.units),
           'CRVAL3': int(nu.in_units('Hz').v)
            })
    fits_image[field].header.update(header_dict)
    fits_image[field].header.update(wcs_header)
#    fits_proj = FITSProjection(ds_sync, proj_axis, fields,
#            center=[0,0,0], width=width, image_res=res)
#else:
#    fits_proj = FITSOffAxisProjection(ds_sync, proj_axis, fields,
#            center=[0,0,0], north_vector=[1,0,0], width=width, image_res=res)

In [ ]:
#proj = yt.ProjectionPlot(ds_sync, proj_axis, fields)
proj.set_buff_size((512, 1024))
print(proj.buff_size)
proj._recreate_frb()
print(proj._frb.data[('flash',field)].shape)
print(proj.frb[field].shape)

In [ ]:
type(proj[field])

In [ ]:
print(res)
print(frb[field].shape)
print(fits_image[field].data.shape)

In [ ]:
fitsfname = 'test.fits'
fits_image.writeto(fitsfname, clobber=True)

In [ ]:
dir = '/d/d5/ychen/2015_production_runs/1022_h1_10Myr/'
ds = yt.load(dir+'data/MHD_Jet_10Myr_hdf5_plt_cnt_0910')

fitsfname = synchrotron_fits_filename(ds, dir, ptype, proj_axis)

In [ ]:
from yt.visualization.fits_image import FITSImageData
fits_read = FITSImageData.from_file(fitsfname)
fits_read.hdulist[0].data.shape

In [ ]:
ds_fits = yt.load(fitsfname)

In [ ]:


In [ ]:
yt.mylog.setLevel('WARNING')
proj_axis = [1,0,2]
for fnumber in ['0600', '0910', '1050']:
    fitsfname = '/d/d5/ychen/2015_production_runs/1022_h1_10Myr/cos_synchrotron_QU_nn_lobe/fits_24_freq/synchrotron_lobe_1_0_2_%s.fits' % fnumber
    fitsds = yt.load(fitsfname)
    fitsds.wcs_2d = fitsds.wcs.celestial

    from itertools import chain
    nus =[(nu, 'MHz') for nu in chain(range(100,200,25), range(200,900,50), range(900,1500,100))]
    #nus = [(100, 'MHz')]

    fields = []
    for nu in nus:
        stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
        fields += [stokes.I[1]]
    slc = yt.SlicePlot(fitsds, 'z', fields)
    slc.set_buff_size(fitsds.domain_dimensions[0:2])

    for nu in nus:
        stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
        field = stokes.I[1]
        norm = yt.YTQuantity(*nu).in_units('GHz').v**0.5
        cmap = plt.cm.hot
        cmap.set_bad('k')
        slc.set_cmap(field, cmap)
        slc.set_log(field, True)
        slc.set_zlim(field, 1E-5/norm, 1E-1/norm)
    slc.save(os.path.dirname(fitsfname))

In [ ]: