In [ ]:
import batoid
import os
import yaml
import numpy as np
from scipy.interpolate import griddata
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import galsim
from tqdm import tqdm
%matplotlib inline

In [ ]:
telescope = batoid.Optic.fromYaml("LSST_r.yaml")
# fiducial_telescope.clearObscuration()
# Re-insert M1 outer edge obscuration though.
# fiducial_telescope.itemDict['LSST.M1'].obscuration = batoid.ObscNegation(batoid.ObscCircle(8.360001/2))

In [ ]:
def simpleGridData(inputCoords, inputValue, outputGrid):
    output = griddata(
        inputCoords,
        inputValue,
        outputGrid,
        method='cubic'
    )
    nearest = griddata(
        inputCoords,
        inputValue,
        outputGrid,
        method='nearest'
    )
    w = np.isfinite(output)
    # Fill in cubic NaNs with nearest
    output[~w] = nearest[~w]
    return output

In [ ]:
def betterGridData(inputCoords, inputValue, outputGrid, k=20):
    output = griddata(inputCoords, inputValue, outputGrid, method='cubic')
    tree = KDTree(inputCoords)
    _, neighbors = tree.query(outputGrid, k=k)
    toFix = np.nonzero(~np.isfinite(output))[0]
    for idx in toFix:
        neighbor = neighbors[idx]
        basis = galsim.zernike.zernikeBasis(
            10, 
            inputCoords[neighbor, 0], 
            inputCoords[neighbor, 1]
        )
        coefs, _, _, _ = np.linalg.lstsq(
            basis.T,
            inputValue[neighbor],
            rcond=None
        )
        value = galsim.zernike.Zernike(coefs).evalCartesian(
            *outputGrid[idx]
        )
        output[idx] = value
    return output

In [ ]:
M1M3dir = "/Users/josh/src/M1M3_ML/data/"
M1M3fn = os.path.join(M1M3dir, "M1M3_1um_156_grid.txt")
M1M3data = np.loadtxt(M1M3fn).T

w1 = np.where(M1M3data[0] == 1)[0]
w3 = np.where(M1M3data[0] == 3)[0]

M1x = M1M3data[1][w1]
M1y = M1M3data[2][w1]

M3x = M1M3data[1][w3]
M3y = M1M3data[2][w3]

M1modes = M1M3data[3:23, w1]  # First 20 mirror modes
M3modes = M1M3data[3:23, w3]

# Modes are currently in the direction normal to surface.  
# We want perturbations in the direction of the optic axis.
# So divide by the z-component of the surface normal.
M1modes /= telescope['M1'].surface.normal(M1x, M1y)[:,2]
M3modes /= telescope['M3'].surface.normal(M3x, M3y)[:,2]

M1points = np.column_stack([M1x, M1y])
M3points = np.column_stack([M3x, M3y])

xgrid = np.linspace(-4.2, 4.2, 100)
xgrid, ygrid = np.meshgrid(xgrid, xgrid)
xgrid = xgrid.ravel()
ygrid = ygrid.ravel()
rgrid = np.hypot(xgrid, ygrid)

centerX, centerY = np.array(np.meshgrid(
    np.linspace(-0.15, 0.15, 10), 
    np.linspace(-0.15, 0.15, 10)
))
centerPoints = np.array(np.stack([centerX.ravel(), centerY.ravel()])).T

In [ ]:
M1ModeBetterGrid = []
M3ModeBetterGrid = []
for i in tqdm(range(20)):
    M1cubic = betterGridData(
        np.vstack([M1points, centerPoints]),
        np.hstack([M1modes[i], np.zeros(len(centerPoints))]),
        np.column_stack([xgrid, ygrid])
    )
    M1ModeBetterGrid.append(M1cubic)
    
    M3cubic = betterGridData(
        np.vstack([M3points, centerPoints]),
        np.hstack([M3modes[i], np.zeros(len(centerPoints))]),
        np.column_stack([xgrid, ygrid])
    )
    M3ModeBetterGrid.append(M3cubic)

