In [115]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl

import seaborn as sns; sns.set(style="white", font_scale=2)

import numpy as np
import pandas as pd
from astropy.io import fits
import glob

import sklearn

import keras
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.layers.noise import GaussianNoise

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.preprocessing.image import random_channel_shift
from keras.optimizers import SGD, Adam
from keras import backend as K
K.set_image_data_format('channels_first')

import scipy.ndimage as ndi

import matplotlib.patches as patches

In [ ]:
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.figsize'] = np.array((10,6))*.6
mpl.rcParams['figure.facecolor'] = "white"

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# give access to importing dwarfz
import os, sys
dwarfz_package_dir = os.getcwd().split("dwarfz")[0]
if dwarfz_package_dir not in sys.path:
    sys.path.insert(0, dwarfz_package_dir)

import dwarfz
    
# back to regular import statements

In [6]:
# my modules that are DNN specific
import preprocessing
import geometry

To Do

  1. Read in fits images
  2. Apply pre-processing stretch (asinh)
  3. crop image smaller
  4. Combine filters into one cube
  5. Create training set with labels
  6. Set up keras model
  7. Poke around at the results

0) Get files


In [8]:
images_dir = preprocessing.images_dir
images_dir


Out[8]:
'../data/galaxy_images_training/quarry_files/'

In [9]:
HSC_ids = [int(os.path.basename(image_file).split("-")[0])
           for image_file in glob.glob(os.path.join(images_dir, "*.fits"))]
HSC_ids = set(HSC_ids) # remove duplicates
HSC_ids = np.array(sorted(HSC_ids))

# now filter out galaxies missing bands
HSC_ids = [HSC_id
           for HSC_id in HSC_ids
           if len(glob.glob(os.path.join(images_dir, "{}*.fits".format(str(HSC_id)))))==5]
HSC_ids = np.array(HSC_ids)

In [10]:
HSC_ids.size


Out[10]:
1866

In [11]:
HSC_id = HSC_ids[1] # for when I need a single sample galaxy

In [12]:
bands = ["g", "r", "i", "z", "y"]

1) Read in fits image


In [13]:
image, flux_mag_0 = preprocessing.get_image(HSC_id, "g")
print("image size: {} x {}".format(*image.shape))
image


image size: 239 x 239
Out[13]:
array([[-0.00146446,  0.02150147,  0.00693631, ...,  0.02792656,
        -0.02018547,  0.01850448],
       [-0.02519608,  0.00035813, -0.0455959 , ..., -0.00586783,
        -0.00882499,  0.01241659],
       [ 0.03936524,  0.00645859, -0.02110163, ...,  0.02997875,
         0.009456  ,  0.00591614],
       ...,
       [-0.03543996,  0.04346127, -0.0372493 , ..., -0.0014411 ,
        -0.01001758,  0.03473332],
       [-0.01037648, -0.03287457,  0.04310744, ...,  0.02935715,
         0.02273993, -0.00532476],
       [-0.05991019, -0.08159582,  0.02607481, ...,  0.01012528,
         0.00453719, -0.00872836]], dtype=float32)

In [14]:
preprocessing.image_plotter(image)



In [16]:
preprocessing.image_plotter(np.log(image))


