Develop/test confusion matrix code for spectral types [v0.1]


In [38]:
%matplotlib notebook

In [39]:
# imports
import numpy as np
from matplotlib import pyplot as plt

from astropy.table import Table
from desisim.spec_qa.redshifts import match_truth_z

Load Tables


In [3]:
path = '/scratch/DESI_SCRATCH/18.3/'

In [4]:
simz_tab = Table.read(path+'truth.fits')

In [5]:
zb_tab = Table.read(path+'zcatalog-mini.fits')

Match


In [6]:
match_truth_z(simz_tab, zb_tab, mini_read=True)


INFO: Upgrading Table to masked Table. Use Table.filled() to convert to unmasked table. [astropy.table.table]

Now what?


In [8]:
simz_tab


Out[8]:
<Table masked=True length=240902>
TARGETIDMOCKIDCONTAM_TARGETTRUEZTRUESPECTYPETEMPLATETYPETEMPLATESUBTYPETEMPLATEIDSEEDMAGVDISPFLUX_GFLUX_RFLUX_ZFLUX_W1FLUX_W2OIIFLUXHBETAFLUXTEFFLOGGFEHZZERRZWARNSPECTYPEDESI_TARGET
int64int64int64float32str10str10str10int32int64float32float32float32float32float32float32float32float32float32float32float32float32float64float64int64str6int64
2882303982179450881559431200.134829GALAXYBGS3627148285763219.810461.53635.2449613.027122.072317.045611.0778-1.00.0-1.0-1.0-1.0----------
2882303982179450891971350600.131031GALAXYBGS3100184751632418.553661.536323.236539.164357.429952.143738.0304-1.01.62485e-15-1.0-1.0-1.0----------
2882303982179450901792872600.145801GALAXYBGS1281115918852618.351461.536319.850150.108486.518174.636251.4383-1.00.0-1.0-1.0-1.0----------
2882303982179450911976207900.295097GALAXYBGS5168166423978417.993261.536318.452471.9985149.989178.462126.497-1.00.0-1.0-1.0-1.0----------
2882303982179450921397908900.404595GALAXYBGS607723658221220.154761.53632.099789.9580921.838936.39628.0412-1.00.0-1.0-1.0-1.0----------
2882303982179450931645920000.264106GALAXYBGS1126199116253719.117661.536311.386623.789435.423933.640526.4355-1.06.34972e-16-1.0-1.0-1.0----------
2882303982179450941852243400.124415GALAXYBGS543126639481818.612461.536315.30539.514973.938579.356554.7434-1.01.20094e-16-1.0-1.0-1.0----------
2882303982179450951784331600.32049GALAXYBGS89146563407619.959361.53633.7376211.39821.868327.329420.5494-1.08.81773e-17-1.0-1.0-1.0----------
2882303982179450961644018700.294273GALAXYBGS3795171574794818.060361.536316.036468.1686149.188206.879152.196-1.00.0-1.0-1.0-1.0----------
2882303982179450971296843100.206353GALAXYBGS67411112988619.658761.53634.8360115.321729.313529.373420.7506-1.00.0-1.0-1.0-1.0----------
..............................................................................
288230399866311538129845501.26878QSOQSO_T1238403679522.0308-1.01.540591.982641.961873.590396.32289-1.0-1.0-1.0-1.0-1.0----------
288230399866311539123092701.64302QSOQSO_T2238403679521.7264-1.02.03922.036552.255734.232497.84327-1.0-1.0-1.0-1.0-1.0----------
288230399866311540120554301.38468QSOQSO_T3238403679520.8983-1.04.372045.870575.6954212.156620.079-1.0-1.0-1.0-1.0-1.0----------
2882303998663115413368704.73661e-05WDWDDA963207035011018.422-1.041.62726.809614.52541.298990.688422-1.0-1.027000.08.5-1.0----------
28823039986631154238968200.000110076WDWDDA74151854322219.612-1.014.16910.14286.004720.6525230.353166-1.0-1.016750.08.0-1.0----------
288230399866311543860080-0.000210479WDWDDA27180067934019.832-1.012.156515.529315.96023.606372.02437-1.0-1.06500.08.25-1.0----------
28823039986631154438952400.000193467WDWDDB22212715190019.843-1.011.42939.425636.3804634.048163.9958-1.0-1.012000.08.0-1.0----------
2882303998663115453899640-3.33564e-06WDWDDA4131407324219.202-1.021.976630.879634.14738.658434.92947-1.0-1.06000.08.0-1.0----------
288230399866311546860730-5.00346e-06WDWDDA27145846440319.832-1.012.159715.532215.97373.608332.02574-1.0-1.06500.08.25-1.0----------
288230399866311547860720-1.10076e-05WDWDDA2758368567819.832-1.012.159615.532115.97343.608282.0257-1.0-1.06500.08.25-1.0----------