M1ModeBetterGrid = np.array(M1ModeBetterGrid)
M3ModeBetterGrid = np.array(M3ModeBetterGrid)

In [ ]:
M1ModeSimpleGrid = []
M3ModeSimpleGrid = []
for i in tqdm(range(20)):
    M1cubic = simpleGridData(
        np.vstack([M1points, centerPoints]),
        np.hstack([M1modes[i], np.zeros(len(centerPoints))]),
        np.column_stack([xgrid, ygrid])
    )
    M1ModeSimpleGrid.append(M1cubic)
    
    M3cubic = simpleGridData(
        np.vstack([M3points, centerPoints]),
        np.hstack([M3modes[i], np.zeros(len(centerPoints))]),
        np.column_stack([xgrid, ygrid])
    )
    M3ModeSimpleGrid.append(M3cubic)

M1ModeSimpleGrid = np.array(M1ModeSimpleGrid)
M3ModeSimpleGrid = np.array(M3ModeSimpleGrid)

In [ ]:
r = np.hypot(xgrid, ygrid)
wM1 = (r<4.18) & (r > 2.558)
wM3 = (r<2.508) & (r > 0.55)
vmin, vmax = -2, 2

fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(12, 8))
for j, ax in enumerate(axes.ravel()):
    bend = [0]*j+[1]+[0]*(19-j)
    bend = np.array(bend)*1e6
    ax.scatter(
        xgrid, 
        ygrid, 
        c=np.ma.masked_array(np.dot(bend, M3ModeSimpleGrid), mask=~wM3),
        s=5,
        vmin=vmin, vmax=vmax
    )
    ax.scatter(
        xgrid, 
        ygrid, 
        c=np.ma.masked_array(np.dot(bend, M1ModeSimpleGrid), mask=~wM1),
        s=5,
        vmin=vmin, vmax=vmax
    )
fig.tight_layout()
plt.show()

In [ ]:
r = np.hypot(xgrid, ygrid)
wM1 = (r<4.18) & (r > 2.558)
wM3 = (r<2.508) & (r > 0.55)
vmin, vmax = -0.1, 0.1

fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(12, 8))
for j, ax in enumerate(axes.ravel()):
    bend = [0]*j+[1]+[0]*(19-j)
    bend = np.array(bend)*1e6
    ax.scatter(
        xgrid, 
        ygrid, 
        c=np.ma.masked_array(np.dot(bend, M3ModeSimpleGrid-M3ModeBetterGrid), mask=~wM3),
        s=5,
        cmap='seismic',
        vmin=vmin, vmax=vmax
    )
    ax.scatter(
        xgrid, 
        ygrid, 
        c=np.ma.masked_array(np.dot(bend, M1ModeSimpleGrid-M1ModeBetterGrid), mask=~wM1),
        s=5,
        cmap='seismic',
        vmin=vmin, vmax=vmax
    )
fig.tight_layout()
plt.show()

In [ ]:
def rot(thx, thy):
    return np.dot(batoid.RotX(thx), batoid.RotY(thy))

def perturbed_telescope(M2Shift, M2Tilt, cameraShift, cameraTilt, M1M3Bend, M2Bend):
    modes=slice(0,20)
    M1grid = np.dot(M1M3Bend[modes], M1ModeBetterGrid[modes])
    M3grid = np.dot(M1M3Bend[modes], M3ModeBetterGrid[modes])
    M1_delta = batoid.Bicubic(xgrid[0:100], ygrid[::100], M1grid.reshape(100, 100))
    M3_delta = batoid.Bicubic(xgrid[0:100], ygrid[::100], M3grid.reshape(100, 100))
    perturbedM1 = batoid.Sum([telescope['M1'].surface, M1_delta])
    perturbedM3 = batoid.Sum([telescope['M3'].surface, M3_delta])

    return (telescope
            .withGloballyShiftedOptic('M2', M2Shift)
            .withLocallyRotatedOptic('M2', rot(*M2Tilt))
            .withGloballyShiftedOptic('LSSTCamera', cameraShift)
            .withLocallyRotatedOptic('LSSTCamera', rot(*cameraTilt))
            .withSurface('M1', perturbedM1)
            .withSurface('M3', perturbedM3)
        )

In [ ]:
def tracePerturbed(thx, thy, *args, **kwargs):
    ptelescope = perturbed_telescope(*args, **kwargs)
    rays = batoid.RayVector.asPolar(
        optic=telescope, wavelength=625e-9,
        theta_x=np.deg2rad(thx), theta_y=np.deg2rad(thy),
        nrad=100, naz=300,
    )
    ptelescope.trace(rays)
    rays.trimVignetted()

    plt.scatter(rays.x-np.mean(rays.x), rays.y-np.mean(rays.y), s=0.1, alpha=0.3)
    plt.xlim(1*np.array([-10e-6, 10e-6]))
    plt.ylim(1*np.array([-10e-6, 10e-6]))
    plt.gca().set_aspect('equal')
    plt.show()

In [ ]:
ibend = 4
bend = [0]*(ibend)+[1]+[0]*(19-ibend)
tracePerturbed(
    0.0, 0.0, 
    M2Shift=(0.0, 0.0, 0.0),
    M2Tilt=(0.0, 0.0),
    cameraShift=(0.0, 0.0, 0.0),
    cameraTilt=(0.0, 0.0),
    M1M3Bend=np.array(bend)*1.4e-1,
    M2Bend=None
)

In [ ]:
def zernikeField(thxs, thys, *args, **kwargs):
    ptelescope = perturbed_telescope(*args, **kwargs)

    jmax = 28
    zs = np.zeros([len(thxs), jmax+1], dtype=float)
    for i, (thx, thy) in enumerate(zip(thxs, thys)):
        if np.hypot(thx, thy) > 1.76:
            zs[i] = float('nan')
            continue
        zs[i] = batoid.analysis.zernikeGQ(
            ptelescope, np.deg2rad(thx), np.deg2rad(thy), 625e-9, 
            jmax=jmax, rings=10,
            reference='chief'
        )
    return zs

In [ ]:
thx = np.linspace(-1.75, 1.75, 30)
thxs, thys = np.meshgrid(thx, thx)
thxs = thxs.ravel()
thys = thys.ravel()

In [ ]:
z0s = zernikeField(
    thxs, thys,
    M2Shift=(0.0, 0.0, 0.0),
    M2Tilt=(0.0, 0.0),
    cameraShift=(0.0, 0.0, 0.0),
    cameraTilt=(0.0, 0.0),
    M1M3Bend=np.array([0]*20),
    M2Bend=None    
)

In [ ]:
zss = []
for ibend in tqdm(range(20)):
    bend = [0]*ibend+[1]+[0]*(19-ibend)
    zss.append(zernikeField(
        thxs, thys,    
        M2Shift=(0.0, 0.0, 0.0),
        M2Tilt=(0.0, 0.0),
        cameraShift=(0.0, 0.0, 0.0),
        cameraTilt=(0.0, 0.0),
        M1M3Bend=np.array(bend)*1e-1,
        M2Bend=None    
    ))

In [ ]:
for j in range(20):
    fig = plt.figure(figsize=(13, 8))
    batoid.plotUtils.zernikePyramid(
        thxs, thys, zss[j][:, 4:].T-z0s[:,4:].T, 
        title="mode {}".format(j+1), 
        vmin=-2, vmax=2,
        fig=fig
    )

In [ ]:
for zj in range(4, 29):
    fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(12, 8))
    for j, ax in enumerate(axes.ravel()):
        ax.scatter(
            thxs,
            thys,
            c=zss[j][:, zj] - z0s[:, zj],
            s=5,
            vmin=-0.5,
            vmax=0.5,
            cmap='Spectral_r'
        )
        ax.set_aspect('equal')
        ax.set_title(f"Mode {j+1}")
    fig.suptitle("Z{}".format(zj))
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [ ]: