In [2]:
%pylab inline
In [3]:
import corner as cn
import desitarget.cuts
from astropy.io import fits
from astropy.table import Table
from sklearn.mixture import GMM
from desitarget.targetmask import desi_mask, bgs_mask
In [4]:
class GaussianMixtureModel(object):
def __init__(self, weights, means, covars, covtype):
self.weights = weights
self.means = means
self.covars = covars
self.covtype = covtype
self.n_components, self.n_dimensions = self.means.shape
@staticmethod
def save(model, filename):
hdus = fits.HDUList()
hdr = fits.Header()
hdr['covtype'] = model.covariance_type
hdus.append(fits.ImageHDU(model.weights_, name='weights', header=hdr))
hdus.append(fits.ImageHDU(model.means_, name='means'))
hdus.append(fits.ImageHDU(model.covars_, name='covars'))
hdus.writeto(filename, clobber=True)
@staticmethod
def load(filename):
hdus = fits.open(filename, memmap=False)
hdr = hdus[0].header
covtype = hdr['covtype']
model = GaussianMixtureModel(
hdus['weights'].data, hdus['means'].data, hdus['covars'].data, covtype)
hdus.close()
return model
def sample(self, n_samples=1, random_state=None):
if self.covtype != 'full':
return NotImplementedError(
'covariance type "{0}" not implemented yet.'.format(self.covtype))
# Code adapted from sklearn's GMM.sample()
if random_state is None:
random_state = np.random.RandomState()
weight_cdf = np.cumsum(self.weights)
X = np.empty((n_samples, self.n_dimensions))
rand = random_state.rand(n_samples)
# decide which component to use for each sample
comps = weight_cdf.searchsorted(rand)
# for each component, generate all needed samples
for comp in range(self.n_components):
# occurrences of current component in X
comp_in_X = (comp == comps)
# number of those occurrences
num_comp_in_X = comp_in_X.sum()
if num_comp_in_X > 0:
X[comp_in_X] = random_state.multivariate_normal(
self.means[comp], self.covars[comp], num_comp_in_X)
return X
In [30]:
def classify_targets(cuts):
lrg_targets = cuts['DESI_TARGET'] & desi_mask.LRG
elg_targets = cuts['DESI_TARGET'] & desi_mask.ELG
qso_targets = cuts['DESI_TARGET'] & desi_mask.QSO
bgs_targets = cuts['BGS_TARGET'] & bgs_mask.BGS_BRIGHT
lrgs = lrg_targets.astype(bool)
elgs = elg_targets.astype(bool)
qsos = qso_targets.astype(bool)
bgs = bgs_targets.astype(bool)
return cuts[lrgs], cuts[elgs], cuts[qsos], cuts[bgs]
def get_bic(data, components_range):
bic = []
#generate bic for each component in the range given
for comp in components_range:
model = GMM(n_components=comp, covariance_type='full')
model.fit(data)
bic.append(model.bic(data))
return bic
def plot_bic(bic, components_range):
fig, ax = plt.subplots(1, 1, figsize=(8,4))
ax.plot(components_range, np.asarray(np.asarray(bic)/100), marker='s', ls='-')
ax.set_xlabel('Number of Gaussian Components')
ax.set_ylabel('Bayesian Information Criterion/100')
plt.title('Optimal number of components = {:d}'.format(np.argmin(bic)))
plt.show()
def flux_to_mag(data):
gflux = data['DECAM_FLUX'][:,1]
rflux = data['DECAM_FLUX'][:,2]
zflux = data['DECAM_FLUX'][:,4]
#only keep non-zero, non-negative flux values to convert to magnitudes
keep = (gflux > 0) & (rflux > 0) & (zflux > 0)
gg = 22.5-2.5*np.log10(gflux[keep])
rr = 22.5-2.5*np.log10(rflux[keep])
zz = 22.5-2.5*np.log10(zflux[keep])
return gg, rr, zz
def make_gmm_model(X_data, components_range, model_filename, seed=123, bic_plot=False):
#list of bic values for given range of components
bic = get_bic(X_data, components_range)
#option to plot bic values
if bic_plot:
plot_bic(bic, components_range)
#index of lowest bic value gives the optimal number of components
n_comp = np.argmin(bic)
gen = np.random.RandomState(seed)
model = GMM(n_components=n_comp, covariance_type="full", random_state=gen).fit(X_data)
GaussianMixtureModel.save(model, model_filename)
print('Saved GMM as {:s}.'.format(model_filename))
In [6]:
def dr2_data(object_data):
gg, rr, zz = flux_to_mag(object_data)
return np.array([gg, rr, zz]).T
def sample_magnitudes(target_type, n_targets, random_state=None):
if target_type == 'LRG':
model = GaussianMixtureModel.load('data/lrgMag_gmm.fits')
elif target_type == 'ELG':
model = GaussianMixtureModel.load('data/elgMag_gmm.fits')
elif target_type == 'QSO':
model = GaussianMixtureModel.load('data/qsoMag_gmm.fits')
elif target_type == 'BGS':
model = GaussianMixtureModel.load('data/bgsMag_gmm.fits')
return model.sample(n_targets, random_state)
In [31]:
#import selection dr2 targets that have passed selection cuts
cuts = Table.read('data/all_cuts.fits', format='fits')
lrg, elg, qso, bgs = classify_targets(cuts)
In [12]:
#rng for sampling
seed = 123
gen = np.random.RandomState(seed)
#number of components to test for bic
components_range = range(1,36)
In [36]:
lrgMag = dr2_data(lrg)
#make sure data, sample and cross-validation sets are of same size for comparison
N_lrg_tot = len(lrgMag)
N_lrg_half = np.floor(N_lrg_tot/2.).astype(int)
lrgMag_data = lrgMag[:N_lrg_half]
lrgMag_cross = lrgMag[N_lrg_half:N_lrg_tot]
print(len(lrgMag_data), len(lrgMag_cross))
In [37]:
make_gmm_model(lrgMag_data, components_range, model_filename='data/lrgMag_gmm.fits',
bic_plot=True)
In [38]:
lrgMag_sample = sample_magnitudes('LRG', n_targets=len(lrgMag_data), random_state=gen)
In [39]:
axes_range = [(15,27), (15,27), (15,27)]
#Data
fig1 = cn.corner(lrgMag_data, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Sample
fig2 = cn.corner(lrgMag_sample, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Cross-validation
fig3 = cn.corner(lrgMag_cross, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
In [22]:
elgMag = dr2_data(elg)
#make sure data, sample and cross-validation sets are of same size for comparison
N_elg_tot = np.floor(len(elgMag)/5.).astype(int)
N_elg_half = np.floor(N_elg_tot/2.).astype(int)
elgMag_data = elgMag[:N_elg_half]
elgMag_cross = elgMag[N_elg_half:N_elg_tot]
print(len(elgMag_data), len(elgMag_cross))
In [23]:
make_gmm_model(elgMag_data, components_range, model_filename='data/elgMag_gmm.fits',
bic_plot=True)
In [24]:
elgMag_sample = sample_magnitudes('ELG', n_targets=len(elgMag_data), random_state=gen)
In [25]:
axes_range = [(12,26), (12,26), (12,26)]
#Data
fig1 = cn.corner(elgMag_data, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Sample
fig2 = cn.corner(elgMag_sample, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Cross-validation
fig3 = cn.corner(elgMag_cross, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
In [26]:
qsoMag = dr2_data(qso)
#make sure data, sample and cross-validation sets are of same size for comparison
N_qso_tot = len(qsoMag)
N_qso_half = np.floor(N_qso_tot/2.).astype(int)
qsoMag_data = qsoMag[:N_qso_half]
qsoMag_cross = qsoMag[N_qso_half:N_qso_tot]
print(len(qsoMag_data), len(qsoMag_cross))
In [27]:
make_gmm_model(qsoMag_data, components_range, model_filename='data/qsoMag_gmm.fits',
bic_plot=True)
In [28]:
qsoMag_sample = sample_magnitudes('QSO', n_targets=len(qsoMag_data), random_state=gen)
In [29]:
axes_range = [(14,28), (14,28), (14,28)]
#Data
fig1 = cn.corner(qsoMag_data, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Sample
fig2 = cn.corner(qsoMag_sample, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Cross-validation
fig3 = cn.corner(qsoMag_cross, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
In [35]:
bgsMag = dr2_data(bgs)
#make sure data, sample and cross-validation sets are of same size for comparison
N_bgs_tot = len(bgsMag)/2 #divide by 2 so size of data sets is closer to those of other targets
N_bgs_half = np.floor(N_bgs_tot/2.).astype(int)
bgsMag_data = bgsMag[:N_bgs_half]
bgsMag_cross = bgsMag[N_bgs_half:N_bgs_tot]
print(len(bgsMag_data), len(bgsMag_cross))
In [13]:
make_gmm_model(bgsMag_data, components_range, model_filename='data/bgsMag_gmm.fits',
bic_plot=True)
In [15]:
bgsMag_sample = sample_magnitudes('BGS', n_targets=len(bgsMag_data), random_state=gen)
In [16]:
axes_range = [(7,26), (7,26), (7,26)]
#Data
fig1 = cn.corner(bgsMag_data, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Sample
fig2 = cn.corner(bgsMag_sample, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)
#Cross-validation
fig3 = cn.corner(bgsMag_cross, labels=[r"$g$", r"$r$", r"$z$"], show_titles=True,
title_kwargs={"fontsize": 12}, range=axes_range)