/Users/egentry/anaconda3/envs/tf/lib/python3.6/site-packages/ipykernel_launcher.py:1: RuntimeWarning: invalid value encountered in log
  """Entry point for launching an IPython kernel.

2) Apply stretch

We're using (negative) asinh magnitudes, as implemented by the HSC collaboration.

To see more about asinh magnitude system, see : Lupton, Gunn and Szalay (1999) used for SDSS. (It's expliticly given in the SDSS Algorithms documentation as well as this overview page).

To see the source of our code, see: the HSC color image creator

And for reference, a common form of this stretch is: $$ \mathrm{mag}_\mathrm{asinh} = - \left(\frac{2.5}{\ln(10)}\right) \left(\mathrm{asinh}\left(\frac{f/f_0}{2b}\right) + \ln(b) \right)$$ for dimensionless softening parameter $b$, and reference flux (f_0).


In [17]:
def scale(x, fluxMag0):
    ### adapted from https://hsc-gitlab.mtk.nao.ac.jp/snippets/23
    mag0 = 19
    scale = 10 ** (0.4 * mag0) / fluxMag0
    x *= scale

    u_min = -0.05
    u_max = 2. / 3.
    u_a = np.exp(10.)

    x = np.arcsinh(u_a*x) / np.arcsinh(u_a)
    x = (x - u_min) / (u_max - u_min)

    return x

In [18]:
image_scaled = preprocessing.scale(image, flux_mag_0)

In [19]:
preprocessing.image_plotter(image_scaled)



In [20]:
sns.distplot(image_scaled.flatten())
plt.title("Distribution of Transformed Intensities")


/Users/egentry/anaconda3/envs/tf/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
Out[20]:
Text(0.5, 1.0, 'Distribution of Transformed Intensities')

In [22]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)

    plt.figure()
    preprocessing.image_plotter(image_scaled)
    plt.title("{} band".format(band))
    plt.colorbar()


3) Crop Image

Am I properly handling odd numbers?


In [23]:
pre_transformed_image_size  = 150
post_transformed_image_size = 75

In [24]:
cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)
cutout.shape


Out[24]:
(75, 75)

In [25]:
preprocessing.image_plotter(cutout)



In [26]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)
    cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)

    plt.figure()
    preprocessing.image_plotter(cutout)
    plt.title("{} band".format(band))
    plt.colorbar()


4) Combine filters into cube


In [27]:
images = [None]*len(bands)
flux_mag_0s = [None]*len(bands)
cutouts = [None]*len(bands)
for i, band in enumerate(bands):
    images[i], flux_mag_0s[i] = preprocessing.get_image(HSC_id, band)
    
    cutouts[i] = preprocessing.get_cutout(
        preprocessing.scale(images[i], flux_mag_0s[i]), 
        post_transformed_image_size
    )

In [28]:
cutout_cube = np.array(cutouts)
cutout_cube.shape


Out[28]:
(5, 75, 75)

In [29]:
# must transform into [0,1] for plt.imshow
# the HSC standard tool accomplishes this by clipping instead.
plt.imshow(preprocessing.transform_0_1(cutout_cube[:3,:,:].transpose(1,2,0)) )


Out[29]:
<matplotlib.image.AxesImage at 0x1a2ab49ba8>

In [30]:
for i, band in enumerate(bands):
    sns.distplot(cutout_cube[:,:,:].transpose(1,2,0)[:,:,i].flatten(), label=band)
    plt.legend(loc="best")


5) Load Training Set Labels


In [31]:
training_set_labels_filename = "../data/galaxy_images_training/2017_09_26-dwarf_galaxy_scores.csv"

In [32]:
df = pd.read_csv(training_set_labels_filename)
df = df.drop_duplicates("HSC_id")
df = df.set_index("HSC_id")
df = df[["low_z_low_mass"]]
df = df.rename(columns={"low_z_low_mass":"target"})
df.head()


Out[32]:
target
HSC_id
43158322471244656 False
43158605939114836 False
43159142810013665 False
43158734788125011 False
43158863637144621 True

In [37]:
def load_image_mappable(HSC_id):
    images      = [None]*len(bands)
    flux_mag_0s = [None]*len(bands)
    cutouts     = [None]*len(bands)
    for j, band in enumerate(bands):
        images[j], flux_mag_0s[j] = preprocessing.get_image(HSC_id, band)
        cutouts[j] = preprocessing.get_cutout(
            preprocessing.scale(images[j], flux_mag_0s[j]),
            pre_transformed_image_size)
    cutout_cube = np.array(cutouts)
    return cutout_cube

In [38]:
X = np.empty((len(HSC_ids), 5, 
              pre_transformed_image_size, pre_transformed_image_size))

In [39]:
X = np.array(list(map(load_image_mappable, HSC_ids)))

In [41]:
X_full = X
X_small = X[:,(0,2,4),:,:] # drop down to 3 bands

In [42]:
X.shape


Out[42]:
(1866, 5, 150, 150)

In [43]:
Y = df.loc[HSC_ids].target.values
Y


Out[43]:
array([False, False,  True, ..., False, False, False])

In [44]:
Y.mean()


Out[44]:
0.2792068595927117

Geometry!


In [46]:
h = pre_transformed_image_size
w = pre_transformed_image_size
transform_matrix = geometry.create_random_transform_matrix(h, w,
                                                  include_rotation=True,
                                                  translation_size = .01,
                                                  verbose=False)

x_tmp = X[0][:3]

result = geometry.apply_transform_new(x_tmp, transform_matrix, 
                            channel_axis=0, fill_mode="constant", cval=np.max(x_tmp))

result = preprocessing.get_cutout(x_tmp, post_transformed_image_size)
plt.imshow(preprocessing.transform_0_1(result.transpose(1,2,0)))


Out[46]:
<matplotlib.image.AxesImage at 0x1a2a0821d0>

In [47]:
import ipywidgets
ipywidgets.interact(preprocessing.transform_plotter,
                    X = ipywidgets.fixed(X),
                    rotation_degrees = ipywidgets.IntSlider(min=0, max=360, step=15, value=45),
                    dx_after = ipywidgets.IntSlider(min=-15, max=15),
                    dy_after = ipywidgets.IntSlider(min=-15, max=15),
                    color = ipywidgets.fixed(True),
                    shear_degrees = ipywidgets.IntSlider(min=0, max=90, step=5, value=0),
                    zoom_x = ipywidgets.FloatSlider(min=.5, max=2, value=1),
                    crop = ipywidgets.Checkbox(value=True)
                    )


Out[47]:
<function preprocessing.transform_plotter(X, reflect_x=False, rotation_degrees=45, dx_after=0, dy_after=0, shear_degrees=0, zoom_x=1, crop=False, color=True)>

5b) Split training and testing set


In [141]:
np.random.seed(1)

randomized_indices = np.arange(X.shape[0])
np.random.shuffle(randomized_indices)

testing_fraction = 0.2
testing_set_indices = randomized_indices[:int(testing_fraction*X.shape[0])]
training_set_indices = np.array(list(set([*randomized_indices]) - set([*testing_set_indices])))

In [142]:
testing_set_indices.size


Out[142]:
373

In [143]:
training_set_indices.size


Out[143]:
1493

6b) Adapt NumpyArrayIterator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images.

Also, I want to change the way that augmentation is happening


In [144]:
from data_generator import ArrayIterator

6c) Adapt ImageDataGenerator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images. Also, I'm adjusting the way that the affine transformations work for the data augmentation


In [145]:
from data_generator import ImageDataGenerator

6d) Create Data Generator


In [146]:
print('Using real-time data augmentation.')

h_before, w_before = X[0,0].shape
print("image shape before: ({},{})".format(h_before, w_before))

h_after = post_transformed_image_size
w_after = post_transformed_image_size
print("image shape after:  ({},{})".format(h_after, w_after))

# get a closure that binds the image size to get_cutout
postprocessing_function = lambda image: preprocessing.get_cutout(image, post_transformed_image_size)

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    with_reflection_x=True, # randomly apply a reflection (in x)
    with_reflection_y=True, # randomly apply a reflection (in y)
    with_rotation=False, # randomly apply a rotation
    width_shift_range=0.002,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.002,  # randomly shift images vertically (fraction of total height)
    postprocessing_function=postprocessing_function, # get a cutout of the processed image
    output_image_shape=(post_transformed_image_size,post_transformed_image_size)
)


Using real-time data augmentation.
image shape before: (150,150)
image shape after:  (75,75)

In [147]:
datagen.fit(X_small[training_set_indices])

7) Set up keras model


In [148]:
input_shape = cutout_cube.shape

nb_dense = 64

In [149]:
from keras.applications import inception_v3, inception_resnet_v2, vgg19

In [150]:
vgg19.VGG19(include_top=False, input_shape=(3, 75, 75))


Out[150]:
<keras.engine.training.Model at 0x1a4cc32f60>

In [152]:
model = Sequential()

# # # 1x1 convolution to make sure we only have 3 channels
n_channels_for_pretrained=3
# model.add(Conv2D(n_channels_for_pretrained, 1, padding='same',
#                  input_shape=input_shape))

pretrained_input_shape = tuple([n_channels_for_pretrained, *input_shape[1:]])
pretrained_layers = vgg19.VGG19(include_top=False,
                                input_shape=pretrained_input_shape
                               )
for layer in pretrained_layers.layers:
    layer.trainable = False
model.add(pretrained_layers)


model.add(Flatten())
model.add(Dense(2*nb_dense, activation="relu"))
model.add(Dense(nb_dense, activation="relu"))
model.add(Dense(1, activation="sigmoid"))

In [153]:
model.layers


Out[153]:
[<keras.engine.training.Model at 0x1a5dd849e8>,
 <keras.layers.core.Flatten at 0x1a58f47c88>,
 <keras.layers.core.Dense at 0x1a5dda8208>,
 <keras.layers.core.Dense at 0x1a5df16f60>,
 <keras.layers.core.Dense at 0x1a5e9a1908>]

In [154]:
learning_rate = 0.00025
decay = 1e-5
momentum = 0.9

# sgd = SGD(lr=learning_rate, decay=decay, momentum=momentum, nesterov=True)

adam = Adam(lr=learning_rate)

In [155]:
logger_filename = "training_transfer.log"

model.compile(loss='binary_crossentropy', 
#               optimizer=sgd, 
              optimizer=adam,
#               metrics=["accuracy"]
             )

if os.path.exists(logger_filename):
    logger_filename_tmp = logger_filename + ".old"
    os.rename(logger_filename, logger_filename_tmp)

In [156]:
earlystopping = EarlyStopping(monitor='loss',
                              patience=35,
                              verbose=1,
                              mode='auto' )

In [157]:
csv_logger = keras.callbacks.CSVLogger(logger_filename,
                                       append=True)

In [158]:
# modelcheckpoint = ModelCheckpoint(pathinCat+'Models/bestmodelMAG.hdf5',monitor='val_loss',verbose=0,save_best_only=True)

8) Run basic keras model


In [159]:
goal_batch_size = 64
steps_per_epoch = max(2, training_set_indices.size//goal_batch_size)
batch_size = training_set_indices.size//steps_per_epoch
print("steps_per_epoch: ", steps_per_epoch)
print("batch_size: ", batch_size)
epochs = 200
verbose = 1


steps_per_epoch:  23
batch_size:  64

In [160]:
X_test_transformed = np.array([datagen.standardize(X_img)
                               for X_img in X_small[testing_set_indices]])

X_test_transformed.shape


Out[160]:
(373, 3, 75, 75)

In [162]:
history = model.fit_generator(datagen.flow(X_small[training_set_indices], 
                                           Y[training_set_indices],
                                           batch_size=batch_size,
                                          ),
                              steps_per_epoch=steps_per_epoch,
                              epochs=epochs,
                              validation_data=(X_test_transformed, 
                                               Y[testing_set_indices]),
                              verbose=verbose,
                              callbacks=[
#                                   earlystopping,
                                  csv_logger],
                              )


Epoch 1/200
23/23 [==============================] - 150s 7s/step - loss: 0.6063 - val_loss: 0.6149
Epoch 2/200
23/23 [==============================] - 150s 7s/step - loss: 0.5902 - val_loss: 0.6116
Epoch 3/200
23/23 [==============================] - 146s 6s/step - loss: 0.5893 - val_loss: 0.6063
Epoch 4/200
23/23 [==============================] - 153s 7s/step - loss: 0.5911 - val_loss: 0.6065
Epoch 5/200
23/23 [==============================] - 143s 6s/step - loss: 0.5854 - val_loss: 0.6078
Epoch 6/200
23/23 [==============================] - 150s 7s/step - loss: 0.5845 - val_loss: 0.6043
Epoch 7/200
23/23 [==============================] - 150s 7s/step - loss: 0.5847 - val_loss: 0.6015
Epoch 8/200
23/23 [==============================] - 146s 6s/step - loss: 0.5903 - val_loss: 0.6032
Epoch 9/200
23/23 [==============================] - 147s 6s/step - loss: 0.6016 - val_loss: 0.6027
Epoch 10/200
23/23 [==============================] - 149s 6s/step - loss: 0.5823 - val_loss: 0.6035
Epoch 11/200
23/23 [==============================] - 151s 7s/step - loss: 0.5772 - val_loss: 0.6006
Epoch 12/200
23/23 [==============================] - 144s 6s/step - loss: 0.6009 - val_loss: 0.5995
Epoch 13/200
23/23 [==============================] - 148s 6s/step - loss: 0.5834 - val_loss: 0.5995
Epoch 14/200
23/23 [==============================] - 148s 6s/step - loss: 0.5836 - val_loss: 0.5977
Epoch 15/200
23/23 [==============================] - 149s 6s/step - loss: 0.5939 - val_loss: 0.5952
Epoch 16/200
23/23 [==============================] - 150s 7s/step - loss: 0.5781 - val_loss: 0.6034
Epoch 17/200
23/23 [==============================] - 149s 6s/step - loss: 0.5831 - val_loss: 0.5936
Epoch 18/200
23/23 [==============================] - 147s 6s/step - loss: 0.5851 - val_loss: 0.5972
Epoch 19/200
23/23 [==============================] - 148s 6s/step - loss: 0.5896 - val_loss: 0.5927
Epoch 20/200
23/23 [==============================] - 147s 6s/step - loss: 0.5875 - val_loss: 0.5906
Epoch 21/200
23/23 [==============================] - 152s 7s/step - loss: 0.5728 - val_loss: 0.5999
Epoch 22/200
23/23 [==============================] - 146s 6s/step - loss: 0.5902 - val_loss: 0.5918
Epoch 23/200
23/23 [==============================] - 149s 6s/step - loss: 0.5843 - val_loss: 0.6290
Epoch 24/200
23/23 [==============================] - 146s 6s/step - loss: 0.5873 - val_loss: 0.6042
Epoch 25/200
23/23 [==============================] - 149s 6s/step - loss: 0.5879 - val_loss: 0.5943
Epoch 26/200
23/23 [==============================] - 149s 6s/step - loss: 0.5858 - val_loss: 0.5975
Epoch 27/200
23/23 [==============================] - 145s 6s/step - loss: 0.5749 - val_loss: 0.5914
Epoch 28/200
23/23 [==============================] - 148s 6s/step - loss: 0.5852 - val_loss: 0.5925
Epoch 29/200
23/23 [==============================] - 146s 6s/step - loss: 0.5965 - val_loss: 0.5902
Epoch 30/200
23/23 [==============================] - 150s 7s/step - loss: 0.5806 - val_loss: 0.5941
Epoch 31/200
23/23 [==============================] - 147s 6s/step - loss: 0.5781 - val_loss: 0.5911
Epoch 32/200
23/23 [==============================] - 150s 7s/step - loss: 0.5851 - val_loss: 0.6027
Epoch 33/200
23/23 [==============================] - 146s 6s/step - loss: 0.5911 - val_loss: 0.5887
Epoch 34/200
23/23 [==============================] - 152s 7s/step - loss: 0.5731 - val_loss: 0.5905
Epoch 35/200
23/23 [==============================] - 144s 6s/step - loss: 0.5925 - val_loss: 0.5979
Epoch 36/200
23/23 [==============================] - 154s 7s/step - loss: 0.5884 - val_loss: 0.5879
Epoch 37/200
23/23 [==============================] - 146s 6s/step - loss: 0.5800 - val_loss: 0.5885
Epoch 38/200
23/23 [==============================] - 149s 6s/step - loss: 0.5793 - val_loss: 0.5887
Epoch 39/200
23/23 [==============================] - 144s 6s/step - loss: 0.5826 - val_loss: 0.5874
Epoch 40/200
23/23 [==============================] - 153s 7s/step - loss: 0.5821 - val_loss: 0.5946
Epoch 41/200
23/23 [==============================] - 148s 6s/step - loss: 0.5742 - val_loss: 0.5884
Epoch 42/200
23/23 [==============================] - 148s 6s/step - loss: 0.5870 - val_loss: 0.6244
Epoch 43/200
23/23 [==============================] - 148s 6s/step - loss: 0.5870 - val_loss: 0.5873
Epoch 44/200
23/23 [==============================] - 146s 6s/step - loss: 0.5718 - val_loss: 0.5899
Epoch 45/200
23/23 [==============================] - 150s 7s/step - loss: 0.5939 - val_loss: 0.5869
Epoch 46/200
23/23 [==============================] - 146s 6s/step - loss: 0.5719 - val_loss: 0.5955
Epoch 47/200
23/23 [==============================] - 149s 6s/step - loss: 0.5869 - val_loss: 0.5865
Epoch 48/200
23/23 [==============================] - 147s 6s/step - loss: 0.5777 - val_loss: 0.5897
Epoch 49/200
23/23 [==============================] - 150s 7s/step - loss: 0.5841 - val_loss: 0.5871
Epoch 50/200
23/23 [==============================] - 146s 6s/step - loss: 0.5788 - val_loss: 0.5874
Epoch 51/200
23/23 [==============================] - 149s 6s/step - loss: 0.5878 - val_loss: 0.5851
Epoch 52/200
23/23 [==============================] - 148s 6s/step - loss: 0.5771 - val_loss: 0.5844
Epoch 53/200
23/23 [==============================] - 149s 6s/step - loss: 0.5846 - val_loss: 0.5905
Epoch 54/200
23/23 [==============================] - 149s 6s/step - loss: 0.5771 - val_loss: 0.5822
Epoch 55/200
23/23 [==============================] - 147s 6s/step - loss: 0.5748 - val_loss: 0.5832
Epoch 56/200
23/23 [==============================] - 146s 6s/step - loss: 0.5872 - val_loss: 0.5824
Epoch 57/200
23/23 [==============================] - 149s 6s/step - loss: 0.5721 - val_loss: 0.5822
Epoch 58/200
23/23 [==============================] - 148s 6s/step - loss: 0.5859 - val_loss: 0.5984
Epoch 59/200
23/23 [==============================] - 152s 7s/step - loss: 0.5915 - val_loss: 0.5891
Epoch 60/200
23/23 [==============================] - 146s 6s/step - loss: 0.5618 - val_loss: 0.5886
Epoch 61/200
23/23 [==============================] - 146s 6s/step - loss: 0.5913 - val_loss: 0.5901
Epoch 62/200
23/23 [==============================] - 152s 7s/step - loss: 0.5711 - val_loss: 0.5873
Epoch 63/200
23/23 [==============================] - 146s 6s/step - loss: 0.5828 - val_loss: 0.5949
Epoch 64/200
23/23 [==============================] - 151s 7s/step - loss: 0.5792 - val_loss: 0.5913
Epoch 65/200
23/23 [==============================] - 146s 6s/step - loss: 0.5971 - val_loss: 0.5914
Epoch 66/200
23/23 [==============================] - 153s 7s/step - loss: 0.5733 - val_loss: 0.5868
Epoch 67/200
23/23 [==============================] - 146s 6s/step - loss: 0.5853 - val_loss: 0.5873
Epoch 68/200
23/23 [==============================] - 149s 6s/step - loss: 0.5666 - val_loss: 0.5880
Epoch 69/200
23/23 [==============================] - 145s 6s/step - loss: 0.5819 - val_loss: 0.5873
Epoch 70/200
23/23 [==============================] - 150s 7s/step - loss: 0.5838 - val_loss: 0.5882
Epoch 71/200
23/23 [==============================] - 146s 6s/step - loss: 0.5768 - val_loss: 0.5944
Epoch 72/200
23/23 [==============================] - 150s 7s/step - loss: 0.5744 - val_loss: 0.6068
Epoch 73/200
23/23 [==============================] - 146s 6s/step - loss: 0.5813 - val_loss: 0.6266
Epoch 74/200
23/23 [==============================] - 147s 6s/step - loss: 0.5848 - val_loss: 0.5965
Epoch 75/200
23/23 [==============================] - 151s 7s/step - loss: 0.5849 - val_loss: 0.5866
Epoch 76/200
23/23 [==============================] - 144s 6s/step - loss: 0.5674 - val_loss: 0.5942
Epoch 77/200
23/23 [==============================] - 149s 6s/step - loss: 0.5851 - val_loss: 0.6007
Epoch 78/200
23/23 [==============================] - 146s 6s/step - loss: 0.5885 - val_loss: 0.5856
Epoch 79/200
23/23 [==============================] - 150s 7s/step - loss: 0.5622 - val_loss: 0.5832
Epoch 80/200
23/23 [==============================] - 147s 6s/step - loss: 0.5839 - val_loss: 0.5875
Epoch 81/200
23/23 [==============================] - 153s 7s/step - loss: 0.5867 - val_loss: 0.5905
Epoch 82/200
23/23 [==============================] - 147s 6s/step - loss: 0.5672 - val_loss: 0.6072
Epoch 83/200
23/23 [==============================] - 150s 7s/step - loss: 0.5955 - val_loss: 0.6050
Epoch 84/200
23/23 [==============================] - 147s 6s/step - loss: 0.5795 - val_loss: 0.5882
Epoch 85/200
23/23 [==============================] - 146s 6s/step - loss: 0.5673 - val_loss: 0.6048
Epoch 86/200
23/23 [==============================] - 147s 6s/step - loss: 0.5866 - val_loss: 0.5857
Epoch 87/200
23/23 [==============================] - 151s 7s/step - loss: 0.5718 - val_loss: 0.5883
Epoch 88/200
23/23 [==============================] - 151s 7s/step - loss: 0.5838 - val_loss: 0.5910
Epoch 89/200
23/23 [==============================] - 150s 7s/step - loss: 0.5857 - val_loss: 0.6024
Epoch 90/200
23/23 [==============================] - 144s 6s/step - loss: 0.5731 - val_loss: 0.5864
Epoch 91/200
23/23 [==============================] - 150s 7s/step - loss: 0.5818 - val_loss: 0.5955
Epoch 92/200
23/23 [==============================] - 150s 7s/step - loss: 0.5697 - val_loss: 0.5872
Epoch 93/200
23/23 [==============================] - 147s 6s/step - loss: 0.5696 - val_loss: 0.5951
Epoch 94/200
23/23 [==============================] - 149s 6s/step - loss: 0.5840 - val_loss: 0.5930
Epoch 95/200
23/23 [==============================] - 145s 6s/step - loss: 0.5726 - val_loss: 0.5865
Epoch 96/200
23/23 [==============================] - 150s 7s/step - loss: 0.5772 - val_loss: 0.6031
Epoch 97/200
23/23 [==============================] - 146s 6s/step - loss: 0.5772 - val_loss: 0.5996
Epoch 98/200
23/23 [==============================] - 148s 6s/step - loss: 0.5761 - val_loss: 0.5920
Epoch 99/200
23/23 [==============================] - 146s 6s/step - loss: 0.5877 - val_loss: 0.5950
Epoch 100/200
23/23 [==============================] - 152s 7s/step - loss: 0.5693 - val_loss: 0.6088
Epoch 101/200
23/23 [==============================] - 145s 6s/step - loss: 0.5829 - val_loss: 0.5927
Epoch 102/200
23/23 [==============================] - 153s 7s/step - loss: 0.5647 - val_loss: 0.5977
Epoch 103/200
23/23 [==============================] - 146s 6s/step - loss: 0.5975 - val_loss: 0.5963
Epoch 104/200
23/23 [==============================] - 145s 6s/step - loss: 0.5721 - val_loss: 0.6261
Epoch 105/200
23/23 [==============================] - 148s 6s/step - loss: 0.5841 - val_loss: 0.5929
Epoch 106/200
23/23 [==============================] - 147s 6s/step - loss: 0.5753 - val_loss: 0.5976
Epoch 107/200
23/23 [==============================] - 152s 7s/step - loss: 0.5749 - val_loss: 0.5947
Epoch 108/200
23/23 [==============================] - 144s 6s/step - loss: 0.5809 - val_loss: 0.6394
Epoch 109/200
23/23 [==============================] - 150s 7s/step - loss: 0.5697 - val_loss: 0.5907
Epoch 110/200
23/23 [==============================] - 147s 6s/step - loss: 0.5778 - val_loss: 0.6025
Epoch 111/200
23/23 [==============================] - 153s 7s/step - loss: 0.5668 - val_loss: 0.5988
Epoch 112/200
23/23 [==============================] - 147s 6s/step - loss: 0.5777 - val_loss: 0.6030
Epoch 113/200
23/23 [==============================] - 150s 7s/step - loss: 0.5723 - val_loss: 0.6321
Epoch 114/200
23/23 [==============================] - 143s 6s/step - loss: 0.5813 - val_loss: 0.5911
Epoch 115/200
23/23 [==============================] - 152s 7s/step - loss: 0.5837 - val_loss: 0.6341
Epoch 116/200
23/23 [==============================] - 147s 6s/step - loss: 0.5575 - val_loss: 0.6037
Epoch 117/200
23/23 [==============================] - 149s 6s/step - loss: 0.5587 - val_loss: 0.6026
Epoch 118/200
23/23 [==============================] - 146s 6s/step - loss: 0.5851 - val_loss: 0.5896
Epoch 119/200
23/23 [==============================] - 149s 6s/step - loss: 0.5735 - val_loss: 0.5984
Epoch 120/200
23/23 [==============================] - 146s 6s/step - loss: 0.5768 - val_loss: 0.5878
Epoch 121/200
23/23 [==============================] - 148s 6s/step - loss: 0.5755 - val_loss: 0.5905
Epoch 122/200
23/23 [==============================] - 146s 6s/step - loss: 0.5756 - val_loss: 0.5979
Epoch 123/200
23/23 [==============================] - 145s 6s/step - loss: 0.5683 - val_loss: 0.6090
Epoch 124/200
23/23 [==============================] - 149s 6s/step - loss: 0.5791 - val_loss: 0.5820
Epoch 125/200
23/23 [==============================] - 148s 6s/step - loss: 0.5812 - val_loss: 0.5889
Epoch 126/200
23/23 [==============================] - 150s 7s/step - loss: 0.5551 - val_loss: 0.5844
Epoch 127/200
23/23 [==============================] - 146s 6s/step - loss: 0.5855 - val_loss: 0.5957
Epoch 128/200
23/23 [==============================] - 149s 6s/step - loss: 0.5803 - val_loss: 0.5849
Epoch 129/200
23/23 [==============================] - 150s 7s/step - loss: 0.5755 - val_loss: 0.5862
Epoch 130/200
23/23 [==============================] - 149s 6s/step - loss: 0.5690 - val_loss: 0.5846
Epoch 131/200
23/23 [==============================] - 146s 6s/step - loss: 0.5719 - val_loss: 0.5808
Epoch 132/200
23/23 [==============================] - 149s 6s/step - loss: 0.5646 - val_loss: 0.5940
Epoch 133/200
23/23 [==============================] - 145s 6s/step - loss: 0.5819 - val_loss: 0.5862
Epoch 134/200
23/23 [==============================] - 152s 7s/step - loss: 0.5686 - val_loss: 0.5839
Epoch 135/200
23/23 [==============================] - 147s 6s/step - loss: 0.5744 - val_loss: 0.5880
Epoch 136/200
23/23 [==============================] - 146s 6s/step - loss: 0.5752 - val_loss: 0.5951
Epoch 137/200
23/23 [==============================] - 151s 7s/step - loss: 0.5645 - val_loss: 0.6116
Epoch 138/200
23/23 [==============================] - 147s 6s/step - loss: 0.5788 - val_loss: 0.5960
Epoch 139/200
23/23 [==============================] - 147s 6s/step - loss: 0.5689 - val_loss: 0.5992
Epoch 140/200
23/23 [==============================] - 142s 6s/step - loss: 0.5816 - val_loss: 0.7129
Epoch 141/200
23/23 [==============================] - 149s 6s/step - loss: 0.5712 - val_loss: 0.5956
Epoch 142/200
23/23 [==============================] - 151s 7s/step - loss: 0.5713 - val_loss: 0.5868
Epoch 143/200
23/23 [==============================] - 149s 6s/step - loss: 0.5857 - val_loss: 0.5844
Epoch 144/200
23/23 [==============================] - 146s 6s/step - loss: 0.5690 - val_loss: 0.5893
Epoch 145/200
23/23 [==============================] - 128s 6s/step - loss: 0.5677 - val_loss: 0.5839
Epoch 146/200
23/23 [==============================] - 94s 4s/step - loss: 0.5709 - val_loss: 0.5862
Epoch 147/200
23/23 [==============================] - 93s 4s/step - loss: 0.5614 - val_loss: 0.5911
Epoch 148/200
23/23 [==============================] - 93s 4s/step - loss: 0.5770 - val_loss: 0.6105
Epoch 149/200
23/23 [==============================] - 93s 4s/step - loss: 0.5819 - val_loss: 0.6538
Epoch 150/200
23/23 [==============================] - 93s 4s/step - loss: 0.5884 - val_loss: 0.6585
Epoch 151/200
23/23 [==============================] - 96s 4s/step - loss: 0.5662 - val_loss: 0.5922
Epoch 152/200
23/23 [==============================] - 92s 4s/step - loss: 0.5714 - val_loss: 0.6083
Epoch 153/200
23/23 [==============================] - 96s 4s/step - loss: 0.5691 - val_loss: 0.6139
Epoch 154/200
23/23 [==============================] - 93s 4s/step - loss: 0.5564 - val_loss: 0.5819
Epoch 155/200
23/23 [==============================] - 94s 4s/step - loss: 0.5836 - val_loss: 0.5949
Epoch 156/200
23/23 [==============================] - 91s 4s/step - loss: 0.5645 - val_loss: 0.5965
Epoch 157/200
23/23 [==============================] - 96s 4s/step - loss: 0.5730 - val_loss: 0.5852
Epoch 158/200
23/23 [==============================] - 93s 4s/step - loss: 0.5660 - val_loss: 0.5927
Epoch 159/200
23/23 [==============================] - 93s 4s/step - loss: 0.5627 - val_loss: 0.6043
Epoch 160/200
23/23 [==============================] - 92s 4s/step - loss: 0.5625 - val_loss: 0.5800
Epoch 161/200
23/23 [==============================] - 96s 4s/step - loss: 0.5652 - val_loss: 0.5813
Epoch 162/200
23/23 [==============================] - 94s 4s/step - loss: 0.5598 - val_loss: 0.5960
Epoch 163/200
23/23 [==============================] - 93s 4s/step - loss: 0.5670 - val_loss: 0.5926
Epoch 164/200
23/23 [==============================] - 93s 4s/step - loss: 0.5810 - val_loss: 0.5826
Epoch 165/200
23/23 [==============================] - 94s 4s/step - loss: 0.5570 - val_loss: 0.5893
Epoch 166/200
23/23 [==============================] - 93s 4s/step - loss: 0.5795 - val_loss: 0.6044
Epoch 167/200
23/23 [==============================] - 93s 4s/step - loss: 0.5697 - val_loss: 0.5912
Epoch 168/200
23/23 [==============================] - 93s 4s/step - loss: 0.5721 - val_loss: 0.6060
Epoch 169/200
23/23 [==============================] - 94s 4s/step - loss: 0.5604 - val_loss: 0.5864
Epoch 170/200
23/23 [==============================] - 93s 4s/step - loss: 0.5755 - val_loss: 0.5978
Epoch 171/200
23/23 [==============================] - 94s 4s/step - loss: 0.5607 - val_loss: 0.5863
Epoch 172/200
23/23 [==============================] - 94s 4s/step - loss: 0.5639 - val_loss: 0.6009
Epoch 173/200
23/23 [==============================] - 125s 5s/step - loss: 0.5668 - val_loss: 0.6111
Epoch 174/200
23/23 [==============================] - 154s 7s/step - loss: 0.5628 - val_loss: 0.5868
Epoch 175/200
23/23 [==============================] - 157s 7s/step - loss: 0.5728 - val_loss: 0.6231
Epoch 176/200
23/23 [==============================] - 161s 7s/step - loss: 0.5681 - val_loss: 0.5944
Epoch 177/200
23/23 [==============================] - 154s 7s/step - loss: 0.5700 - val_loss: 0.6202
Epoch 178/200
23/23 [==============================] - 160s 7s/step - loss: 0.5712 - val_loss: 0.5883
Epoch 179/200
23/23 [==============================] - 153s 7s/step - loss: 0.5735 - val_loss: 0.6911
Epoch 180/200
23/23 [==============================] - 158s 7s/step - loss: 0.5576 - val_loss: 0.6151
Epoch 181/200
23/23 [==============================] - 160s 7s/step - loss: 0.5745 - val_loss: 0.5992
Epoch 182/200
23/23 [==============================] - 156s 7s/step - loss: 0.5796 - val_loss: 0.6076
Epoch 183/200
23/23 [==============================] - 152s 7s/step - loss: 0.5537 - val_loss: 0.5899
Epoch 184/200
23/23 [==============================] - 162s 7s/step - loss: 0.5561 - val_loss: 0.6340
Epoch 185/200
23/23 [==============================] - 156s 7s/step - loss: 0.5625 - val_loss: 0.5861
Epoch 186/200
23/23 [==============================] - 156s 7s/step - loss: 0.5688 - val_loss: 0.5894
Epoch 187/200
23/23 [==============================] - 152s 7s/step - loss: 0.5487 - val_loss: 0.6050
Epoch 188/200
23/23 [==============================] - 160s 7s/step - loss: 0.5633 - val_loss: 0.6016
Epoch 189/200
23/23 [==============================] - 156s 7s/step - loss: 0.5716 - val_loss: 0.6209
Epoch 190/200
23/23 [==============================] - 157s 7s/step - loss: 0.5618 - val_loss: 0.5951
Epoch 191/200
23/23 [==============================] - 157s 7s/step - loss: 0.5584 - val_loss: 0.6164
Epoch 192/200
23/23 [==============================] - 157s 7s/step - loss: 0.5673 - val_loss: 0.6586
Epoch 193/200
23/23 [==============================] - 156s 7s/step - loss: 0.5703 - val_loss: 0.6123
Epoch 194/200
23/23 [==============================] - 158s 7s/step - loss: 0.5655 - val_loss: 0.6244
Epoch 195/200
23/23 [==============================] - 157s 7s/step - loss: 0.5564 - val_loss: 0.5976
Epoch 196/200
23/23 [==============================] - 156s 7s/step - loss: 0.5567 - val_loss: 0.5988
Epoch 197/200
23/23 [==============================] - 157s 7s/step - loss: 0.5717 - val_loss: 0.5895
Epoch 198/200
23/23 [==============================] - 160s 7s/step - loss: 0.5642 - val_loss: 0.5930
Epoch 199/200
23/23 [==============================] - 154s 7s/step - loss: 0.5685 - val_loss: 0.6252
Epoch 200/200
23/23 [==============================] - 160s 7s/step - loss: 0.5491 - val_loss: 0.6183

In [163]:
logged_history = pd.read_csv(logger_filename)
logged_history


Out[163]:
epoch loss val_loss
0 0 0.604966 0.614906
1 1 0.591833 0.611575
2 2 0.587082 0.606265
3 3 0.591132 0.606473
4 4 0.586355 0.607815
5 5 0.587027 0.604261
6 6 0.584728 0.601473
7 7 0.589871 0.603235
8 8 0.597793 0.602655
9 9 0.578286 0.603495
10 10 0.577246 0.600585
11 11 0.601189 0.599494
12 12 0.584351 0.599510
13 13 0.579366 0.597704
14 14 0.594418 0.595248
15 15 0.578069 0.603376
16 16 0.586432 0.593551
17 17 0.586998 0.597205
18 18 0.586362 0.592716
19 19 0.586411 0.590637
20 20 0.572845 0.599873
21 21 0.592535 0.591805
22 22 0.584710 0.628970
23 23 0.585402 0.604165
24 24 0.587066 0.594320
25 25 0.585818 0.597481
26 26 0.578573 0.591441
27 27 0.584643 0.592499
28 28 0.597982 0.590190
29 29 0.582428 0.594075
... ... ... ...
170 170 0.562039 0.586296
171 171 0.566052 0.600897
172 172 0.566762 0.611098
173 173 0.564094 0.586803
174 174 0.573751 0.623115
175 175 0.568087 0.594387
176 176 0.564233 0.620183
177 177 0.571201 0.588319
178 178 0.563146 0.691147
179 179 0.559985 0.615093
180 180 0.574548 0.599249
181 181 0.582403 0.607553
182 182 0.553024 0.589932
183 183 0.556126 0.634021
184 184 0.564074 0.586116
185 185 0.572400 0.589404
186 186 0.552891 0.604999
187 187 0.563263 0.601616
188 188 0.571406 0.620868
189 189 0.564572 0.595052
190 190 0.558080 0.616419
191 191 0.569750 0.658632
192 192 0.569914 0.612299
193 193 0.565014 0.624427
194 194 0.556423 0.597618
195 195 0.555038 0.598813
196 196 0.571619 0.589494
197 197 0.564240 0.593031
198 198 0.564383 0.625169
199 199 0.549107 0.618255

200 rows × 3 columns


In [164]:
from sklearn.metrics import log_loss

In [165]:
p = Y[training_set_indices].mean()
prior_loss = log_loss(Y[testing_set_indices], 
                      [p]*testing_set_indices.size)

In [166]:
from sklearn.metrics import log_loss
with mpl.rc_context(rc={"figure.figsize": (10,6)}):

        
    plt.axhline(prior_loss, label="Prior", 
                linestyle="dashed", color="black")
    
    plt.plot(logged_history["val_loss"], label="Validation")
    plt.plot(logged_history["loss"], label="Training")
    
    plt.xlabel("Epoch")
    plt.ylabel("Loss\n(avg. binary cross-entropy)")
    
    plt.legend()
    
    plt.ylim(bottom=.56, top=.62)



In [167]:
from sklearn.metrics import log_loss
with mpl.rc_context(rc={"figure.figsize": (10,6)}):

    simple_conv = lambda x: np.convolve(x, np.ones(5)/5, mode="valid")
    
    plt.axhline(prior_loss, label="Prior", 
                linestyle="dashed", color="black")
    
    plt.plot(simple_conv(logged_history["val_loss"]),
             label="Validation")
    plt.plot(simple_conv(logged_history["loss"]),
             label="Training")
    
    plt.xlabel("Epoch")
    plt.ylabel("Loss\n(avg. binary cross-entropy)")
    
    plt.legend()
    
    plt.ylim(bottom=.56, top=.62)


9) Look at validation results


In [59]:
class_probs = model.predict_proba(X_test_transformed).flatten()
class_probs


373/373 [==============================] - 1s 3ms/step
Out[59]:
array([0.38096353, 0.44315898, 0.42595476, 0.39431217, 0.44192123,
       0.42403397, 0.3084317 , 0.3977308 , 0.42297694, 0.44161874,
       0.4457367 , 0.4409836 , 0.40325916, 0.37228101, 0.28045028,
       0.39575815, 0.44849315, 0.4440918 , 0.4450764 , 0.4423715 ,
       0.4449345 , 0.41168362, 0.4467733 , 0.43635613, 0.44292593,
       0.41494936, 0.39176732, 0.39990672, 0.2934756 , 0.32629684,
       0.3822251 , 0.31482962, 0.3116483 , 0.4064709 , 0.30879626,
       0.2557003 , 0.41120014, 0.37519196, 0.44936234, 0.40176862,
       0.4401391 , 0.4009208 , 0.33045557, 0.43595022, 0.3117058 ,
       0.36470458, 0.42641282, 0.38719258, 0.22711629, 0.39770648,
       0.2820537 , 0.34835997, 0.43251485, 0.30706087, 0.25999638,
       0.40024894, 0.38270858, 0.3465161 , 0.38174748, 0.43951225,
       0.36454654, 0.37350082, 0.3985789 , 0.42196226, 0.43141177,
       0.391576  , 0.402127  , 0.445567  , 0.43924326, 0.4386993 ,
       0.4277306 , 0.43835342, 0.2692941 , 0.44942412, 0.4508125 ,
       0.15961964, 0.39857697, 0.44297135, 0.3566997 , 0.440747  ,
       0.2533559 , 0.37569338, 0.4425728 , 0.43577525, 0.38890773,
       0.43563932, 0.4241861 , 0.3932006 , 0.43114784, 0.4328482 ,
       0.4168405 , 0.39507186, 0.4186143 , 0.42405146, 0.44095054,
       0.4161974 , 0.33376172, 0.44632694, 0.43030286, 0.3956095 ,
       0.30100748, 0.39117464, 0.4447542 , 0.43389064, 0.43305126,
       0.34959513, 0.35396492, 0.3121762 , 0.40590182, 0.4294078 ,
       0.4349919 , 0.40991887, 0.41888702, 0.36120278, 0.4282187 ,
       0.44593552, 0.3572685 , 0.44955522, 0.44456702, 0.3408464 ,
       0.45547816, 0.42949343, 0.342261  , 0.37738895, 0.44656345,
       0.31829506, 0.40286332, 0.44911963, 0.42276067, 0.44706854,
       0.40765384, 0.31033507, 0.32841244, 0.25814795, 0.31241944,
       0.453109  , 0.44477054, 0.34910932, 0.43158117, 0.17553769,
       0.45416188, 0.41460636, 0.40567777, 0.44760412, 0.24019198,
       0.44046155, 0.40499735, 0.36315823, 0.36915693, 0.4149858 ,
       0.44576353, 0.38120922, 0.3828068 , 0.37243614, 0.44181246,
       0.3235208 , 0.44830713, 0.44839013, 0.44242042, 0.44256744,
       0.44800305, 0.44181132, 0.43351567, 0.41210204, 0.44121698,
       0.413092  , 0.44176257, 0.43989152, 0.44552085, 0.43252727,
       0.4258776 , 0.3768219 , 0.41349387, 0.44884515, 0.32588708,
       0.41172966, 0.44757605, 0.41906548, 0.3825629 , 0.2607594 ,
       0.38989165, 0.29450786, 0.44284907, 0.31305894, 0.4185819 ,
       0.29876572, 0.16504529, 0.44900298, 0.44687438, 0.44056222,
       0.44137925, 0.3948508 , 0.34090987, 0.38642806, 0.40271282,
       0.44723928, 0.4007566 , 0.43055147, 0.44907433, 0.44951198,
       0.43810576, 0.43672234, 0.43379405, 0.43176052, 0.44240138,
       0.40329707, 0.4387482 , 0.32325944, 0.40388075, 0.27106833,
       0.3865721 , 0.20918834, 0.34775093, 0.44877672, 0.25511292,
       0.44780973, 0.40457538, 0.44524664, 0.32925203, 0.43885016,
       0.35806048, 0.3756185 , 0.38967642, 0.4192389 , 0.24980316,
       0.44150057, 0.43218148, 0.41110608, 0.39357594, 0.29679593,
       0.44293633, 0.42880318, 0.44525757, 0.4079865 , 0.3264692 ,
       0.44915658, 0.42911676, 0.44559175, 0.4005376 , 0.3871582 ,
       0.4421507 , 0.27849784, 0.23117535, 0.44651335, 0.4048834 ,
       0.40808037, 0.29254013, 0.4494428 , 0.3446331 , 0.4238776 ,
       0.40741017, 0.42450374, 0.2993166 , 0.3820515 , 0.447232  ,
       0.2889412 , 0.42764917, 0.44313568, 0.2791731 , 0.445895  ,
       0.35959965, 0.28588116, 0.39521152, 0.42558408, 0.29218394,
       0.3922141 , 0.37945515, 0.38019413, 0.44780952, 0.44512308,
       0.38719988, 0.39650354, 0.4100446 , 0.25340104, 0.3529415 ,
       0.38736403, 0.29348183, 0.44824216, 0.39691976, 0.443862  ,
       0.38470116, 0.30114385, 0.39607376, 0.42483705, 0.31340092,
       0.44839117, 0.45280483, 0.28304958, 0.37893963, 0.44633308,
       0.44886526, 0.42260054, 0.36907005, 0.39242512, 0.36889833,
       0.43124402, 0.4156075 , 0.25506765, 0.44074667, 0.38222235,
       0.43462127, 0.4382892 , 0.35095876, 0.4299872 , 0.44595292,
       0.3694274 , 0.4358682 , 0.40623176, 0.39650482, 0.44579092,
       0.42501035, 0.44486007, 0.40626553, 0.4242109 , 0.44190362,
       0.45157403, 0.36083287, 0.37688532, 0.41013816, 0.3927443 ,
       0.45153722, 0.3697739 , 0.447224  , 0.4037631 , 0.42013228,
       0.2867513 , 0.43948248, 0.43644127, 0.44002444, 0.40816608,
       0.24360898, 0.44442272, 0.43882433, 0.4233473 , 0.40551072,
       0.44690314, 0.41971892, 0.44890967, 0.4048392 , 0.4508837 ,
       0.27985042, 0.43386275, 0.16176486, 0.44558656, 0.37172565,
       0.40866947, 0.326269  , 0.3866764 , 0.38622338, 0.29225364,
       0.36603895, 0.42216617, 0.29009834, 0.25085118, 0.43494695,
       0.41980428, 0.31129548, 0.3963757 , 0.37951544, 0.438067  ,
       0.4432989 , 0.44002596, 0.44598934, 0.35958058, 0.4162218 ,
       0.32581538, 0.43282866, 0.38365978, 0.44844076, 0.43104017,
       0.44456834, 0.36104265, 0.44557685], dtype=float32)

In [60]:
with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    sns.distplot(class_probs[Y[testing_set_indices]==True], color="g", label="true dwarfs")
    sns.distplot(class_probs[Y[testing_set_indices]==False], color="b", label="true non-dwarfs")

    plt.xlabel("p(dwarf | image)")
    plt.ylabel("density (galaxies)")

    plt.xlim(0, .7)
    plt.axvline(Y[training_set_indices].mean(), linestyle="dashed", color="black", label="prior\n(from training set)")
    plt.axvline(.5, linestyle="dotted", color="black", label="50/50")

    plt.legend(
        loc="upper left",
        bbox_to_anchor=(1, 1),
    )



In [61]:
from sklearn import metrics
from sklearn.metrics import roc_auc_score

with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    fpr, tpr, _ = metrics.roc_curve(Y[testing_set_indices], class_probs)
    roc_auc = roc_auc_score(Y[testing_set_indices], class_probs)

    plt.plot(fpr, tpr, label="DNN (AUC = {:.2})".format(roc_auc))
    plt.plot([0,1], [0,1], linestyle="dashed", color="black", label="random guessing")

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")

    plt.title("ROC Curve")

    plt.legend(loc="best")



In [62]:
from sklearn.metrics import average_precision_score
with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    precision, recall, _ = metrics.precision_recall_curve(Y[testing_set_indices], class_probs)
    pr_auc = average_precision_score(Y[testing_set_indices], class_probs)

    plt.plot(recall, precision, label="AUC = {:.2}".format(pr_auc))
    plt.plot([0,1], [Y[testing_set_indices].mean()]*2, linestyle="dashed", color="black")

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("Recall")
    plt.ylabel("Precision")

    plt.title("PR Curve")

    plt.legend(loc="best")


Apply a threshold

For now, just use a threshold using the prior class probability (estimated from the training set)

Note under a symmetric loss function, this isn't as good as


In [63]:
predicted_classes = class_probs > (Y[training_set_indices].mean())
predicted_classes.mean()


Out[63]:
0.9195710455764075

In [64]:
confusion_matrix = metrics.confusion_matrix(Y[testing_set_indices], predicted_classes)
confusion_matrix


Out[64]:
array([[ 27, 257],
       [  3,  86]])

In [65]:
print("number of dwarfs (true)     : ", Y[testing_set_indices].sum())
print("number of dwarfs (predicted): ", predicted_classes.sum())


number of dwarfs (true)     :  89
number of dwarfs (predicted):  343

In [66]:
import itertools
# adapted from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

confusion_matrix_normalized = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]

plt.imshow(confusion_matrix_normalized, interpolation='nearest',cmap="gray_r")
# plt.title(title)
plt.colorbar()
tick_marks = np.arange(2)
classes = [False, True]
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = 'd'
thresh = 1 / 2.
for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])):
    plt.text(j, i, format(confusion_matrix[i, j], fmt),
             fontdict={"size":20},
             horizontalalignment="center",
             color="white" if confusion_matrix_normalized[i, j] > thresh else "black")

plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.tight_layout()



In [69]:
print("  i - Y_true[i] - Y_pred[i] -  error?")
print("-------------------------------------")
for i in range(predicted_classes.size):
    print("{:>3} -    {:>1d}      -    {:d}      -   {:>2d}".format(i, 
                                                    Y[testing_set_indices][i], 
                                                    predicted_classes[i],
                                                    (Y[testing_set_indices][i] != predicted_classes[i]), 
                                                   ))


  i - Y_true[i] - Y_pred[i] -  error?
-------------------------------------
  0 -    0      -    1      -    1
  1 -    1      -    1      -    0
  2 -    0      -    1      -    1
  3 -    1      -    1      -    0
  4 -    0      -    1      -    1
  5 -    0      -    1      -    1
  6 -    0      -    1      -    1
  7 -    0      -    1      -    1
  8 -    0      -    1      -    1
  9 -    0      -    1      -    1
 10 -    1      -    1      -    0
 11 -    0      -    1      -    1
 12 -    1      -    1      -    0
 13 -    0      -    1      -    1
 14 -    0      -    0      -    0
 15 -    0      -    1      -    1
 16 -    0      -    1      -    1
 17 -    0      -    1      -    1
 18 -    1      -    1      -    0
 19 -    1      -    1      -    0
 20 -    0      -    1      -    1
 21 -    0      -    1      -    1
 22 -    1      -    1      -    0
 23 -    0      -    1      -    1
 24 -    0      -    1      -    1
 25 -    0      -    1      -    1
 26 -    0      -    1      -    1
 27 -    0      -    1      -    1
 28 -    0      -    1      -    1
 29 -    0      -    1      -    1
 30 -    0      -    1      -    1
 31 -    0      -    1      -    1
 32 -    1      -    1      -    0
 33 -    0      -    1      -    1
 34 -    0      -    1      -    1
 35 -    1      -    0      -    1
 36 -    0      -    1      -    1
 37 -    1      -    1      -    0
 38 -    0      -    1      -    1
 39 -    1      -    1      -    0
 40 -    0      -    1      -    1
 41 -    0      -    1      -    1
 42 -    1      -    1      -    0
 43 -    1      -    1      -    0
 44 -    0      -    1      -    1
 45 -    1      -    1      -    0
 46 -    0      -    1      -    1
 47 -    1      -    1      -    0
 48 -    0      -    0      -    0
 49 -    0      -    1      -    1
 50 -    0      -    0      -    0
 51 -    0      -    1      -    1
 52 -    1      -    1      -    0
 53 -    0      -    1      -    1
 54 -    0      -    0      -    0
 55 -    0      -    1      -    1
 56 -    0      -    1      -    1
 57 -    0      -    1      -    1
 58 -    0      -    1      -    1
 59 -    0      -    1      -    1
 60 -    1      -    1      -    0
 61 -    0      -    1      -    1
 62 -    0      -    1      -    1
 63 -    0      -    1      -    1
 64 -    0      -    1      -    1
 65 -    0      -    1      -    1
 66 -    0      -    1      -    1
 67 -    1      -    1      -    0
 68 -    1      -    1      -    0
 69 -    0      -    1      -    1
 70 -    0      -    1      -    1
 71 -    0      -    1      -    1
 72 -    0      -    0      -    0
 73 -    0      -    1      -    1
 74 -    0      -    1      -    1
 75 -    0      -    0      -    0
 76 -    0      -    1      -    1
 77 -    1      -    1      -    0
 78 -    0      -    1      -    1
 79 -    0      -    1      -    1
 80 -    0      -    0      -    0
 81 -    1      -    1      -    0
 82 -    0      -    1      -    1
 83 -    0      -    1      -    1
 84 -    0      -    1      -    1
 85 -    1      -    1      -    0
 86 -    0      -    1      -    1
 87 -    1      -    1      -    0
 88 -    0      -    1      -    1
 89 -    0      -    1      -    1
 90 -    1      -    1      -    0
 91 -    0      -    1      -    1
 92 -    0      -    1      -    1
 93 -    1      -    1      -    0
 94 -    0      -    1      -    1
 95 -    0      -    1      -    1
 96 -    0      -    1      -    1
 97 -    1      -    1      -    0
 98 -    1      -    1      -    0
 99 -    0      -    1      -    1
100 -    0      -    1      -    1
101 -    0      -    1      -    1
102 -    1      -    1      -    0
103 -    0      -    1      -    1
104 -    0      -    1      -    1
105 -    1      -    1      -    0
106 -    0      -    1      -    1
107 -    1      -    1      -    0
108 -    0      -    1      -    1
109 -    0      -    1      -    1
110 -    0      -    1      -    1
111 -    0      -    1      -    1
112 -    0      -    1      -    1
113 -    0      -    1      -    1
114 -    0      -    1      -    1
115 -    0      -    1      -    1
116 -    0      -    1      -    1
117 -    0      -    1      -    1
118 -    0      -    1      -    1
119 -    0      -    1      -    1
120 -    0      -    1      -    1
121 -    0      -    1      -    1
122 -    0      -    1      -    1
123 -    0      -    1      -    1
124 -    0      -    1      -    1
125 -    1      -    1      -    0
126 -    0      -    1      -    1
127 -    0      -    1      -    1
128 -    0      -    1      -    1
129 -    1      -    1      -    0
130 -    1      -    1      -    0
131 -    0      -    1      -    1
132 -    0      -    1      -    1
133 -    0      -    0      -    0
134 -    0      -    1      -    1
135 -    0      -    1      -    1
136 -    0      -    1      -    1
137 -    0      -    1      -    1
138 -    0      -    1      -    1
139 -    0      -    0      -    0
140 -    0      -    1      -    1
141 -    0      -    1      -    1
142 -    0      -    1      -    1
143 -    1      -    1      -    0
144 -    0      -    0      -    0
145 -    0      -    1      -    1
146 -    0      -    1      -    1
147 -    0      -    1      -    1
148 -    0      -    1      -    1
149 -    0      -    1      -    1
150 -    1      -    1      -    0
151 -    0      -    1      -    1
152 -    0      -    1      -    1
153 -    0      -    1      -    1
154 -    0      -    1      -    1
155 -    0      -    1      -    1
156 -    0      -    1      -    1
157 -    1      -    1      -    0
158 -    1      -    1      -    0
159 -    1      -    1      -    0
160 -    0      -    1      -    1
161 -    0      -    1      -    1
162 -    1      -    1      -    0
163 -    1      -    1      -    0
164 -    1      -    1      -    0
165 -    0      -    1      -    1
166 -    1      -    1      -    0
167 -    1      -    1      -    0
168 -    1      -    1      -    0
169 -    1      -    1      -    0
170 -    0      -    1      -    1
171 -    0      -    1      -    1
172 -    0      -    1      -    1
173 -    0      -    1      -    1
174 -    0      -    1      -    1
175 -    0      -    1      -    1
176 -    1      -    1      -    0
177 -    0      -    1      -    1
178 -    0      -    1      -    1
179 -    0      -    0      -    0
180 -    0      -    1      -    1
181 -    0      -    1      -    1
182 -    1      -    1      -    0
183 -    0      -    1      -    1
184 -    0      -    1      -    1
185 -    0      -    1      -    1
186 -    0      -    0      -    0
187 -    0      -    1      -    1
188 -    0      -    1      -    1
189 -    1      -    1      -    0
190 -    0      -    1      -    1
191 -    0      -    1      -    1
192 -    0      -    1      -    1
193 -    0      -    1      -    1
194 -    0      -    1      -    1
195 -    0      -    1      -    1
196 -    0      -    1      -    1
197 -    1      -    1      -    0
198 -    0      -    1      -    1
199 -    0      -    1      -    1
200 -    1      -    1      -    0
201 -    1      -    1      -    0
202 -    1      -    1      -    0
203 -    0      -    1      -    1
204 -    1      -    1      -    0
205 -    1      -    1      -    0
206 -    1      -    1      -    0
207 -    0      -    1      -    1
208 -    0      -    1      -    1
209 -    0      -    0      -    0
210 -    0      -    1      -    1
211 -    0      -    0      -    0
212 -    0      -    1      -    1
213 -    0      -    1      -    1
214 -    0      -    0      -    0
215 -    1      -    1      -    0
216 -    1      -    1      -    0
217 -    1      -    1      -    0
218 -    0      -    1      -    1
219 -    1      -    1      -    0
220 -    1      -    1      -    0
221 -    0      -    1      -    1
222 -    0      -    1      -    1
223 -    0      -    1      -    1
224 -    1      -    0      -    1
225 -    0      -    1      -    1
226 -    0      -    1      -    1
227 -    0      -    1      -    1
228 -    0      -    1      -    1
229 -    0      -    1      -    1
230 -    0      -    1      -    1
231 -    0      -    1      -    1
232 -    0      -    1      -    1
233 -    0      -    1      -    1
234 -    0      -    1      -    1
235 -    1      -    1      -    0
236 -    0      -    1      -    1
237 -    0      -    1      -    1
238 -    1      -    1      -    0
239 -    0      -    1      -    1
240 -    0      -    1      -    1
241 -    0      -    0      -    0
242 -    0      -    0      -    0
243 -    0      -    1      -    1
244 -    1      -    1      -    0
245 -    1      -    1      -    0
246 -    0      -    1      -    1
247 -    0      -    1      -    1
248 -    0      -    1      -    1
249 -    0      -    1      -    1
250 -    0      -    1      -    1
251 -    0      -    1      -    1
252 -    0      -    1      -    1
253 -    0      -    1      -    1
254 -    1      -    1      -    0
255 -    0      -    0      -    0
256 -    0      -    1      -    1
257 -    0      -    1      -    1
258 -    0      -    0      -    0
259 -    0      -    1      -    1
260 -    0      -    1      -    1
261 -    0      -    0      -    0
262 -    0      -    1      -    1
263 -    0      -    1      -    1
264 -    1      -    1      -    0
265 -    0      -    1      -    1
266 -    0      -    1      -    1
267 -    0      -    1      -    1
268 -    1      -    1      -    0
269 -    1      -    1      -    0
270 -    0      -    1      -    1
271 -    1      -    1      -    0
272 -    0      -    1      -    1
273 -    0      -    0      -    0
274 -    0      -    1      -    1
275 -    0      -    1      -    1
276 -    0      -    1      -    1
277 -    0      -    1      -    1
278 -    0      -    1      -    1
279 -    0      -    1      -    1
280 -    1      -    1      -    0
281 -    0      -    1      -    1
282 -    0      -    1      -    1
283 -    1      -    1      -    0
284 -    0      -    1      -    1
285 -    0      -    1      -    1
286 -    0      -    1      -    1
287 -    0      -    0      -    0
288 -    1      -    1      -    0
289 -    0      -    1      -    1
290 -    0      -    1      -    1
291 -    0      -    1      -    1
292 -    0      -    1      -    1
293 -    0      -    1      -    1
294 -    0      -    1      -    1
295 -    0      -    1      -    1
296 -    0      -    1      -    1
297 -    0      -    0      -    0
298 -    1      -    1      -    0
299 -    0      -    1      -    1
300 -    0      -    1      -    1
301 -    0      -    1      -    1
302 -    0      -    1      -    1
303 -    1      -    1      -    0
304 -    0      -    1      -    1
305 -    0      -    1      -    1
306 -    1      -    1      -    0
307 -    0      -    1      -    1
308 -    0      -    1      -    1
309 -    0      -    1      -    1
310 -    0      -    1      -    1
311 -    1      -    1      -    0
312 -    0      -    1      -    1
313 -    0      -    1      -    1
314 -    1      -    1      -    0
315 -    0      -    1      -    1
316 -    0      -    1      -    1
317 -    0      -    1      -    1
318 -    1      -    1      -    0
319 -    0      -    1      -    1
320 -    0      -    1      -    1
321 -    0      -    1      -    1
322 -    1      -    1      -    0
323 -    0      -    1      -    1
324 -    0      -    1      -    1
325 -    0      -    0      -    0
326 -    0      -    1      -    1
327 -    0      -    1      -    1
328 -    0      -    1      -    1
329 -    0      -    1      -    1
330 -    0      -    0      -    0
331 -    0      -    1      -    1
332 -    1      -    1      -    0
333 -    1      -    1      -    0
334 -    0      -    1      -    1
335 -    1      -    1      -    0
336 -    0      -    1      -    1
337 -    0      -    1      -    1
338 -    0      -    1      -    1
339 -    1      -    1      -    0
340 -    0      -    0      -    0
341 -    0      -    1      -    1
342 -    1      -    0      -    1
343 -    1      -    1      -    0
344 -    0      -    1      -    1
345 -    0      -    1      -    1
346 -    0      -    1      -    1
347 -    0      -    1      -    1
348 -    0      -    1      -    1
349 -    0      -    1      -    1
350 -    1      -    1      -    0
351 -    0      -    1      -    1
352 -    0      -    1      -    1
353 -    0      -    0      -    0
354 -    0      -    1      -    1
355 -    0      -    1      -    1
356 -    0      -    1      -    1
357 -    0      -    1      -    1
358 -    0      -    1      -    1
359 -    0      -    1      -    1
360 -    0      -    1      -    1
361 -    1      -    1      -    0
362 -    0      -    1      -    1
363 -    0      -    1      -    1
364 -    0      -    1      -    1
365 -    0      -    1      -    1
366 -    0      -    1      -    1
367 -    0      -    1      -    1
368 -    1      -    1      -    0
369 -    0      -    1      -    1
370 -    0      -    1      -    1
371 -    0      -    1      -    1
372 -    0      -    1      -    1

Analyze Errors


In [77]:
HSC_ids


Out[77]:
array([43158176442374224, 43158176442374373, 43158176442374445, ...,
       43159155694916013, 43159155694916476, 43159155694917496])

In [79]:
df.loc[HSC_ids[testing_set_indices]].head()


Out[79]:
target
HSC_id
43158880817015286 False
43159142810013371 True
43158876522047030 False
43158863637144621 True
43158335356145699 False

In [80]:
COSMOS_filename = os.path.join(dwarfz.data_dir_default, 
                               "COSMOS_reference.sqlite")
COSMOS = dwarfz.datasets.COSMOS(COSMOS_filename)

In [81]:
HSC_filename = os.path.join(dwarfz.data_dir_default, 
                            "HSC_COSMOS_median_forced.sqlite3")
HSC = dwarfz.datasets.HSC(HSC_filename)

In [82]:
matches_filename = os.path.join(dwarfz.data_dir_default, 
                                "matches.sqlite3")
matches_df = dwarfz.matching.Matches.load_from_filename(matches_filename)

In [89]:
combined = matches_df[matches_df.match].copy()
combined["ra"]       = COSMOS.df.loc[combined.index].ra
combined["dec"]      = COSMOS.df.loc[combined.index].dec
combined["photo_z"]  = COSMOS.df.loc[combined.index].photo_z
combined["log_mass"] = COSMOS.df.loc[combined.index].mass_med
combined["active"]   = COSMOS.df.loc[combined.index].classification

combined = combined.set_index("catalog_2_ids")

combined.head()


Out[89]:
sep match error ra dec photo_z log_mass active
catalog_2_ids
43158996781122114 0.114389 True False 149.749393 1.618068 0.3797 11.07610 0
43158447025298860 0.471546 True False 150.388349 1.614538 2.3343 8.99275 1
43158447025298862 0.202378 True False 150.402935 1.614631 2.1991 9.71373 1
43158584464246387 0.207967 True False 150.295083 1.614662 2.4407 9.77811 1
43158584464253383 0.295316 True False 150.239919 1.614675 0.2079 7.04224 1

In [93]:
df_features_testing = combined.loc[HSC_ids[testing_set_indices]]

In [148]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
#            C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Number of objects",
            )

plt.suptitle("All Objects")


348
Out[148]:
Text(0.5,0.98,'All Objects')

In [146]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )


plt.suptitle("All Objects")


348
Out[146]:
Text(0.5,0.98,'All Objects')

^^ Huh, that's a pretty uniform looking spread. It doesn't really seem like it's trending in an useful direction (either near the desired boundaries or as you go further away).


In [147]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )

plt.suptitle("False Positives")


243
Out[147]:
Text(0.5,0.98,'False Positives')

In [ ]: