In [ ]:
import batoid
import numpy as np
from IPython.display import display
from ipywidgets import interact, interactive_output, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

In [ ]:
try:
    import galsim
except:
    has_galsim = False
else:
    has_galsim = True

In [ ]:
fiducial_telescope = batoid.Optic.fromYaml("DECam.yaml")

In [ ]:
def spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax):
    rays = batoid.RayVector.asPolar(
        optic=telescope, 
        inner=telescope.pupilObscuration*telescope.pupilSize/2,
        theta_x=np.deg2rad(theta_x), theta_y=np.deg2rad(theta_y),
        nrad=48, naz=192, wavelength=wavelength*1e-9
    )

    telescope.trace(rays)
    rays.trimVignetted()
    spots = np.vstack([rays.x, rays.y])
    spots -= np.mean(spots, axis=1)[:,None]
    spots *= 1e6 # meters -> microns

    ax.scatter(spots[0], spots[1], s=1, alpha=0.5)
    ax.set_xlim(-10**logscale, 10**logscale)
    ax.set_ylim(-10**logscale, 10**logscale)
    ax.set_title(r"$\theta_x = {:4.2f}\,,\theta_y = {:4.2f}$".format(theta_x, theta_y))
    ax.set_xlabel("microns")
    ax.set_ylabel("microns")

In [ ]:
def wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax):
    wf = batoid.analysis.wavefront(
        telescope, 
        np.deg2rad(theta_x), np.deg2rad(theta_y), 
        wavelength*1e-9, 
        nx=128
    )    
    wfplot = ax.imshow(
        wf.array,
        extent=np.r_[-1,1,-1,1]*telescope.pupilSize/2
    )
    ax.set_xlabel("meters")
    ax.set_ylabel("meters")
    plt.colorbar(wfplot, ax=ax)

In [ ]:
def fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    fft = batoid.analysis.fftPSF(
        telescope, 
        np.deg2rad(theta_x), np.deg2rad(theta_y), 
        wavelength*1e-9, nx=32
    )
    # We should be very close to primitive vectors that are a multiple of
    # [1,0] and [0,1].  If the multiplier is negative though, then this will
    # make it look like our PSF is upside-down.  So we check for this here and 
    # invert if necessary.  This will make it easier to compare with the spot 
    # diagram, for instance
    if fft.primitiveVectors[0,0] < 0:
        fft.array = fft.array[::-1,::-1]

    scale = np.sqrt(np.abs(np.linalg.det(fft.primitiveVectors)))
    nxout = fft.array.shape[0]
    fft.array /= np.sum(fft.array)
    fftplot = ax.imshow(
        fft.array,
        extent=np.r_[-1,1,-1,1]*scale*nxout/2*1e6
    )
    ax.set_title("FFT PSF")
    ax.set_xlabel("micron")
    ax.set_ylabel("micron")    
    plt.colorbar(fftplot, ax=ax)

In [ ]:
def huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    huygensPSF = batoid.analysis.huygensPSF(
        telescope, np.deg2rad(theta_x), np.deg2rad(theta_y),
        wavelength*1e-9, nx=32
    )
    # We should be very close to primitive vectors that are a multiple of
    # [1,0] and [0,1].  If the multiplier is negative though, then this will
    # make it look like our PSF is upside-down.  So we check for this here and 
    # invert if necessary.  This will make it easier to compare with the spot 
    # diagram, for instance
    if huygensPSF.primitiveVectors[0,0] < 0:
        huygensPSF.array = huygensPSF.array[::-1,::-1]

    huygensPSF.array /= np.sum(huygensPSF.array)    
    scale = np.sqrt(np.abs(np.linalg.det(huygensPSF.primitiveVectors)))
    nxout = huygensPSF.array.shape[0]
    
    huygensplot = plt.imshow(
        huygensPSF.array,
        extent=np.r_[-1,1,-1,1]*scale*nxout/2*1e6
    )
    ax.set_title("Huygens PSF")
    ax.set_xlabel("micron")
    ax.set_ylabel("micron")    
    plt.colorbar(huygensplot, ax=ax)

In [ ]:
what = dict(
    do_spot = widgets.Checkbox(value=True, description='Spot'),
    do_wavefront = widgets.Checkbox(value=True, description='Wavefront'),
    do_fftPSF = widgets.Checkbox(value=True, description='FFT PSF'),
    do_huygensPSF = widgets.Checkbox(value=True, description='Huygens PSF')
)
where = dict(
    wavelength=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0, description="$\lambda$ (nm)"),
    theta_x=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=-0.5, description="$\\theta_x (deg)$"),
    theta_y=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=0.0, description="$\\theta_y (deg)$"),
    logscale=widgets.FloatSlider(min=1, max=3, step=0.1, value=1, description="scale")
)
perturb = dict(
    optic=widgets.Dropdown(
        options=fiducial_telescope.itemDict.keys(), 
        value='BlancoDECam.DECam'
    ),
    dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dx ($mm$)"),
    dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dy ($mm$)"),
    dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0, description="dz ($\mu m$)"),
    dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_x$ (arcmin)"),
    dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_y$ (arcmin)"),
)

def f(do_spot, do_wavefront, do_fftPSF, do_huygensPSF,
    wavelength, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy, logscale, **kwargs):

    telescope = (fiducial_telescope
            .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
            .withLocallyRotatedOptic(optic, batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60)))
    )

    nplot = sum([do_spot, do_wavefront, do_fftPSF, do_huygensPSF])
    
    if nplot > 0:
        fig, axes = plt.subplots(ncols=nplot, figsize=(4*nplot, 4), squeeze=False)

        iax = 0
        if do_spot:
            ax = axes.ravel()[iax]
            spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax)
            iax += 1

        if do_wavefront:
            ax = axes.ravel()[iax]
            wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_fftPSF:
            ax = axes.ravel()[iax]
            fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_huygensPSF:
            ax = axes.ravel()[iax]
            huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax)

        fig.tight_layout()