In [11]:
np.unique(simz_tab['ZWARN'])


Out[11]:
<MaskedColumn name='ZWARN' dtype='int64' length=7>
0
4
32
36
1024
1028
--

In [16]:
measured_z = simz_tab['ZWARN'].mask == False
np.unique(simz_tab[measured_z]['SPECTYPE']), np.unique(simz_tab[measured_z]['TRUESPECTYPE'])


Out[16]:
(<MaskedColumn name='SPECTYPE' dtype='str6' length=3>
 GALAXY
 QSO   
 STAR  , <MaskedColumn name='TRUESPECTYPE' dtype='str10' length=4>
 GALAXY    
 QSO       
 STAR      
 WD        )

Confuse me


In [18]:
# Cut the table
cut_simz = simz_tab[measured_z]

In [27]:
# Strip those columns
strip_ttypes = np.char.rstrip(cut_simz['TRUESPECTYPE'])
strip_stypes = np.char.rstrip(cut_simz['SPECTYPE'])

In [29]:
# All TRUE types
ttypes = np.unique(strip_ttypes)
ttypes


Out[29]:
array(['GALAXY', 'QSO', 'STAR', 'WD'], 
      dtype='<U10')

In [30]:
# All SPEC types
stypes = np.unique(strip_stypes)
stypes


Out[30]:
array(['GALAXY', 'QSO', 'STAR'], 
      dtype='<U6')

In [31]:
# Init
results = {}
for ttype in ttypes:
    results[ttype] = {}

In [36]:
for ttype in ttypes:
    itrue = strip_ttypes == ttype
    # Init correct answer in case there are none
    results[ttype][ttype] = 0
    #import pdb; pdb.set_trace()
    for stype in stypes:
        results[ttype][stype] = np.sum(strip_stypes[itrue] == stype)

In [37]:
results


Out[37]:
{'GALAXY': {'GALAXY': 34736, 'QSO': 2, 'STAR': 0},
 'QSO': {'GALAXY': 17, 'QSO': 3388, 'STAR': 0},
 'STAR': {'GALAXY': 521, 'QSO': 0, 'STAR': 2271},
 'WD': {'GALAXY': 7, 'QSO': 0, 'STAR': 71, 'WD': 0}}

Plot

Build the matrix


In [53]:
confuse = np.zeros((ttypes.size, ttypes.size))

In [54]:
for ii,ttype in enumerate(ttypes):
    for jj,ottype in enumerate(ttypes):
        if ottype in results[ttype].keys():
            confuse[ii,jj] = results[ttype][ottype]
confuse


Out[54]:
array([[  3.47360000e+04,   2.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.70000000e+01,   3.38800000e+03,   0.00000000e+00,
          0.00000000e+00],
       [  5.21000000e+02,   0.00000000e+00,   2.27100000e+03,
          0.00000000e+00],
       [  7.00000000e+00,   0.00000000e+00,   7.10000000e+01,
          0.00000000e+00]])

Normalize


In [55]:
for kk in range(confuse.shape[0]):
    confuse[kk,:] /= np.sum(confuse[kk,:])
confuse


Out[55]:
array([[  9.99942426e-01,   5.75738384e-05,   0.00000000e+00,
          0.00000000e+00],
       [  4.99265786e-03,   9.95007342e-01,   0.00000000e+00,
          0.00000000e+00],
       [  1.86604585e-01,   0.00000000e+00,   8.13395415e-01,
          0.00000000e+00],
       [  8.97435897e-02,   0.00000000e+00,   9.10256410e-01,
          0.00000000e+00]])

Plot


In [60]:
#plt.clf()
plt.matshow(confuse)
plt.colorbar()
plt.xlabel('True')
plt.ylabel('Predicted')
plt.show()



In [ ]: