In the present notebook we introduce the Gaussian Mixture Model (GMM) analysis done in Souza et al. 2017 ( https://arxiv.org/pdf/1703.07607.pdf ) using the python scikit-learn implementation.
Note that results presented in the paper were done using R and hence small differences can be found here. Note also that most figure cosmetic functions are defined separetaly in `gmm_fig_style.py` for readability.
In [1]:
# importing useful libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import itertools
from sklearn.mixture import GaussianMixture
from gmm_fig_style import *
plt.style.use('seaborn-whitegrid')
sns.set_palette(sns.color_palette('Set2', 4))
my_cmap = mpl.cm.colors.ListedColormap(sns.color_palette('Set2', 4),
name='from_list', N=4)
print('Color palette used throughout the notebook :')
sns.palplot(sns.color_palette("Set2",4))
In [2]:
# reading the catalogue
data = pd.read_csv('../Catalogue/GMM_catalogue.csv')
In [3]:
# display some information
data.info()
In [4]:
# A first look at the galaxy distribution in
# the BPT and WHAN diagrams
plt.figure(figsize=(16,6))
ax1 = plt.subplot(121)
data.plot.hexbin(x='xx_BPT_WHAN', y='yy_BPT', mincnt=1,
bins='log', gridsize=101, cmap='viridis', ax=ax1)
ax2 = plt.subplot(122)
data.plot.hexbin(x='xx_BPT_WHAN', y='yy_WHAN', mincnt=1,
bins='log', gridsize=101, cmap='viridis', ax=ax2)
set_plt_style()
In our study we tried several classification method with different parameters to finally focus on the GMM technique fitting for 2, 3 and 4 clusters. I present hereafter how to perform such clustering with GMM using scikit-learn (version 0.18.1).
For each case we visualize the results in three different ways :
In [5]:
# define and train GMM
gmm2 = GaussianMixture(covariance_type='full', n_components=2)
gmm2.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
# return the probability of belonging to a group
proba_gmm2 = gmm2.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels2 = proba_gmm2.argmax(axis=1)
# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm2_proba1'] = proba_gmm2[:,0]
data['gmm2_proba2'] = proba_gmm2[:,1]
In [6]:
# visualizing the resulting classification
plt.figure(figsize=(14,6))
plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels2, s=0.4, cmap=my_cmap)
plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels2, s=0.4, cmap=my_cmap)
set_plt_style()
In [7]:
plt.figure(figsize=(14,6))
ax1 = plt.subplot(121)
for i in enumerate(['C0','C3']):
plot_BPT_ell(gmm2.covariances_[i[0]], gmm2.means_[i[0]], ax=ax1, col=i[1])
ax2 = plt.subplot(122)
for i in enumerate(['C0','C3']):
plot_WHAN_ell(gmm2.covariances_[i[0]], gmm2.means_[i[0]], ax=ax2, col=i[1])
set_plt_style()
Seaborn allows to visualize the density distribution using kde. In the following plot we only focus on point with a probability to belong to a group superior to 50%.
In [8]:
plt.figure(figsize=(14,6))
ax1 = plt.subplot(121)
sns.kdeplot(data[data.gmm2_proba1>0.5].xx_BPT_WHAN, data[data.gmm2_proba1>0.5].yy_BPT,
bw='scott', n_levels=5, ax=ax1, cmap='Greens_d')
sns.kdeplot(data[data.gmm2_proba2>0.5].xx_BPT_WHAN, data[data.gmm2_proba2>0.5].yy_BPT,
bw='scott', n_levels=5, ax=ax1, cmap='RdPu_d')
ax2 = plt.subplot(122)
sns.kdeplot(data[data.gmm2_proba1>0.5].xx_BPT_WHAN, data[data.gmm2_proba1>0.5].yy_WHAN,
bw='scott', n_levels=5, ax=ax2, cmap='Greens_d')
sns.kdeplot(data[data.gmm2_proba2>0.5].xx_BPT_WHAN, data[data.gmm2_proba2>0.5].yy_WHAN,
bw='scott', n_levels=5, ax=ax2, cmap='RdPu_d')
set_plt_style()
In [9]:
# define and train GMM
gmm3 = GaussianMixture(covariance_type='full', n_components=3)
gmm3.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
# return the probability of belonging to a group
proba_gmm3 = gmm3.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels3 = proba_gmm3.argmax(axis=1)
# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm3_proba1'] = proba_gmm3[:,0]
data['gmm3_proba2'] = proba_gmm3[:,1]
data['gmm3_proba3'] = proba_gmm3[:,2]
In [10]:
plt.figure(figsize=(14,6))
plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels3, s=0.4, cmap=my_cmap)
plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels3, s=0.4, cmap=my_cmap)
set_plt_style()
In [11]:
plt.figure(figsize=(14,6))
ax1 = plt.subplot(121)
for i in enumerate(['C0','C2','C3']):
plot_BPT_ell(gmm3.covariances_[i[0]], gmm3.means_[i[0]], ax=ax1, col=i[1])
ax2 = plt.subplot(122)
for i in enumerate(['C0','C2','C3']):
plot_WHAN_ell(gmm3.covariances_[i[0]], gmm3.means_[i[0]], ax=ax2, col=i[1])
set_plt_style()
In [13]:
plt.figure(figsize=(14,6))
ax1 = plt.subplot(121)
sns.kdeplot(data[data.gmm3_proba1>0.5].xx_BPT_WHAN, data[data.gmm3_proba1>0.5].yy_BPT,
bw='scott', n_levels=5, ax=ax1, cmap='Greens_d')
sns.kdeplot(data[data.gmm3_proba2>0.5].xx_BPT_WHAN, data[data.gmm3_proba2>0.5].yy_BPT,
bw='scott', n_levels=5, ax=ax1, cmap='Blues_d')
sns.kdeplot(data[data.gmm3_proba3>0.5].xx_BPT_WHAN, data[data.gmm3_proba3>0.5].yy_BPT,
bw='scott', n_levels=5, ax=ax1, cmap='RdPu_d')
ax2 = plt.subplot(122)
sns.kdeplot(data[data.gmm3_proba1>0.5].xx_BPT_WHAN, data[data.gmm3_proba1>0.5].yy_WHAN,
bw='scott', n_levels=5, ax=ax2, cmap='Greens_d')
sns.kdeplot(data[data.gmm3_proba2>0.5].xx_BPT_WHAN, data[data.gmm3_proba2>0.5].yy_WHAN,
bw='scott', n_levels=5, ax=ax2, cmap='Blues_d')
sns.kdeplot(data[data.gmm3_proba3>0.5].xx_BPT_WHAN, data[data.gmm3_proba3>0.5].yy_WHAN,
bw='scott', n_levels=5, ax=ax2, cmap='RdPu_d')
set_plt_style()
In [14]:
# define and train GMM
gmm4 = GaussianMixture(covariance_type='full', n_components=4)
gmm4.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
# return the probability of belonging to a group
proba_gmm4 = gmm4.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels4 = proba_gmm4.argmax(axis=1)
# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm4_proba1'] = proba_gmm4[:,0]
data['gmm4_proba2'] = proba_gmm4[:,1]
data['gmm4_proba3'] = proba_gmm4[:,2]
data['gmm4_proba4'] = proba_gmm4[:,3]
data['GMM4_group'] = labels4
In [16]:
plt.figure(figsize=(14,6))
plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels4, s=0.4, cmap=my_cmap)
plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels4, s=0.4, cmap=my_cmap)
set_plt_style()