all_widgets = {}
for d in [what, where, perturb]:
    for k in d:
        all_widgets[k] = d[k]

output = interactive_output(f, all_widgets)
display(widgets.VBox([widgets.HBox([
    widgets.VBox([v for v in what.values()]),
    widgets.VBox([v for v in where.values()]),
    widgets.VBox([v for v in perturb.values()])]),
    output])
)

In [ ]:
if has_galsim:  
    @interact(wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0,
                                          description="$\lambda$ (nm)"),
              theta_x=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=-0.5,
                                          description="$\\theta_x (deg)$"),
              theta_y=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=0.0,
                                          description="$\\theta_y (deg)$"),
              optic=widgets.Dropdown(
                  options=fiducial_telescope.itemDict.keys(),
                  value='BlancoDECam.DECam'
              ),
              dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dx ($mm$)"),
              dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dy ($mm$)"),
              dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0,
                                     description="dz ($\mu m$)"),
              dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_x$ (arcmin)"),
              dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_y$ (arcmin)"))
    def zernike(wavelen, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy):
        telescope = (fiducial_telescope
                .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
                .withLocallyRotatedOptic(
                        optic,
                        batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60))
                )
        )

        z = batoid.analysis.zernike(
            telescope, np.deg2rad(theta_x), np.deg2rad(theta_y), wavelen*1e-9, 
            jmax=22, eps=0.1, nx=128
        )
        for i in range(1, len(z)//2+1):
            print("{:6d}   {:6.3f}      {:6d}  {:6.3f}".format(i, z[i], i+11, z[i+11]))

In [ ]:
if has_galsim:
    @interact_manual(
        wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0,
                                    description="$\lambda$ (nm)"),
        optic=widgets.Dropdown(
            options=fiducial_telescope.itemDict.keys(), 
            value='BlancoDECam.DECam'
        ),
        z_coef=widgets.Dropdown(
            options=list(range(1, 56)), value=1,
            description="Zernike coefficient"
        ),
        z_amp=widgets.FloatSlider(min=-0.1, max=0.1, step=0.01, value=0.0,
                                  description="Zernike amplitude"),
        dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                               description="dx ($mm$)"),
        dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                               description="dy ($mm$)"),
        dz=widgets.FloatSlider(min=-500, max=500, step=10, value=0.0,
                               description="dz ($\mu m$)"),
        dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                 description="d$\phi_x$ (arcmin)"),
        dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                 description="d$\phi_y$ (arcmin)"),
        do_resid=widgets.Checkbox(value=False, description="residual?"))
    def zFoV(wavelen, optic, z_coef, z_amp, dx, dy, dz, dthx, dthy, do_resid):
        telescope = (fiducial_telescope
                .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
                .withLocallyRotatedOptic(
                        optic,
                        batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60))
                )
        )
        if z_amp != 0:
            try:
                interface = telescope[optic]
                s0 = interface.surface
            except:
                pass
            else:
                s1 = batoid.Sum([
                    s0,
                    batoid.Zernike(
                        [0]*z_coef+[z_amp*wavelen*1e-9], 
                        R_outer=interface.outRadius,
                        R_inner=interface.inRadius,
                    )
                ])
                telescope = telescope.withSurface(optic, s1)

        thxs = np.linspace(-1.1, 1.1, 15)
        thys = np.linspace(-1.1, 1.1, 15)
        
        img = np.zeros((15, 15), dtype=float)
        vmin = -0.3
        vmax = 0.3
        zs = []
        thxplot = []
        thyplot = []
        for ix, thx in enumerate(thxs):
            for iy, thy in enumerate(thys):
                if np.hypot(thx, thy) > 1.1: 
                    continue
                z = batoid.analysis.zernike(
                    telescope, np.deg2rad(thx), np.deg2rad(thy), wavelen*1e-9,
                    jmax=21, eps=0.61, nx=16
                )
                thxplot.append(thx)
                thyplot.append(thy)
                if do_resid:
                    vmin = -0.05
                    vmax = 0.05
                    z -= batoid.analysis.zernike(
                        fiducial_telescope, np.deg2rad(thx), np.deg2rad(thy), 625e-9,
                        jmax=21, eps=0.61, nx=16
                    )
                zs.append(z)
        zs = np.array(zs).T
        thxplot = np.array(thxplot)
        thyplot = np.array(thyplot)
        fig = plt.figure(figsize=(13, 8))
        batoid.plotUtils.zernikePyramid(thxplot, thyplot, zs[4:], vmin=vmin, vmax=vmax, fig=fig)
        plt.show()
        
        # Compute double Zernike 
        fBasis = galsim.zernike.zernikeBasis(22, thxplot, thyplot, 1.1)
        dzs, _, _, _ = np.linalg.lstsq(fBasis.T, zs.T, rcond=None)
        dzs = dzs[:,4:]
        asort = np.argsort(np.abs(dzs).ravel())[::-1]
        focal_idx, pupil_idx = np.unravel_index(asort[:10], dzs.shape)
        cumsum = 0.0
        for fid, pid in zip(focal_idx, pupil_idx):
            val = dzs[fid, pid]
            cumsum += val**2
            print("{:3d} {:3d} {:8.4f} {:8.4f}".format(fid, pid+4, val, np.sqrt(cumsum)))
        print("sum sqr dz {:8.4f}".format(np.sqrt(np.sum(dzs**2))))

In [ ]: