The goal of this notebook is to demonstrate how the DESI stellar templates, which are defined on a semi-regular grid of effective temperature, surface gravity, and iron abundance, can be interpolated onto arbitrary values of Teff, logg, and [Fe/H].
The StarInterp() Class defined below is adopted from code written by Ben Johnson (CfA) under an MIT license (https://github.com/bd-j/prospector/blob/master/prospect/sources/ssp_basis.py#L35-L266) but I'm not sure how to properly attribute it in the documentation of the Class itself.
Briefly, the code uses a Delaunay tessellation scheme to find the simplex in Teff-logg-[Fe/H] parameter space which surrounds the desired (input) physical values and then does a weighted sum of the basis templates based on the weighted position of the point in that simplex.
The notebook also compares the interpolation scheme against spectra interpolated by Carlos Allende Prieto using a custom routine written in Fortran.
In [1]:
%pylab inline
In [2]:
import os
import numpy as np
import fitsio
import matplotlib.pyplot as plt
from desisim.io import read_basis_templates
In [3]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 16})
In [4]:
baseflux, basewave, basemeta, = read_basis_templates('STAR')
In [5]:
mwspath = os.path.join(os.getenv('DESI_ROOT'), 'mocks', 'mws', 'galaxia', 'alpha', 'v0.0.4')
mwsfile = os.path.join(mwspath, 'bricks', '215', '2150p000', 'allsky_galaxia_desi_2150p000.fits')
data = fitsio.read(mwsfile, upper=True, ext=1)
data.dtype.names
Out[5]:
In [6]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))
ax1.scatter(10**data['TEFF'], data['LOGG'], s=5, alpha=0.5,
color='powderblue', label='MWS/Galaxia Mock')
ax1.scatter(basemeta['TEFF'], basemeta['LOGG'], label='Basis Templates')
ax1.set_xscale('log')
ax1.set_xticks([3000, 5000, 1E4, 2E4])
ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.set_xlabel('$T_{eff}$ (K)')
ax1.set_ylabel('$\log g$ (cm s$^{-2}$)')
ax1.legend(loc='lower right', fontsize=10, markerscale=1.0)
ax2.scatter(10**data['TEFF'], data['FEH'], s=5, alpha=0.5, color='powderblue')
ax2.scatter(basemeta['TEFF'], basemeta['FEH'])
ax2.set_xscale('log')
ax2.set_xticks([3000, 5000, 1E4, 2E4])
ax2.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax2.set_xlabel('$T_{eff}$ (K)')
ax2.set_ylabel('[Fe/H]')
ax3.scatter(data['LOGG'], data['FEH'], s=5, alpha=0.5, color='powderblue')
ax3.scatter(basemeta['LOGG'], basemeta['FEH'])
ax3.set_ylabel('[Fe/H]')
ax3.set_xlabel('$\log g$ (cm s$^{-2}$)')
fig.subplots_adjust(wspace=0.22)
In [7]:
class StarInterp(object):
_spectra = None
def __init__(self, WD=False, verbose=False, n_neighbors=0, log_interp=True,
logify_Z=True, use_params=None, rescale_params=False, **kwargs):
"""An object which holds the stellar spectral library, performs
interpolations of that library, and has methods to return attenuated,
normalized, smoothed stellar spectra. The interpolations are performed
using barycenter coordinates of the enclosing simplex found from the
Delauynay triangulation. This is not tractable for large dimension
(see BigStarBasis for that case).
:param WD: (default: False)
If True, read the white dwarf basis templates. Otherwise, read the (normal)
star templates.
:param n_neighbors: (default:0)
Number of nearest neighbors to use when requested parameters are
outside the convex hull of the library prameters. If ``0`` then a
ValueError is raised instead of the nearest spectrum.
:param verbose:
If True, print information about the parameters used when a point
is outside the convex hull.
:param log_interp: (default:True)
Switch to interpolate in log(flux) instead of linear flux.
:param use_params:
Sequence of strings. If given, only use the listed parameters
(which must be present in the `_libparams` structure) to build the
grid and construct spectra. Otherwise all fields of `_libparams`
will be used.
:param rescale: (default: False)
If True, rescale the parameters to the unit cube before generating
the triangulation. Note that the `param_vector` method will also
rescale the input parameters in this case.
"""
# Cache initialization variables
self.verbose = verbose
self.logarithmic = log_interp
self.logify_Z = logify_Z
self.n_neighbors = n_neighbors
self._rescale = rescale_params
# Load the templates
self.load_templates(WD=WD)
# Do some important bookkeeping
if use_params:
self.stellar_pars = tuple(use_params)
else:
self.stellar_pars = self._libparams.dtype.names
self.ndim = len(self.stellar_pars)
# Build the triangulation and kdtree
self.triangulate()
self.build_kdtree()
self.params = {}
@property
def wavelengths(self):
return self._wave
def load_templates(self, WD=False):
"""Read the DESI stellar or white dwarf basis templates. The ``wavelengths``,
``parameters`` and ``spectra`` attributes are ndarrays of shape (nwave,),
(nmodels,), and (nmodels, nwave), respectively and the ``parameters`` array
is a structured array.
"""
from desisim.io import read_basis_templates
if WD:
objtype = 'WD'
stellar_pars = ('TEFF', 'LOGG')
else:
objtype = 'STAR'
stellar_pars = ('TEFF', 'LOGG', 'FEH') # need a case for WDs
self._spectra, self._wave, meta = read_basis_templates(objtype)
# Extract the parameter files from the astropy.Table into a np.ndarray.
params = np.ndarray( (len(meta), ), dtype=[(p, 'f4') for p in stellar_pars])
for pp in stellar_pars:
params[pp] = meta[pp]
self._libparams = params
if self.logify_Z and 'Z' in self._libparams.dtype.names:
from numpy.lib import recfunctions as rfn
self._libparams['Z'] = np.log10(self._libparams['Z'])
rfn.rename_fields(self._libparams, {'Z': 'logZ'})
def triangulate(self):
"""Build the Delauynay Triangulation of the model library."""
from scipy.spatial import Delaunay
# slow. should use a view based method
model_points = np.array([list(self._libparams[d]) for d in self.stellar_pars]).T
if self._rescale:
self.parameter_range = np.array([model_points.min(axis=0), model_points.max(axis=0)])
model_points = self.rescale_params(model_points)
self._dtri = Delaunay(model_points)
def build_kdtree(self):
"""Build the kdtree of the model points."""
# slow. should use a view based method
from scipy.spatial import cKDTree as KDTree
model_points = np.array([list(self._libparams[d])
for d in self.stellar_pars])
self._kdt = KDTree(model_points.T)
def rescale_params(self, points):
if self._rescale:
x = np.atleast_2d(points)
x = (x - self.parameter_range[0, :]) / np.diff(self.parameter_range, axis=0)
return np.squeeze(x)
else:
return points
def param_vector(self, **kwargs):
"""Take a dictionary of parameters and return the stellar library
parameter vector corresponding to these parameters as an ndarray.
Raises a KeyError if the dictionary does not contain *all* of the
required stellar parameters.
"""
pvec = [kwargs[n] for n in self.stellar_pars]
return self.rescale_params(np.array(pvec))
def weights_kNN(self, target_points, k=1):
"""The interpolation weights are determined from the inverse distance
to the k nearest neighbors.
:param target_points: ndarray, shape(ntarg,npar)
The coordinates to which you wish to interpolate.
:param k:
The number of nearest neighbors to use.
:returns inds: ndarray, shape(ntarg,npar+1)
The model indices of the interpolates.
:returns weights: narray, shape (ntarg,npar+1)
The weights of each model given by ind in the interpolates.
"""
# In some cases the distances are all the same...
dists, inds = self._kdt.query(target_points, k=k)
inds = np.atleast_1d(np.squeeze(inds))
if k == 1:
return inds, np.ones(inds.shape)
weights = 1 / dists
# weights[np.isinf(weights)] = large_number
weights = weights / weights.sum(axis=-1)
return inds, np.atleast_1d(np.squeeze(weights))
def weights(self, **kwargs):
"""Delauynay weighting. Return indices of the models forming the
enclosing simplex, as well as the barycentric coordinates of the point
within this simplex to use as weights. If point is outside the convex
hull then fallback to nearest neighbor unless ``n_neighbors`` is 0.
"""
from itertools import chain
inparams = np.squeeze(self.param_vector(**kwargs))
triangle_ind = self._dtri.find_simplex(inparams)
if triangle_ind == -1:
self.edge_flag = True
if self.n_neighbors == 0:
pstring = ', '.join(self.ndim * ['{}={}'])
pstring = pstring.format(*chain(*zip(self.stellar_pars, inparams)))
raise ValueError("Requested spectrum ({}) outside convex hull,"
" and nearest neighbor interpolation turned "
"off.".format(*pstring))
ind, wght = self.weights_kNN(inparams, k=self.n_neighbors)
if self.verbose:
print("Parameters {0} outside model convex hull. "
"Using model index {1} instead. ".format(inparams, ind))
return ind, wght
inds = self._dtri.simplices[triangle_ind, :]
transform = self._dtri.transform[triangle_ind, :, :]
Tinv = transform[:self.ndim, :]
x_r = inparams - transform[self.ndim, :]
bary = np.dot(Tinv, x_r)
last = 1.0 - bary.sum()
wghts = np.append(bary, last)
oo = inds.argsort()
return inds[oo], wghts[oo]
def get_spectrum(self, **kwargs):
"""Given stellar parameters, obtain an interpolated spectrum at those
parameters.
:param **kwargs:
Keyword arguments must include values for the parameters listed in
``stellar_pars``.
:returns wave:
The wavelengths at which the spectrum is defined.
:returns spec:
The spectrum interpolated to the requested parameters. This has
the same units as the supplied library spectra.
:returns unc:
The uncertainty spectrum, where the uncertainty is due to
interpolation error. Curently unimplemented (i.e. it is a None
type object).
"""
inds, wghts = self.weights(**kwargs)
if self.logarithmic:
spec = np.exp(np.dot(wghts, np.log(self._spectra[inds, :])))
else:
spec = np.dot(wghts, self._spectra[inds, :])
spec_unc = None
return self._wave, spec, spec_unc
In [8]:
Star = StarInterp(n_neighbors=5)
In [9]:
teff, feh, logg = 6253, -0.06, 3.22
%timeit wave, flux, ferr = Star.get_spectrum(TEFF=teff, FEH=feh, LOGG=logg)
In [10]:
indx, wght = Star.weights(TEFF=teff, FEH=feh, LOGG=logg)
print(indx, wght)
In [11]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))
ax1.scatter(Star._libparams['TEFF'], Star._libparams['LOGG'],
alpha=0.5, color='powderblue', label='Basis Templates')
ax1.scatter(Star._libparams['TEFF'][indx], Star._libparams['LOGG'][indx],
edgecolor='k', label='"Nearest" Template(s)')
ax1.plot(teff, logg, 's', color='red', label='Input Values')
ax1.legend(loc='lower right', fontsize=10)
ax1.set_xscale('log')
ax1.set_xticks([3000, 5000, 1E4, 2E4])
ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.set_xlabel('$T_{eff}$ (K)')
ax1.set_ylabel('$\log g$ (cm s$^{-2}$)')
ax2.scatter(Star._libparams['TEFF'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax2.scatter(Star._libparams['TEFF'][indx], Star._libparams['FEH'][indx],
edgecolor='k')
ax2.plot(teff, feh, 's', color='red')
ax2.set_ylim(-3.5, 1)
ax2.set_xscale('log')
ax2.set_xticks([3000, 5000, 1E4, 2E4])
ax2.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax2.set_xlabel('$T_{eff}$ (K)')
ax2.set_ylabel('[Fe/H]')
ax3.scatter(Star._libparams['LOGG'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax3.scatter(Star._libparams['LOGG'][indx], Star._libparams['FEH'][indx],
edgecolor='k')
ax3.plot(logg, feh, 's', color='red')
ax3.set_ylabel('[Fe/H]')
ax3.set_ylim(-3.5, 1)
ax3.set_xlabel('$\log g$ (cm s$^{-2}$)')
plt.suptitle('Input Values Within Model Grid')
fig.subplots_adjust(wspace=0.22)
In [12]:
teff2, feh2, logg2 = 6253, -3.1, 4.5
%timeit wave, flux, ferr = Star.get_spectrum(TEFF=teff2, FEH=feh2, LOGG=logg2)
#wave, flux, ferr = Star.get_spectrum(TEFF=teff, FEH=feh, LOGG=logg)
In [13]:
indx2, wght2 = Star.weights(TEFF=teff2, FEH=feh2, LOGG=logg2)
print(indx2, wght2)
In [14]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))
ax1.scatter(Star._libparams['TEFF'], Star._libparams['LOGG'],
alpha=0.5, color='powderblue', label='Basis Templates')
ax1.scatter(Star._libparams['TEFF'][indx2], Star._libparams['LOGG'][indx2],
edgecolor='k', label='"Nearest" Template(s)')
ax1.plot(teff2, logg2, 's', color='red', label='Input Values')
ax1.legend(loc='lower right', fontsize=10, frameon=True)
ax1.set_xscale('log')
ax1.set_xticks([3000, 5000, 1E4, 2E4])
ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.set_xlabel('$T_{eff}$ (K)')
ax1.set_ylabel('$\log g$ (cm s$^{-2}$)')
ax2.scatter(Star._libparams['TEFF'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax2.scatter(Star._libparams['TEFF'][indx2], Star._libparams['FEH'][indx2],
edgecolor='k')
ax2.plot(teff2, feh2, 's', color='red')
ax2.set_ylim(-3.5, 1)
ax2.set_xscale('log')
ax2.set_xticks([3000, 5000, 1E4, 2E4])
ax2.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax2.set_xlabel('$T_{eff}$ (K)')
ax2.set_ylabel('[Fe/H]')
ax3.scatter(Star._libparams['LOGG'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax3.scatter(Star._libparams['LOGG'][indx2], Star._libparams['FEH'][indx2],
edgecolor='k')
ax3.plot(logg2, feh2, 's', color='red')
ax3.set_ylim(-3.5, 1)
ax3.set_ylabel('[Fe/H]')
ax3.set_xlabel('$\log g$ (cm s$^{-2}$)')
plt.suptitle('Input Values Outside Model Grid')
fig.subplots_adjust(wspace=0.22)
In [15]:
opath = os.path.join(os.getenv('DESI_ROOT'), 'data', 'mwsinterp')
oflux, odata = list(), list()
for interp in ('o1', 'o2', 'o3'):
ofile = os.path.join(opath, interp, 'nsc1.fits')
odata.append(fitsio.read(ofile, ext=1))
oflux.append(fitsio.read(ofile, ext=2))
if interp == 'o1': # all the parameter values are the same
nstar, npix = oflux[0].shape
oparam = {'FEH': odata[0]['OPARAM'][:, 0],
'TEFF': odata[0]['OPARAM'][:, 1],
'LOGG': odata[0]['OPARAM'][:, 2]}
In [16]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))
ax1.scatter(Star._libparams['TEFF'], Star._libparams['LOGG'],
alpha=0.5, color='powderblue', label='Basis Templates')
ax1.scatter(oparam['TEFF'], oparam['LOGG'], label='Input Values')
ax1.legend(loc='lower right', fontsize=10, frameon=True)
ax1.set_xscale('log')
ax1.set_xticks([3000, 5000, 1E4, 2E4])
ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.set_xlabel('$T_{eff}$ (K)')
ax1.set_ylabel('$\log g$ (cm s$^{-2}$)')
ax2.scatter(Star._libparams['TEFF'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax2.scatter(oparam['TEFF'], oparam['FEH'])
ax2.set_xscale('log')
ax2.set_xticks([3000, 5000, 1E4, 2E4])
ax2.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax2.set_xlabel('$T_{eff}$ (K)')
ax2.set_ylabel('[Fe/H]')
ax3.scatter(Star._libparams['LOGG'], Star._libparams['FEH'],
alpha=0.5, color='powderblue')
ax3.scatter(oparam['LOGG'], oparam['FEH'])
ax3.set_ylabel('[Fe/H]')
ax3.set_xlabel('$\log g$ (cm s$^{-2}$)')
fig.subplots_adjust(wspace=0.22)
In [17]:
wave = Star.wavelengths
flux = np.ndarray( (nstar, npix) ).astype('f4')
for ii in range(nstar):
_, flux1, _ = Star.get_spectrum(TEFF=oparam['TEFF'][ii],
FEH=oparam['FEH'][ii],
LOGG=oparam['LOGG'][ii])
flux[ii, :] = flux1
In [18]:
resid = np.ndarray( (3, nstar, len(wave)) )
for ii in range(3):
for jj in range(nstar):
resid[ii, jj, :] = 100 * (flux[jj, :] / oflux[ii][jj, :] - 1)
In [19]:
nshow = 10
seed = 123
rand = np.random.RandomState(seed)
these = rand.choice(nstar, nshow)
print(these)
In [20]:
lbl = ('Linear', 'Quadratic', 'Cubic')
fig, ax = plt.subplots(nshow, 2, figsize=(10, 3 * nshow), sharex=True)
for jj in range(nshow):
ax[jj, 0].plot(wave, flux[these[jj], :], alpha=0.5)
ax[jj, 0].set_yticks([])
ax[jj, 0].set_ylabel('Relative Flux')
ax[jj, 0].text(0.03, 0.94,
'$T_{{eff}}={:.0f}$ K\n $log g={:.2f}$\n [Fe/H]={:.2f}'.format(
oparam['TEFF'][these[jj]], oparam['LOGG'][these[jj]], oparam['FEH'][these[jj]]),
horizontalalignment='left', verticalalignment='top',
transform=ax[jj, 0].transAxes, fontsize=12)
for jj in range(nshow):
ax[jj, 1].set_ylim(-90, 90)
ax[jj, 1].axhline(y=0, ls='--', color='k')
ax[jj, 1].yaxis.tick_right()
ax[jj, 1].yaxis.set_label_position('right')
ax[jj, 1].set_ylabel('Residuals (%)')
#ax[jj, 1].set_xlabel(r'Wavelength (\AA)')
ax[jj, 1].set_xlim(1000, 9500)
for ii in range(1, 3):
ax[jj, 1].plot(wave, resid[ii, these[jj], :],
alpha=0.5, label=lbl[ii])
if jj == 0:
ax[jj, 1].legend(loc='upper right')
fig.text(0.5, 0.12, 'Wavelength ($\AA$)', ha='center')
fig.subplots_adjust(wspace=0.02, hspace=0.02)
Although the method used here should be robust, there are significant systematic residuals blueward of ~5000 A for both the quadratic and cubic interpolation schemes, especially (not surprisingly) at low temperature and high metallicity.
One possibility is that a template grid which is more finely defined in physical parameter space would help. Alternatively, perhaps the quadratic and cubic interpolation schemes are not as close to "truth" as we think.
In [ ]: