In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl

import seaborn as sns; sns.set(style="white", font_scale=2)

import numpy as np
import pandas as pd
from astropy.io import fits
import glob

import sklearn

import keras
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.layers.noise import GaussianNoise

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.preprocessing.image import random_channel_shift
from keras.optimizers import SGD, Adam
from keras import backend as K
K.set_image_data_format('channels_first')

import scipy.ndimage as ndi

import matplotlib.patches as patches


Using TensorFlow backend.

In [162]:
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.figsize'] = np.array((10,6))*.6
mpl.rcParams['figure.facecolor'] = "white"

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# give access to importing dwarfz
import os, sys
dwarfz_package_dir = os.getcwd().split("dwarfz")[0]
if dwarfz_package_dir not in sys.path:
    sys.path.insert(0, dwarfz_package_dir)

import dwarfz
    
# back to regular import statements

In [4]:
# my modules that are DNN specific
import preprocessing
import geometry

To Do

  1. Read in fits images
  2. Apply pre-processing stretch (asinh)
  3. crop image smaller
  4. Combine filters into one cube
  5. Create training set with labels
  6. Set up keras model
  7. Poke around at the results

0) Get files


In [5]:
images_dir = preprocessing.images_dir
images_dir


Out[5]:
'../data/galaxy_images_training/quarry_files/'

In [6]:
HSC_ids = [int(os.path.basename(image_file).split("-")[0])
           for image_file in glob.glob(os.path.join(images_dir, "*.fits"))]
HSC_ids = set(HSC_ids) # remove duplicates
HSC_ids = np.array(sorted(HSC_ids))

# now filter out galaxies missing bands
HSC_ids = [HSC_id
           for HSC_id in HSC_ids
           if len(glob.glob(os.path.join(images_dir, "{}*.fits".format(str(HSC_id)))))==5]
HSC_ids = np.array(HSC_ids)

In [7]:
HSC_ids.size


Out[7]:
1866

In [8]:
HSC_id = HSC_ids[1] # for when I need a single sample galaxy

In [9]:
bands = ["g", "r", "i", "z", "y"]

1) Read in fits image


In [10]:
image, flux_mag_0 = preprocessing.get_image(HSC_id, "g")
print("image size: {} x {}".format(*image.shape))
image


image size: 239 x 239
Out[10]:
array([[-0.00146446,  0.02150147,  0.00693631, ...,  0.02792656,
        -0.02018547,  0.01850448],
       [-0.02519608,  0.00035813, -0.0455959 , ..., -0.00586783,
        -0.00882499,  0.01241659],
       [ 0.03936524,  0.00645859, -0.02110163, ...,  0.02997875,
         0.009456  ,  0.00591614],
       ...,
       [-0.03543996,  0.04346127, -0.0372493 , ..., -0.0014411 ,
        -0.01001758,  0.03473332],
       [-0.01037648, -0.03287457,  0.04310744, ...,  0.02935715,
         0.02273993, -0.00532476],
       [-0.05991019, -0.08159582,  0.02607481, ...,  0.01012528,
         0.00453719, -0.00872836]], dtype=float32)

In [11]:
preprocessing.image_plotter(image)



In [12]:
preprocessing.image_plotter(np.log(image))


/Users/egentry/anaconda3/envs/tf/lib/python3.6/site-packages/ipykernel_launcher.py:1: RuntimeWarning: invalid value encountered in log
  """Entry point for launching an IPython kernel.

2) Apply stretch

We're using (negative) asinh magnitudes, as implemented by the HSC collaboration.

To see more about asinh magnitude system, see : Lupton, Gunn and Szalay (1999) used for SDSS. (It's expliticly given in the SDSS Algorithms documentation as well as this overview page).

To see the source of our code, see: the HSC color image creator

And for reference, a common form of this stretch is: $$ \mathrm{mag}_\mathrm{asinh} = - \left(\frac{2.5}{\ln(10)}\right) \left(\mathrm{asinh}\left(\frac{f/f_0}{2b}\right) + \ln(b) \right)$$ for dimensionless softening parameter $b$, and reference flux (f_0).


In [13]:
def scale(x, fluxMag0):
    ### adapted from https://hsc-gitlab.mtk.nao.ac.jp/snippets/23
    mag0 = 19
    scale = 10 ** (0.4 * mag0) / fluxMag0
    x *= scale

    u_min = -0.05
    u_max = 2. / 3.
    u_a = np.exp(10.)

    x = np.arcsinh(u_a*x) / np.arcsinh(u_a)
    x = (x - u_min) / (u_max - u_min)

    return x

In [14]:
image_scaled = preprocessing.scale(image, flux_mag_0)

In [15]:
preprocessing.image_plotter(image_scaled)



In [16]:
sns.distplot(image_scaled.flatten())
plt.title("Distribution of Transformed Intensities")


/Users/egentry/anaconda3/envs/tf/lib/python3.6/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
Out[16]:
Text(0.5, 1.0, 'Distribution of Transformed Intensities')

In [17]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)

    plt.figure()
    preprocessing.image_plotter(image_scaled)
    plt.title("{} band".format(band))
    plt.colorbar()


3) Crop Image

Am I properly handling odd numbers?


In [18]:
pre_transformed_image_size  = 150
post_transformed_image_size = 75

In [19]:
cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)
cutout.shape


Out[19]:
(75, 75)

In [20]:
preprocessing.image_plotter(cutout)



In [21]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)
    cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)

    plt.figure()
    preprocessing.image_plotter(cutout)
    plt.title("{} band".format(band))
    plt.colorbar()


4) Combine filters into cube


In [22]:
images = [None]*len(bands)
flux_mag_0s = [None]*len(bands)
cutouts = [None]*len(bands)
for i, band in enumerate(bands):
    images[i], flux_mag_0s[i] = preprocessing.get_image(HSC_id, band)
    
    cutouts[i] = preprocessing.get_cutout(
        preprocessing.scale(images[i], flux_mag_0s[i]), 
        post_transformed_image_size
    )

In [23]:
cutout_cube = np.array(cutouts)
cutout_cube.shape


Out[23]:
(5, 75, 75)

In [24]:
# must transform into [0,1] for plt.imshow
# the HSC standard tool accomplishes this by clipping instead.
plt.imshow(preprocessing.transform_0_1(cutout_cube[:3,:,:].transpose(1,2,0)) )


Out[24]:
<matplotlib.image.AxesImage at 0x1a25670400>

In [25]:
for i, band in enumerate(bands):
    sns.distplot(cutout_cube[:,:,:].transpose(1,2,0)[:,:,i].flatten(), label=band)
    plt.legend(loc="best")


5) Load Training Set Labels


In [26]:
training_set_labels_filename = "../data/galaxy_images_training/2017_09_26-dwarf_galaxy_scores.csv"

In [27]:
df = pd.read_csv(training_set_labels_filename)
df = df.drop_duplicates("HSC_id")
df = df.set_index("HSC_id")
df = df[["low_z_low_mass"]]
df = df.rename(columns={"low_z_low_mass":"target"})
df.head()


Out[27]:
target
HSC_id
43158322471244656 False
43158605939114836 False
43159142810013665 False
43158734788125011 False
43158863637144621 True

In [29]:
def load_image_mappable(HSC_id):
    images      = [None]*len(bands)
    flux_mag_0s = [None]*len(bands)
    cutouts     = [None]*len(bands)
    for j, band in enumerate(bands):
        images[j], flux_mag_0s[j] = preprocessing.get_image(HSC_id, band)
        cutouts[j] = preprocessing.get_cutout(
            preprocessing.scale(images[j], flux_mag_0s[j]),
            pre_transformed_image_size)
    cutout_cube = np.array(cutouts)
    return cutout_cube

In [30]:
X = np.empty((len(HSC_ids), 5, 
              pre_transformed_image_size, pre_transformed_image_size))

In [31]:
X = np.array(list(map(load_image_mappable, HSC_ids)))

In [33]:
X_full = X
X_small = X[:,(0,2,4),:,:] # drop down to 3 bands

In [34]:
X.shape


Out[34]:
(1866, 5, 150, 150)

In [35]:
Y = df.loc[HSC_ids].target.values
Y


Out[35]:
array([False, False,  True, ..., False, False, False])

In [36]:
Y.mean()


Out[36]:
0.2792068595927117

Geometry!


In [37]:
h = pre_transformed_image_size
w = pre_transformed_image_size
transform_matrix = geometry.create_random_transform_matrix(h, w,
                                                  include_rotation=True,
                                                  translation_size = .01,
                                                  verbose=False)

x_tmp = X[0][:3]

result = geometry.apply_transform_new(x_tmp, transform_matrix, 
                            channel_axis=0, fill_mode="constant", cval=np.max(x_tmp))

result = preprocessing.get_cutout(x_tmp, post_transformed_image_size)
plt.imshow(preprocessing.transform_0_1(result.transpose(1,2,0)))


Out[37]:
<matplotlib.image.AxesImage at 0x1a241ee438>

In [38]:
import ipywidgets
ipywidgets.interact(preprocessing.transform_plotter,
                    X = ipywidgets.fixed(X),
                    rotation_degrees = ipywidgets.IntSlider(min=0, max=360, step=15, value=45),
                    dx_after = ipywidgets.IntSlider(min=-15, max=15),
                    dy_after = ipywidgets.IntSlider(min=-15, max=15),
                    color = ipywidgets.fixed(True),
                    shear_degrees = ipywidgets.IntSlider(min=0, max=90, step=5, value=0),
                    zoom_x = ipywidgets.FloatSlider(min=.5, max=2, value=1),
                    crop = ipywidgets.Checkbox(value=True)
                    )


Out[38]:
<function preprocessing.transform_plotter(X, reflect_x=False, rotation_degrees=45, dx_after=0, dy_after=0, shear_degrees=0, zoom_x=1, crop=False, color=True)>

5b) Split training and testing set


In [119]:
np.random.seed(1)

randomized_indices = np.arange(X.shape[0])
np.random.shuffle(randomized_indices)

testing_fraction = 0.2
testing_set_indices = randomized_indices[:int(testing_fraction*X.shape[0])]
training_set_indices = np.array(list(set([*randomized_indices]) - set([*testing_set_indices])))

In [120]:
testing_set_indices.size


Out[120]:
373

In [121]:
training_set_indices.size


Out[121]:
1493

6b) Adapt NumpyArrayIterator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images.

Also, I want to change the way that augmentation is happening


In [122]:
from data_generator import ArrayIterator

6c) Adapt ImageDataGenerator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images. Also, I'm adjusting the way that the affine transformations work for the data augmentation


In [123]:
from data_generator import ImageDataGenerator

6d) Create Data Generator


In [124]:
print('Using real-time data augmentation.')

h_before, w_before = X[0,0].shape
print("image shape before: ({},{})".format(h_before, w_before))

h_after = post_transformed_image_size
w_after = post_transformed_image_size
print("image shape after:  ({},{})".format(h_after, w_after))

# get a closure that binds the image size to get_cutout
postprocessing_function = lambda image: preprocessing.get_cutout(image, post_transformed_image_size)

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    with_reflection_x=True, # randomly apply a reflection (in x)
    with_reflection_y=True, # randomly apply a reflection (in y)
    with_rotation=False, # randomly apply a rotation
    width_shift_range=0.002,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.002,  # randomly shift images vertically (fraction of total height)
    postprocessing_function=postprocessing_function, # get a cutout of the processed image
    output_image_shape=(post_transformed_image_size,post_transformed_image_size)
)


Using real-time data augmentation.
image shape before: (150,150)
image shape after:  (75,75)

In [125]:
datagen.fit(X_full[training_set_indices])

7) Set up keras model


In [142]:
input_shape = cutout_cube.shape

nb_dense = 64

In [143]:
from keras.applications import inception_v3, inception_resnet_v2, vgg19

In [144]:
vgg19.VGG19(include_top=False, input_shape=(3, 75, 75))


Out[144]:
<keras.engine.training.Model at 0x1a474e0e48>

In [146]:
model = Sequential()

# # # 1x1 convolution to make sure we only have 3 channels
n_channels_for_pretrained=3
one_by_one = Conv2D(n_channels_for_pretrained, 1, padding='same',
                 input_shape=input_shape)
model.add(one_by_one)

pretrained_input_shape = tuple([n_channels_for_pretrained, *input_shape[1:]])
pretrained_layers = vgg19.VGG19(include_top=False,
                                input_shape=pretrained_input_shape
                               )
for layer in pretrained_layers.layers:
    layer.trainable = False
model.add(pretrained_layers)

model.add(Flatten())
model.add(Dense(2*nb_dense, activation="relu"))
model.add(Dense(nb_dense, activation="relu"))
model.add(Dense(1, activation="sigmoid"))

In [147]:
model.layers


Out[147]:
[<keras.layers.convolutional.Conv2D at 0x1a480dd978>,
 <keras.engine.training.Model at 0x1a4626eba8>,
 <keras.layers.core.Flatten at 0x1a480dd908>,
 <keras.layers.core.Dense at 0x1a46294668>,
 <keras.layers.core.Dense at 0x1a4645cdd8>,
 <keras.layers.core.Dense at 0x1a6766a898>]

In [148]:
learning_rate = 0.00025
decay = 1e-5
momentum = 0.9

# sgd = SGD(lr=learning_rate, decay=decay, momentum=momentum, nesterov=True)

adam = Adam(lr=learning_rate)

In [149]:
logger_filename = "training_transfer.one_by_one.log"

model.compile(loss='binary_crossentropy', 
#               optimizer=sgd, 
              optimizer=adam,
             )

# can only manually set weights _after_ compiling
one_by_one_weights = np.zeros((1,1,5,3))
for i in range(3):
    one_by_one_weights[0, 0, i, i] = 1.
one_by_one.set_weights([one_by_one_weights,
                        np.zeros(3)])

if os.path.exists(logger_filename):
    logger_filename_tmp = logger_filename + ".old"
    os.rename(logger_filename, logger_filename_tmp)

In [150]:
earlystopping = EarlyStopping(monitor='loss',
                              patience=35,
                              verbose=1,
                              mode='auto' )

In [151]:
csv_logger = keras.callbacks.CSVLogger(logger_filename,
                                       append=True)

In [152]:
# modelcheckpoint = ModelCheckpoint(pathinCat+'Models/bestmodelMAG.hdf5',monitor='val_loss',verbose=0,save_best_only=True)

8) Run basic keras model


In [153]:
goal_batch_size = 64
steps_per_epoch = max(2, training_set_indices.size//goal_batch_size)
batch_size = training_set_indices.size//steps_per_epoch
print("steps_per_epoch: ", steps_per_epoch)
print("batch_size: ", batch_size)
epochs = 200
verbose = 1


steps_per_epoch:  23
batch_size:  64

In [154]:
X_test_transformed = np.array([datagen.standardize(X_img)
                               for X_img in X_full[testing_set_indices]])

X_test_transformed.shape


Out[154]:
(373, 5, 75, 75)

In [156]:
history = model.fit_generator(datagen.flow(X_full[training_set_indices], 
                                           Y[training_set_indices],
                                           batch_size=batch_size,
                                          ),
                              steps_per_epoch=steps_per_epoch,
                              epochs=epochs,
                              validation_data=(X_test_transformed, 
                                               Y[testing_set_indices]),
                              verbose=verbose,
                              callbacks=[
#                                   earlystopping,
                                  csv_logger],
                              )


Epoch 1/200
23/23 [==============================] - 183s 8s/step - loss: 0.6205 - val_loss: 0.6131
Epoch 2/200
23/23 [==============================] - 186s 8s/step - loss: 0.5914 - val_loss: 0.6128
Epoch 3/200
23/23 [==============================] - 181s 8s/step - loss: 0.5782 - val_loss: 0.6128
Epoch 4/200
23/23 [==============================] - 178s 8s/step - loss: 0.6093 - val_loss: 0.6154
Epoch 5/200
23/23 [==============================] - 181s 8s/step - loss: 0.5846 - val_loss: 0.6055
Epoch 6/200
23/23 [==============================] - 181s 8s/step - loss: 0.5933 - val_loss: 0.6058
Epoch 7/200
23/23 [==============================] - 180s 8s/step - loss: 0.5892 - val_loss: 0.6069
Epoch 8/200
23/23 [==============================] - 181s 8s/step - loss: 0.5946 - val_loss: 0.6041
Epoch 9/200
23/23 [==============================] - 185s 8s/step - loss: 0.5816 - val_loss: 0.6011
Epoch 10/200
23/23 [==============================] - 181s 8s/step - loss: 0.5840 - val_loss: 0.6013
Epoch 11/200
23/23 [==============================] - 180s 8s/step - loss: 0.5794 - val_loss: 0.5980
Epoch 12/200
23/23 [==============================] - 181s 8s/step - loss: 0.5881 - val_loss: 0.6000
Epoch 13/200
23/23 [==============================] - 181s 8s/step - loss: 0.5874 - val_loss: 0.5976
Epoch 14/200
23/23 [==============================] - 177s 8s/step - loss: 0.5871 - val_loss: 0.5974
Epoch 15/200
23/23 [==============================] - 185s 8s/step - loss: 0.5841 - val_loss: 0.5971
Epoch 16/200
23/23 [==============================] - 181s 8s/step - loss: 0.5789 - val_loss: 0.6038
Epoch 17/200
23/23 [==============================] - 181s 8s/step - loss: 0.5880 - val_loss: 0.6099
Epoch 18/200
23/23 [==============================] - 181s 8s/step - loss: 0.5840 - val_loss: 0.6197
Epoch 19/200
23/23 [==============================] - 181s 8s/step - loss: 0.5972 - val_loss: 0.5957
Epoch 20/200
23/23 [==============================] - 181s 8s/step - loss: 0.5923 - val_loss: 0.6012
Epoch 21/200
23/23 [==============================] - 176s 8s/step - loss: 0.5880 - val_loss: 0.6007
Epoch 22/200
23/23 [==============================] - 185s 8s/step - loss: 0.5863 - val_loss: 0.6206
Epoch 23/200
23/23 [==============================] - 181s 8s/step - loss: 0.5892 - val_loss: 0.5990
Epoch 24/200
23/23 [==============================] - 182s 8s/step - loss: 0.5851 - val_loss: 0.6018
Epoch 25/200
23/23 [==============================] - 180s 8s/step - loss: 0.5833 - val_loss: 0.5962
Epoch 26/200
23/23 [==============================] - 187s 8s/step - loss: 0.5795 - val_loss: 0.5959
Epoch 27/200
23/23 [==============================] - 181s 8s/step - loss: 0.5837 - val_loss: 0.5998
Epoch 28/200
23/23 [==============================] - 181s 8s/step - loss: 0.5942 - val_loss: 0.5908
Epoch 29/200
23/23 [==============================] - 181s 8s/step - loss: 0.5629 - val_loss: 0.5945
Epoch 30/200
23/23 [==============================] - 181s 8s/step - loss: 0.5847 - val_loss: 0.5973
Epoch 31/200
23/23 [==============================] - 185s 8s/step - loss: 0.5870 - val_loss: 0.5950
Epoch 32/200
23/23 [==============================] - 181s 8s/step - loss: 0.5808 - val_loss: 0.5907
Epoch 33/200
23/23 [==============================] - 176s 8s/step - loss: 0.5984 - val_loss: 0.5931
Epoch 34/200
23/23 [==============================] - 181s 8s/step - loss: 0.5883 - val_loss: 0.5897
Epoch 35/200
23/23 [==============================] - 181s 8s/step - loss: 0.5757 - val_loss: 0.5901
Epoch 36/200
23/23 [==============================] - 181s 8s/step - loss: 0.5953 - val_loss: 0.5885
Epoch 37/200
23/23 [==============================] - 185s 8s/step - loss: 0.5790 - val_loss: 0.5884
Epoch 38/200
23/23 [==============================] - 180s 8s/step - loss: 0.5807 - val_loss: 0.5895
Epoch 39/200
23/23 [==============================] - 181s 8s/step - loss: 0.5890 - val_loss: 0.5914
Epoch 40/200
23/23 [==============================] - 176s 8s/step - loss: 0.5693 - val_loss: 0.5930
Epoch 41/200
23/23 [==============================] - 185s 8s/step - loss: 0.5811 - val_loss: 0.5925
Epoch 42/200
23/23 [==============================] - 181s 8s/step - loss: 0.5942 - val_loss: 0.5899
Epoch 43/200
23/23 [==============================] - 181s 8s/step - loss: 0.5821 - val_loss: 0.5894
Epoch 44/200
23/23 [==============================] - 177s 8s/step - loss: 0.5776 - val_loss: 0.5963
Epoch 45/200
23/23 [==============================] - 186s 8s/step - loss: 0.5784 - val_loss: 0.5905
Epoch 46/200
23/23 [==============================] - 180s 8s/step - loss: 0.5849 - val_loss: 0.5898
Epoch 47/200
23/23 [==============================] - 181s 8s/step - loss: 0.5862 - val_loss: 0.5890
Epoch 48/200
23/23 [==============================] - 180s 8s/step - loss: 0.5863 - val_loss: 0.5942
Epoch 49/200
23/23 [==============================] - 181s 8s/step - loss: 0.5805 - val_loss: 0.5934
Epoch 50/200
23/23 [==============================] - 182s 8s/step - loss: 0.5942 - val_loss: 0.5901
Epoch 51/200
23/23 [==============================] - 181s 8s/step - loss: 0.5871 - val_loss: 0.5897
Epoch 52/200
23/23 [==============================] - 180s 8s/step - loss: 0.5717 - val_loss: 0.5929
Epoch 53/200
23/23 [==============================] - 180s 8s/step - loss: 0.5972 - val_loss: 0.5961
Epoch 54/200
23/23 [==============================] - 181s 8s/step - loss: 0.5831 - val_loss: 0.5934
Epoch 55/200
23/23 [==============================] - 180s 8s/step - loss: 0.5671 - val_loss: 0.5888
Epoch 56/200
23/23 [==============================] - 180s 8s/step - loss: 0.5854 - val_loss: 0.5873
Epoch 57/200
23/23 [==============================] - 180s 8s/step - loss: 0.5810 - val_loss: 0.5889
Epoch 58/200
23/23 [==============================] - 184s 8s/step - loss: 0.5871 - val_loss: 0.5888
Epoch 59/200
23/23 [==============================] - 175s 8s/step - loss: 0.5774 - val_loss: 0.5905
Epoch 60/200
23/23 [==============================] - 185s 8s/step - loss: 0.5773 - val_loss: 0.5922
Epoch 61/200
23/23 [==============================] - 180s 8s/step - loss: 0.5900 - val_loss: 0.5893
Epoch 62/200
23/23 [==============================] - 175s 8s/step - loss: 0.5705 - val_loss: 0.5873
Epoch 63/200
23/23 [==============================] - 185s 8s/step - loss: 0.5928 - val_loss: 0.5875
Epoch 64/200
23/23 [==============================] - 181s 8s/step - loss: 0.5635 - val_loss: 0.6019
Epoch 65/200
23/23 [==============================] - 176s 8s/step - loss: 0.5941 - val_loss: 0.5848
Epoch 66/200
23/23 [==============================] - 180s 8s/step - loss: 0.5571 - val_loss: 0.6107
Epoch 67/200
23/23 [==============================] - 181s 8s/step - loss: 0.5792 - val_loss: 0.5996
Epoch 68/200
23/23 [==============================] - 185s 8s/step - loss: 0.5849 - val_loss: 0.5846
Epoch 69/200
23/23 [==============================] - 181s 8s/step - loss: 0.5857 - val_loss: 0.5866
Epoch 70/200
23/23 [==============================] - 175s 8s/step - loss: 0.5783 - val_loss: 0.5859
Epoch 71/200
23/23 [==============================] - 185s 8s/step - loss: 0.5863 - val_loss: 0.5848
Epoch 72/200
23/23 [==============================] - 180s 8s/step - loss: 0.5846 - val_loss: 0.5856
Epoch 73/200
23/23 [==============================] - 180s 8s/step - loss: 0.5902 - val_loss: 0.5891
Epoch 74/200
23/23 [==============================] - 181s 8s/step - loss: 0.5704 - val_loss: 0.5979
Epoch 75/200
23/23 [==============================] - 180s 8s/step - loss: 0.5738 - val_loss: 0.5871
Epoch 76/200
23/23 [==============================] - 180s 8s/step - loss: 0.5854 - val_loss: 0.5928
Epoch 77/200
23/23 [==============================] - 181s 8s/step - loss: 0.5858 - val_loss: 0.5836
Epoch 78/200
23/23 [==============================] - 180s 8s/step - loss: 0.5701 - val_loss: 0.5846
Epoch 79/200
23/23 [==============================] - 186s 8s/step - loss: 0.5830 - val_loss: 0.5893
Epoch 80/200
23/23 [==============================] - 180s 8s/step - loss: 0.5871 - val_loss: 0.5921
Epoch 81/200
23/23 [==============================] - 180s 8s/step - loss: 0.5792 - val_loss: 0.5944
Epoch 82/200
23/23 [==============================] - 181s 8s/step - loss: 0.5854 - val_loss: 0.5892
Epoch 83/200
23/23 [==============================] - 176s 8s/step - loss: 0.5880 - val_loss: 0.5883
Epoch 84/200
23/23 [==============================] - 186s 8s/step - loss: 0.5714 - val_loss: 0.5957
Epoch 85/200
23/23 [==============================] - 176s 8s/step - loss: 0.5878 - val_loss: 0.5909
Epoch 86/200
23/23 [==============================] - 185s 8s/step - loss: 0.5716 - val_loss: 0.5915
Epoch 87/200
23/23 [==============================] - 176s 8s/step - loss: 0.5851 - val_loss: 0.5911
Epoch 88/200
23/23 [==============================] - 180s 8s/step - loss: 0.5896 - val_loss: 0.5925
Epoch 89/200
23/23 [==============================] - 186s 8s/step - loss: 0.5731 - val_loss: 0.5926
Epoch 90/200
23/23 [==============================] - 180s 8s/step - loss: 0.5783 - val_loss: 0.5887
Epoch 91/200
23/23 [==============================] - 176s 8s/step - loss: 0.5763 - val_loss: 0.6011
Epoch 92/200
23/23 [==============================] - 181s 8s/step - loss: 0.5740 - val_loss: 0.6116
Epoch 93/200
23/23 [==============================] - 185s 8s/step - loss: 0.5838 - val_loss: 0.5961
Epoch 94/200
23/23 [==============================] - 181s 8s/step - loss: 0.5794 - val_loss: 0.5895
Epoch 95/200
23/23 [==============================] - 181s 8s/step - loss: 0.5771 - val_loss: 0.5962
Epoch 96/200
23/23 [==============================] - 180s 8s/step - loss: 0.5780 - val_loss: 0.5882
Epoch 97/200
23/23 [==============================] - 180s 8s/step - loss: 0.5798 - val_loss: 0.5969
Epoch 98/200
23/23 [==============================] - 180s 8s/step - loss: 0.5834 - val_loss: 0.6018
Epoch 99/200
23/23 [==============================] - 180s 8s/step - loss: 0.5803 - val_loss: 0.6066
Epoch 100/200
23/23 [==============================] - 181s 8s/step - loss: 0.5753 - val_loss: 0.5967
Epoch 101/200
23/23 [==============================] - 181s 8s/step - loss: 0.5833 - val_loss: 0.5914
Epoch 102/200
23/23 [==============================] - 186s 8s/step - loss: 0.5846 - val_loss: 0.5897
Epoch 103/200
23/23 [==============================] - 176s 8s/step - loss: 0.5755 - val_loss: 0.6019
Epoch 104/200
23/23 [==============================] - 186s 8s/step - loss: 0.5710 - val_loss: 0.5953
Epoch 105/200
23/23 [==============================] - 176s 8s/step - loss: 0.5919 - val_loss: 0.5872
Epoch 106/200
23/23 [==============================] - 180s 8s/step - loss: 0.5530 - val_loss: 0.6173
Epoch 107/200
23/23 [==============================] - 181s 8s/step - loss: 0.5984 - val_loss: 0.5885
Epoch 108/200
23/23 [==============================] - 185s 8s/step - loss: 0.5910 - val_loss: 0.5893
Epoch 109/200
23/23 [==============================] - 176s 8s/step - loss: 0.5522 - val_loss: 0.5915
Epoch 110/200
23/23 [==============================] - 185s 8s/step - loss: 0.5879 - val_loss: 0.5928
Epoch 111/200
23/23 [==============================] - 181s 8s/step - loss: 0.5747 - val_loss: 0.5933
Epoch 112/200
23/23 [==============================] - 175s 8s/step - loss: 0.5672 - val_loss: 0.6004
Epoch 113/200
23/23 [==============================] - 180s 8s/step - loss: 0.5901 - val_loss: 0.6026
Epoch 114/200
23/23 [==============================] - 187s 8s/step - loss: 0.5740 - val_loss: 0.5928
Epoch 115/200
23/23 [==============================] - 180s 8s/step - loss: 0.5766 - val_loss: 0.5936
Epoch 116/200
23/23 [==============================] - 181s 8s/step - loss: 0.5646 - val_loss: 0.5954
Epoch 117/200
23/23 [==============================] - 180s 8s/step - loss: 0.5723 - val_loss: 0.6084
Epoch 118/200
23/23 [==============================] - 181s 8s/step - loss: 0.5822 - val_loss: 0.5899
Epoch 119/200
23/23 [==============================] - 181s 8s/step - loss: 0.5735 - val_loss: 0.5944
Epoch 120/200
23/23 [==============================] - 181s 8s/step - loss: 0.5793 - val_loss: 0.5928
Epoch 121/200
23/23 [==============================] - 180s 8s/step - loss: 0.5745 - val_loss: 0.6017
Epoch 122/200
23/23 [==============================] - 181s 8s/step - loss: 0.5814 - val_loss: 0.5893
Epoch 123/200
23/23 [==============================] - 180s 8s/step - loss: 0.5733 - val_loss: 0.5939
Epoch 124/200
23/23 [==============================] - 181s 8s/step - loss: 0.5802 - val_loss: 0.6048
Epoch 125/200
23/23 [==============================] - 181s 8s/step - loss: 0.5624 - val_loss: 0.6394
Epoch 126/200
23/23 [==============================] - 181s 8s/step - loss: 0.5878 - val_loss: 0.5958
Epoch 127/200
23/23 [==============================] - 181s 8s/step - loss: 0.5847 - val_loss: 0.5933
Epoch 128/200
23/23 [==============================] - 181s 8s/step - loss: 0.5779 - val_loss: 0.5936
Epoch 129/200
23/23 [==============================] - 181s 8s/step - loss: 0.5727 - val_loss: 0.5925
Epoch 130/200
23/23 [==============================] - 185s 8s/step - loss: 0.5693 - val_loss: 0.6015
Epoch 131/200
23/23 [==============================] - 180s 8s/step - loss: 0.5745 - val_loss: 0.5996
Epoch 132/200
23/23 [==============================] - 181s 8s/step - loss: 0.5783 - val_loss: 0.5892
Epoch 133/200
23/23 [==============================] - 280s 12s/step - loss: 0.5825 - val_loss: 0.6081
Epoch 134/200
23/23 [==============================] - 309s 13s/step - loss: 0.5792 - val_loss: 0.6200
Epoch 135/200
23/23 [==============================] - 323s 14s/step - loss: 0.5869 - val_loss: 0.6039
Epoch 136/200
23/23 [==============================] - 316s 14s/step - loss: 0.5828 - val_loss: 0.5913
Epoch 137/200
23/23 [==============================] - 316s 14s/step - loss: 0.5742 - val_loss: 0.5896
Epoch 138/200
23/23 [==============================] - 307s 13s/step - loss: 0.5816 - val_loss: 0.6010
Epoch 139/200
23/23 [==============================] - 323s 14s/step - loss: 0.5740 - val_loss: 0.6057
Epoch 140/200
23/23 [==============================] - 316s 14s/step - loss: 0.5749 - val_loss: 0.5926
Epoch 141/200
23/23 [==============================] - 316s 14s/step - loss: 0.5763 - val_loss: 0.5922
Epoch 142/200
23/23 [==============================] - 316s 14s/step - loss: 0.5668 - val_loss: 0.6427
Epoch 143/200
23/23 [==============================] - 315s 14s/step - loss: 0.5817 - val_loss: 0.5918
Epoch 144/200
23/23 [==============================] - 316s 14s/step - loss: 0.5687 - val_loss: 0.5950
Epoch 145/200
23/23 [==============================] - 315s 14s/step - loss: 0.5721 - val_loss: 0.6170
Epoch 146/200
23/23 [==============================] - 316s 14s/step - loss: 0.5726 - val_loss: 0.5991
Epoch 147/200
23/23 [==============================] - 315s 14s/step - loss: 0.5743 - val_loss: 0.6087
Epoch 148/200
23/23 [==============================] - 317s 14s/step - loss: 0.5770 - val_loss: 0.5984
Epoch 149/200
23/23 [==============================] - 316s 14s/step - loss: 0.5609 - val_loss: 0.5924
Epoch 150/200
23/23 [==============================] - 317s 14s/step - loss: 0.5684 - val_loss: 0.5896
Epoch 151/200
23/23 [==============================] - 317s 14s/step - loss: 0.5743 - val_loss: 0.5912
Epoch 152/200
23/23 [==============================] - 315s 14s/step - loss: 0.5822 - val_loss: 0.5951
Epoch 153/200
23/23 [==============================] - 324s 14s/step - loss: 0.5618 - val_loss: 0.5944
Epoch 154/200
23/23 [==============================] - 307s 13s/step - loss: 0.5643 - val_loss: 0.6124
Epoch 155/200
23/23 [==============================] - 316s 14s/step - loss: 0.5826 - val_loss: 0.5953
Epoch 156/200
23/23 [==============================] - 316s 14s/step - loss: 0.5511 - val_loss: 0.5982
Epoch 157/200
23/23 [==============================] - 316s 14s/step - loss: 0.5825 - val_loss: 0.6118
Epoch 158/200
23/23 [==============================] - 317s 14s/step - loss: 0.5765 - val_loss: 0.6068
Epoch 159/200
23/23 [==============================] - 324s 14s/step - loss: 0.5595 - val_loss: 0.6165
Epoch 160/200
23/23 [==============================] - 316s 14s/step - loss: 0.5645 - val_loss: 0.6025
Epoch 161/200
23/23 [==============================] - 316s 14s/step - loss: 0.5701 - val_loss: 0.5939
Epoch 162/200
23/23 [==============================] - 315s 14s/step - loss: 0.5508 - val_loss: 0.6584
Epoch 163/200
23/23 [==============================] - 317s 14s/step - loss: 0.5823 - val_loss: 0.5986
Epoch 164/200
23/23 [==============================] - 309s 13s/step - loss: 0.5696 - val_loss: 0.5942
Epoch 165/200
23/23 [==============================] - 324s 14s/step - loss: 0.5671 - val_loss: 0.5930
Epoch 166/200
23/23 [==============================] - 316s 14s/step - loss: 0.5732 - val_loss: 0.6046
Epoch 167/200
23/23 [==============================] - 317s 14s/step - loss: 0.5848 - val_loss: 0.6038
Epoch 168/200
23/23 [==============================] - 316s 14s/step - loss: 0.5610 - val_loss: 0.5962
Epoch 169/200
23/23 [==============================] - 316s 14s/step - loss: 0.5677 - val_loss: 0.5963
Epoch 170/200
23/23 [==============================] - 315s 14s/step - loss: 0.5691 - val_loss: 0.6037
Epoch 171/200
23/23 [==============================] - 317s 14s/step - loss: 0.5664 - val_loss: 0.6424
Epoch 172/200
23/23 [==============================] - 315s 14s/step - loss: 0.5853 - val_loss: 0.6171
Epoch 173/200
23/23 [==============================] - 322s 14s/step - loss: 0.5685 - val_loss: 0.6347
Epoch 174/200
23/23 [==============================] - 308s 13s/step - loss: 0.5661 - val_loss: 0.5928
Epoch 175/200
23/23 [==============================] - 316s 14s/step - loss: 0.5765 - val_loss: 0.5996
Epoch 176/200
23/23 [==============================] - 316s 14s/step - loss: 0.5702 - val_loss: 0.6206
Epoch 177/200
23/23 [==============================] - 316s 14s/step - loss: 0.5553 - val_loss: 0.5977
Epoch 178/200
23/23 [==============================] - 325s 14s/step - loss: 0.5529 - val_loss: 0.6142
Epoch 179/200
23/23 [==============================] - 307s 13s/step - loss: 0.5762 - val_loss: 0.6915
Epoch 180/200
23/23 [==============================] - 324s 14s/step - loss: 0.5948 - val_loss: 0.6336
Epoch 181/200
23/23 [==============================] - 316s 14s/step - loss: 0.5656 - val_loss: 0.5972
Epoch 182/200
23/23 [==============================] - 315s 14s/step - loss: 0.5657 - val_loss: 0.6164
Epoch 183/200
23/23 [==============================] - 316s 14s/step - loss: 0.5612 - val_loss: 0.6098
Epoch 184/200
23/23 [==============================] - 306s 13s/step - loss: 0.5579 - val_loss: 0.6063
Epoch 185/200
23/23 [==============================] - 325s 14s/step - loss: 0.5587 - val_loss: 0.6072
Epoch 186/200
23/23 [==============================] - 307s 13s/step - loss: 0.5538 - val_loss: 0.6023
Epoch 187/200
23/23 [==============================] - 325s 14s/step - loss: 0.5751 - val_loss: 0.6030
Epoch 188/200
23/23 [==============================] - 315s 14s/step - loss: 0.5729 - val_loss: 0.6070
Epoch 189/200
23/23 [==============================] - 316s 14s/step - loss: 0.5563 - val_loss: 0.6029
Epoch 190/200
23/23 [==============================] - 316s 14s/step - loss: 0.5678 - val_loss: 0.6050
Epoch 191/200
23/23 [==============================] - 316s 14s/step - loss: 0.5588 - val_loss: 0.6466
Epoch 192/200
23/23 [==============================] - 315s 14s/step - loss: 0.5725 - val_loss: 0.6221
Epoch 193/200
23/23 [==============================] - 315s 14s/step - loss: 0.5651 - val_loss: 0.6452
Epoch 194/200
23/23 [==============================] - 316s 14s/step - loss: 0.5578 - val_loss: 0.6224
Epoch 195/200
23/23 [==============================] - 315s 14s/step - loss: 0.5709 - val_loss: 0.6041
Epoch 196/200
23/23 [==============================] - 315s 14s/step - loss: 0.5712 - val_loss: 0.6082
Epoch 197/200
23/23 [==============================] - 318s 14s/step - loss: 0.5702 - val_loss: 0.6213
Epoch 198/200
23/23 [==============================] - 317s 14s/step - loss: 0.5587 - val_loss: 0.6136
Epoch 199/200
23/23 [==============================] - 315s 14s/step - loss: 0.5716 - val_loss: 0.6168
Epoch 200/200
23/23 [==============================] - 325s 14s/step - loss: 0.5480 - val_loss: 0.6152

In [157]:
logged_history = pd.read_csv(logger_filename)
logged_history


Out[157]:
epoch loss val_loss
0 0 0.625163 0.613147
1 1 0.591443 0.612764
2 2 0.577214 0.612777
3 3 0.603634 0.615373
4 4 0.583025 0.605459
5 5 0.587915 0.605771
6 6 0.588803 0.606950
7 7 0.592144 0.604085
8 8 0.581558 0.601125
9 9 0.586347 0.601322
10 10 0.579989 0.597998
11 11 0.588126 0.600044
12 12 0.588524 0.597567
13 13 0.581074 0.597355
14 14 0.584146 0.597105
15 15 0.582932 0.603825
16 16 0.590283 0.609898
17 17 0.581452 0.619715
18 18 0.596163 0.595666
19 19 0.590907 0.601241
20 20 0.591377 0.600684
21 21 0.586311 0.620571
22 22 0.587392 0.598951
23 23 0.584681 0.601766
24 24 0.585888 0.596175
25 25 0.580034 0.595876
26 26 0.583248 0.599792
27 27 0.595192 0.590801
28 28 0.564765 0.594454
29 29 0.585922 0.597325
... ... ... ...
170 170 0.568673 0.642422
171 171 0.582559 0.617082
172 172 0.568536 0.634678
173 173 0.570030 0.592751
174 174 0.576366 0.599555
175 175 0.567817 0.620639
176 176 0.555246 0.597715
177 177 0.552859 0.614191
178 178 0.566975 0.691513
179 179 0.594818 0.633648
180 180 0.569522 0.597187
181 181 0.568437 0.616393
182 182 0.558482 0.609809
183 183 0.563748 0.606301
184 184 0.558688 0.607209
185 185 0.559000 0.602337
186 186 0.575082 0.602983
187 187 0.571717 0.607048
188 188 0.559430 0.602911
189 189 0.568724 0.605037
190 190 0.562074 0.646605
191 191 0.571279 0.622057
192 192 0.565039 0.645184
193 193 0.561905 0.622374
194 194 0.570240 0.604071
195 195 0.569247 0.608244
196 196 0.568280 0.621342
197 197 0.557124 0.613607
198 198 0.575547 0.616825
199 199 0.548008 0.615171

200 rows × 3 columns


In [158]:
from sklearn.metrics import log_loss

In [159]:
p = Y[training_set_indices].mean()
prior_loss = log_loss(Y[testing_set_indices], 
                      [p]*testing_set_indices.size)

In [163]:
with mpl.rc_context(rc={"figure.figsize": (10,6)}):

        
    plt.axhline(prior_loss, label="Prior", 
                linestyle="dashed", color="black")
    
    plt.plot(logged_history["val_loss"], label="Validation")
    plt.plot(logged_history["loss"], label="Training")
    
    plt.xlabel("Epoch")
    plt.ylabel("Loss\n(avg. binary cross-entropy)")
    
    plt.legend()
    
    plt.ylim(bottom=.56, top=.62)



In [164]:
from sklearn.metrics import log_loss
with mpl.rc_context(rc={"figure.figsize": (10,6)}):

    simple_conv = lambda x: np.convolve(x, np.ones(5)/5, mode="valid")
    
    plt.axhline(prior_loss, label="Prior", 
                linestyle="dashed", color="black")
    
    plt.plot(simple_conv(logged_history["val_loss"]),
             label="Validation")
    plt.plot(simple_conv(logged_history["loss"]),
             label="Training")
    
    plt.xlabel("Epoch")
    plt.ylabel("Loss\n(avg. binary cross-entropy)")
    
    plt.legend()
    
    plt.ylim(bottom=.56, top=.62)


9) Look at validation results


In [59]:
class_probs = model.predict_proba(X_test_transformed).flatten()
class_probs


373/373 [==============================] - 1s 3ms/step
Out[59]:
array([0.38096353, 0.44315898, 0.42595476, 0.39431217, 0.44192123,
       0.42403397, 0.3084317 , 0.3977308 , 0.42297694, 0.44161874,
       0.4457367 , 0.4409836 , 0.40325916, 0.37228101, 0.28045028,
       0.39575815, 0.44849315, 0.4440918 , 0.4450764 , 0.4423715 ,
       0.4449345 , 0.41168362, 0.4467733 , 0.43635613, 0.44292593,
       0.41494936, 0.39176732, 0.39990672, 0.2934756 , 0.32629684,
       0.3822251 , 0.31482962, 0.3116483 , 0.4064709 , 0.30879626,
       0.2557003 , 0.41120014, 0.37519196, 0.44936234, 0.40176862,
       0.4401391 , 0.4009208 , 0.33045557, 0.43595022, 0.3117058 ,
       0.36470458, 0.42641282, 0.38719258, 0.22711629, 0.39770648,
       0.2820537 , 0.34835997, 0.43251485, 0.30706087, 0.25999638,
       0.40024894, 0.38270858, 0.3465161 , 0.38174748, 0.43951225,
       0.36454654, 0.37350082, 0.3985789 , 0.42196226, 0.43141177,
       0.391576  , 0.402127  , 0.445567  , 0.43924326, 0.4386993 ,
       0.4277306 , 0.43835342, 0.2692941 , 0.44942412, 0.4508125 ,
       0.15961964, 0.39857697, 0.44297135, 0.3566997 , 0.440747  ,
       0.2533559 , 0.37569338, 0.4425728 , 0.43577525, 0.38890773,
       0.43563932, 0.4241861 , 0.3932006 , 0.43114784, 0.4328482 ,
       0.4168405 , 0.39507186, 0.4186143 , 0.42405146, 0.44095054,
       0.4161974 , 0.33376172, 0.44632694, 0.43030286, 0.3956095 ,
       0.30100748, 0.39117464, 0.4447542 , 0.43389064, 0.43305126,
       0.34959513, 0.35396492, 0.3121762 , 0.40590182, 0.4294078 ,
       0.4349919 , 0.40991887, 0.41888702, 0.36120278, 0.4282187 ,
       0.44593552, 0.3572685 , 0.44955522, 0.44456702, 0.3408464 ,
       0.45547816, 0.42949343, 0.342261  , 0.37738895, 0.44656345,
       0.31829506, 0.40286332, 0.44911963, 0.42276067, 0.44706854,
       0.40765384, 0.31033507, 0.32841244, 0.25814795, 0.31241944,
       0.453109  , 0.44477054, 0.34910932, 0.43158117, 0.17553769,
       0.45416188, 0.41460636, 0.40567777, 0.44760412, 0.24019198,
       0.44046155, 0.40499735, 0.36315823, 0.36915693, 0.4149858 ,
       0.44576353, 0.38120922, 0.3828068 , 0.37243614, 0.44181246,
       0.3235208 , 0.44830713, 0.44839013, 0.44242042, 0.44256744,
       0.44800305, 0.44181132, 0.43351567, 0.41210204, 0.44121698,
       0.413092  , 0.44176257, 0.43989152, 0.44552085, 0.43252727,
       0.4258776 , 0.3768219 , 0.41349387, 0.44884515, 0.32588708,
       0.41172966, 0.44757605, 0.41906548, 0.3825629 , 0.2607594 ,
       0.38989165, 0.29450786, 0.44284907, 0.31305894, 0.4185819 ,
       0.29876572, 0.16504529, 0.44900298, 0.44687438, 0.44056222,
       0.44137925, 0.3948508 , 0.34090987, 0.38642806, 0.40271282,
       0.44723928, 0.4007566 , 0.43055147, 0.44907433, 0.44951198,
       0.43810576, 0.43672234, 0.43379405, 0.43176052, 0.44240138,
       0.40329707, 0.4387482 , 0.32325944, 0.40388075, 0.27106833,
       0.3865721 , 0.20918834, 0.34775093, 0.44877672, 0.25511292,
       0.44780973, 0.40457538, 0.44524664, 0.32925203, 0.43885016,
       0.35806048, 0.3756185 , 0.38967642, 0.4192389 , 0.24980316,
       0.44150057, 0.43218148, 0.41110608, 0.39357594, 0.29679593,
       0.44293633, 0.42880318, 0.44525757, 0.4079865 , 0.3264692 ,
       0.44915658, 0.42911676, 0.44559175, 0.4005376 , 0.3871582 ,
       0.4421507 , 0.27849784, 0.23117535, 0.44651335, 0.4048834 ,
       0.40808037, 0.29254013, 0.4494428 , 0.3446331 , 0.4238776 ,
       0.40741017, 0.42450374, 0.2993166 , 0.3820515 , 0.447232  ,
       0.2889412 , 0.42764917, 0.44313568, 0.2791731 , 0.445895  ,
       0.35959965, 0.28588116, 0.39521152, 0.42558408, 0.29218394,
       0.3922141 , 0.37945515, 0.38019413, 0.44780952, 0.44512308,
       0.38719988, 0.39650354, 0.4100446 , 0.25340104, 0.3529415 ,
       0.38736403, 0.29348183, 0.44824216, 0.39691976, 0.443862  ,
       0.38470116, 0.30114385, 0.39607376, 0.42483705, 0.31340092,
       0.44839117, 0.45280483, 0.28304958, 0.37893963, 0.44633308,
       0.44886526, 0.42260054, 0.36907005, 0.39242512, 0.36889833,
       0.43124402, 0.4156075 , 0.25506765, 0.44074667, 0.38222235,
       0.43462127, 0.4382892 , 0.35095876, 0.4299872 , 0.44595292,
       0.3694274 , 0.4358682 , 0.40623176, 0.39650482, 0.44579092,
       0.42501035, 0.44486007, 0.40626553, 0.4242109 , 0.44190362,
       0.45157403, 0.36083287, 0.37688532, 0.41013816, 0.3927443 ,
       0.45153722, 0.3697739 , 0.447224  , 0.4037631 , 0.42013228,
       0.2867513 , 0.43948248, 0.43644127, 0.44002444, 0.40816608,
       0.24360898, 0.44442272, 0.43882433, 0.4233473 , 0.40551072,
       0.44690314, 0.41971892, 0.44890967, 0.4048392 , 0.4508837 ,
       0.27985042, 0.43386275, 0.16176486, 0.44558656, 0.37172565,
       0.40866947, 0.326269  , 0.3866764 , 0.38622338, 0.29225364,
       0.36603895, 0.42216617, 0.29009834, 0.25085118, 0.43494695,
       0.41980428, 0.31129548, 0.3963757 , 0.37951544, 0.438067  ,
       0.4432989 , 0.44002596, 0.44598934, 0.35958058, 0.4162218 ,
       0.32581538, 0.43282866, 0.38365978, 0.44844076, 0.43104017,
       0.44456834, 0.36104265, 0.44557685], dtype=float32)

In [60]:
with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    sns.distplot(class_probs[Y[testing_set_indices]==True], color="g", label="true dwarfs")
    sns.distplot(class_probs[Y[testing_set_indices]==False], color="b", label="true non-dwarfs")

    plt.xlabel("p(dwarf | image)")
    plt.ylabel("density (galaxies)")

    plt.xlim(0, .7)
    plt.axvline(Y[training_set_indices].mean(), linestyle="dashed", color="black", label="prior\n(from training set)")
    plt.axvline(.5, linestyle="dotted", color="black", label="50/50")

    plt.legend(
        loc="upper left",
        bbox_to_anchor=(1, 1),
    )



In [61]:
from sklearn import metrics
from sklearn.metrics import roc_auc_score

with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    fpr, tpr, _ = metrics.roc_curve(Y[testing_set_indices], class_probs)
    roc_auc = roc_auc_score(Y[testing_set_indices], class_probs)

    plt.plot(fpr, tpr, label="DNN (AUC = {:.2})".format(roc_auc))
    plt.plot([0,1], [0,1], linestyle="dashed", color="black", label="random guessing")

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")

    plt.title("ROC Curve")

    plt.legend(loc="best")



In [62]:
from sklearn.metrics import average_precision_score
with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    precision, recall, _ = metrics.precision_recall_curve(Y[testing_set_indices], class_probs)
    pr_auc = average_precision_score(Y[testing_set_indices], class_probs)

    plt.plot(recall, precision, label="AUC = {:.2}".format(pr_auc))
    plt.plot([0,1], [Y[testing_set_indices].mean()]*2, linestyle="dashed", color="black")

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("Recall")
    plt.ylabel("Precision")

    plt.title("PR Curve")

    plt.legend(loc="best")


Apply a threshold

For now, just use a threshold using the prior class probability (estimated from the training set)

Note under a symmetric loss function, this isn't as good as


In [63]:
predicted_classes = class_probs > (Y[training_set_indices].mean())
predicted_classes.mean()


Out[63]:
0.9195710455764075

In [64]:
confusion_matrix = metrics.confusion_matrix(Y[testing_set_indices], predicted_classes)
confusion_matrix


Out[64]:
array([[ 27, 257],
       [  3,  86]])

In [65]:
print("number of dwarfs (true)     : ", Y[testing_set_indices].sum())
print("number of dwarfs (predicted): ", predicted_classes.sum())


number of dwarfs (true)     :  89
number of dwarfs (predicted):  343

In [66]:
import itertools
# adapted from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

confusion_matrix_normalized = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]

plt.imshow(confusion_matrix_normalized, interpolation='nearest',cmap="gray_r")
# plt.title(title)
plt.colorbar()
tick_marks = np.arange(2)
classes = [False, True]
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = 'd'
thresh = 1 / 2.
for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])):
    plt.text(j, i, format(confusion_matrix[i, j], fmt),
             fontdict={"size":20},
             horizontalalignment="center",
             color="white" if confusion_matrix_normalized[i, j] > thresh else "black")

plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.tight_layout()



In [69]:
print("  i - Y_true[i] - Y_pred[i] -  error?")
print("-------------------------------------")
for i in range(predicted_classes.size):
    print("{:>3} -    {:>1d}      -    {:d}      -   {:>2d}".format(i, 
                                                    Y[testing_set_indices][i], 
                                                    predicted_classes[i],
                                                    (Y[testing_set_indices][i] != predicted_classes[i]), 
                                                   ))


  i - Y_true[i] - Y_pred[i] -  error?
-------------------------------------
  0 -    0      -    1      -    1
  1 -    1      -    1      -    0
  2 -    0      -    1      -    1
  3 -    1      -    1      -    0
  4 -    0      -    1      -    1
  5 -    0      -    1      -    1
  6 -    0      -    1      -    1
  7 -    0      -    1      -    1
  8 -    0      -    1      -    1
  9 -    0      -    1      -    1
 10 -    1      -    1      -    0
 11 -    0      -    1      -    1
 12 -    1      -    1      -    0
 13 -    0      -    1      -    1
 14 -    0      -    0      -    0
 15 -    0      -    1      -    1
 16 -    0      -    1      -    1
 17 -    0      -    1      -    1
 18 -    1      -    1      -    0
 19 -    1      -    1      -    0
 20 -    0      -    1      -    1
 21 -    0      -    1      -    1
 22 -    1      -    1      -    0
 23 -    0      -    1      -    1
 24 -    0      -    1      -    1
 25 -    0      -    1      -    1
 26 -    0      -    1      -    1
 27 -    0      -    1      -    1
 28 -    0      -    1      -    1
 29 -    0      -    1      -    1
 30 -    0      -    1      -    1
 31 -    0      -    1      -    1
 32 -    1      -    1      -    0
 33 -    0      -    1      -    1
 34 -    0      -    1      -    1
 35 -    1      -    0      -    1
 36 -    0      -    1      -    1
 37 -    1      -    1      -    0
 38 -    0      -    1      -    1
 39 -    1      -    1      -    0
 40 -    0      -    1      -    1
 41 -    0      -    1      -    1
 42 -    1      -    1      -    0
 43 -    1      -    1      -    0
 44 -    0      -    1      -    1
 45 -    1      -    1      -    0
 46 -    0      -    1      -    1
 47 -    1      -    1      -    0
 48 -    0      -    0      -    0
 49 -    0      -    1      -    1
 50 -    0      -    0      -    0
 51 -    0      -    1      -    1
 52 -    1      -    1      -    0
 53 -    0      -    1      -    1
 54 -    0      -    0      -    0
 55 -    0      -    1      -    1
 56 -    0      -    1      -    1
 57 -    0      -    1      -    1
 58 -    0      -    1      -    1
 59 -    0      -    1      -    1
 60 -    1      -    1      -    0
 61 -    0      -    1      -    1
 62 -    0      -    1      -    1
 63 -    0      -    1      -    1
 64 -    0      -    1      -    1
 65 -    0      -    1      -    1
 66 -    0      -    1      -    1
 67 -    1      -    1      -    0
 68 -    1      -    1      -    0
 69 -    0      -    1      -    1
 70 -    0      -    1      -    1
 71 -    0      -    1      -    1
 72 -    0      -    0      -    0
 73 -    0      -    1      -    1
 74 -    0      -    1      -    1
 75 -    0      -    0      -    0
 76 -    0      -    1      -    1
 77 -    1      -    1      -    0
 78 -    0      -    1      -    1
 79 -    0      -    1      -    1
 80 -    0      -    0      -    0
 81 -    1      -    1      -    0
 82 -    0      -    1      -    1
 83 -    0      -    1      -    1
 84 -    0      -    1      -    1
 85 -    1      -    1      -    0
 86 -    0      -    1      -    1
 87 -    1      -    1      -    0
 88 -    0      -    1      -    1
 89 -    0      -    1      -    1
 90 -    1      -    1      -    0
 91 -    0      -    1      -    1
 92 -    0      -    1      -    1
 93 -    1      -    1      -    0
 94 -    0      -    1      -    1
 95 -    0      -    1      -    1
 96 -    0      -    1      -    1
 97 -    1      -    1      -    0
 98 -    1      -    1      -    0
 99 -    0      -    1      -    1
100 -    0      -    1      -    1
101 -    0      -    1      -    1
102 -    1      -    1      -    0
103 -    0      -    1      -    1
104 -    0      -    1      -    1
105 -    1      -    1      -    0
106 -    0      -    1      -    1
107 -    1      -    1      -    0
108 -    0      -    1      -    1
109 -    0      -    1      -    1
110 -    0      -    1      -    1
111 -    0      -    1      -    1
112 -    0      -    1      -    1
113 -    0      -    1      -    1
114 -    0      -    1      -    1
115 -    0      -    1      -    1
116 -    0      -    1      -    1
117 -    0      -    1      -    1
118 -    0      -    1      -    1
119 -    0      -    1      -    1
120 -    0      -    1      -    1
121 -    0      -    1      -    1
122 -    0      -    1      -    1
123 -    0      -    1      -    1
124 -    0      -    1      -    1
125 -    1      -    1      -    0
126 -    0      -    1      -    1
127 -    0      -    1      -    1
128 -    0      -    1      -    1
129 -    1      -    1      -    0
130 -    1      -    1      -    0
131 -    0      -    1      -    1
132 -    0      -    1      -    1
133 -    0      -    0      -    0
134 -    0      -    1      -    1
135 -    0      -    1      -    1
136 -    0      -    1      -    1
137 -    0      -    1      -    1
138 -    0      -    1      -    1
139 -    0      -    0      -    0
140 -    0      -    1      -    1
141 -    0      -    1      -    1
142 -    0      -    1      -    1
143 -    1      -    1      -    0
144 -    0      -    0      -    0
145 -    0      -    1      -    1
146 -    0      -    1      -    1
147 -    0      -    1      -    1
148 -    0      -    1      -    1
149 -    0      -    1      -    1
150 -    1      -    1      -    0
151 -    0      -    1      -    1
152 -    0      -    1      -    1
153 -    0      -    1      -    1
154 -    0      -    1      -    1
155 -    0      -    1      -    1
156 -    0      -    1      -    1
157 -    1      -    1      -    0
158 -    1      -    1      -    0
159 -    1      -    1      -    0
160 -    0      -    1      -    1
161 -    0      -    1      -    1
162 -    1      -    1      -    0
163 -    1      -    1      -    0
164 -    1      -    1      -    0
165 -    0      -    1      -    1
166 -    1      -    1      -    0
167 -    1      -    1      -    0
168 -    1      -    1      -    0
169 -    1      -    1      -    0
170 -    0      -    1      -    1
171 -    0      -    1      -    1
172 -    0      -    1      -    1
173 -    0      -    1      -    1
174 -    0      -    1      -    1
175 -    0      -    1      -    1
176 -    1      -    1      -    0
177 -    0      -    1      -    1
178 -    0      -    1      -    1
179 -    0      -    0      -    0
180 -    0      -    1      -    1
181 -    0      -    1      -    1
182 -    1      -    1      -    0
183 -    0      -    1      -    1
184 -    0      -    1      -    1
185 -    0      -    1      -    1
186 -    0      -    0      -    0
187 -    0      -    1      -    1
188 -    0      -    1      -    1
189 -    1      -    1      -    0
190 -    0      -    1      -    1
191 -    0      -    1      -    1
192 -    0      -    1      -    1
193 -    0      -    1      -    1
194 -    0      -    1      -    1
195 -    0      -    1      -    1
196 -    0      -    1      -    1
197 -    1      -    1      -    0
198 -    0      -    1      -    1
199 -    0      -    1      -    1
200 -    1      -    1      -    0
201 -    1      -    1      -    0
202 -    1      -    1      -    0
203 -    0      -    1      -    1
204 -    1      -    1      -    0
205 -    1      -    1      -    0
206 -    1      -    1      -    0
207 -    0      -    1      -    1
208 -    0      -    1      -    1
209 -    0      -    0      -    0
210 -    0      -    1      -    1
211 -    0      -    0      -    0
212 -    0      -    1      -    1
213 -    0      -    1      -    1
214 -    0      -    0      -    0
215 -    1      -    1      -    0
216 -    1      -    1      -    0
217 -    1      -    1      -    0
218 -    0      -    1      -    1
219 -    1      -    1      -    0
220 -    1      -    1      -    0
221 -    0      -    1      -    1
222 -    0      -    1      -    1
223 -    0      -    1      -    1
224 -    1      -    0      -    1
225 -    0      -    1      -    1
226 -    0      -    1      -    1
227 -    0      -    1      -    1
228 -    0      -    1      -    1
229 -    0      -    1      -    1
230 -    0      -    1      -    1
231 -    0      -    1      -    1
232 -    0      -    1      -    1
233 -    0      -    1      -    1
234 -    0      -    1      -    1
235 -    1      -    1      -    0
236 -    0      -    1      -    1
237 -    0      -    1      -    1
238 -    1      -    1      -    0
239 -    0      -    1      -    1
240 -    0      -    1      -    1
241 -    0      -    0      -    0
242 -    0      -    0      -    0
243 -    0      -    1      -    1
244 -    1      -    1      -    0
245 -    1      -    1      -    0
246 -    0      -    1      -    1
247 -    0      -    1      -    1
248 -    0      -    1      -    1
249 -    0      -    1      -    1
250 -    0      -    1      -    1
251 -    0      -    1      -    1
252 -    0      -    1      -    1
253 -    0      -    1      -    1
254 -    1      -    1      -    0
255 -    0      -    0      -    0
256 -    0      -    1      -    1
257 -    0      -    1      -    1
258 -    0      -    0      -    0
259 -    0      -    1      -    1
260 -    0      -    1      -    1
261 -    0      -    0      -    0
262 -    0      -    1      -    1
263 -    0      -    1      -    1
264 -    1      -    1      -    0
265 -    0      -    1      -    1
266 -    0      -    1      -    1
267 -    0      -    1      -    1
268 -    1      -    1      -    0
269 -    1      -    1      -    0
270 -    0      -    1      -    1
271 -    1      -    1      -    0
272 -    0      -    1      -    1
273 -    0      -    0      -    0
274 -    0      -    1      -    1
275 -    0      -    1      -    1
276 -    0      -    1      -    1
277 -    0      -    1      -    1
278 -    0      -    1      -    1
279 -    0      -    1      -    1
280 -    1      -    1      -    0
281 -    0      -    1      -    1
282 -    0      -    1      -    1
283 -    1      -    1      -    0
284 -    0      -    1      -    1
285 -    0      -    1      -    1
286 -    0      -    1      -    1
287 -    0      -    0      -    0
288 -    1      -    1      -    0
289 -    0      -    1      -    1
290 -    0      -    1      -    1
291 -    0      -    1      -    1
292 -    0      -    1      -    1
293 -    0      -    1      -    1
294 -    0      -    1      -    1
295 -    0      -    1      -    1
296 -    0      -    1      -    1
297 -    0      -    0      -    0
298 -    1      -    1      -    0
299 -    0      -    1      -    1
300 -    0      -    1      -    1
301 -    0      -    1      -    1
302 -    0      -    1      -    1
303 -    1      -    1      -    0
304 -    0      -    1      -    1
305 -    0      -    1      -    1
306 -    1      -    1      -    0
307 -    0      -    1      -    1
308 -    0      -    1      -    1
309 -    0      -    1      -    1
310 -    0      -    1      -    1
311 -    1      -    1      -    0
312 -    0      -    1      -    1
313 -    0      -    1      -    1
314 -    1      -    1      -    0
315 -    0      -    1      -    1
316 -    0      -    1      -    1
317 -    0      -    1      -    1
318 -    1      -    1      -    0
319 -    0      -    1      -    1
320 -    0      -    1      -    1
321 -    0      -    1      -    1
322 -    1      -    1      -    0
323 -    0      -    1      -    1
324 -    0      -    1      -    1
325 -    0      -    0      -    0
326 -    0      -    1      -    1
327 -    0      -    1      -    1
328 -    0      -    1      -    1
329 -    0      -    1      -    1
330 -    0      -    0      -    0
331 -    0      -    1      -    1
332 -    1      -    1      -    0
333 -    1      -    1      -    0
334 -    0      -    1      -    1
335 -    1      -    1      -    0
336 -    0      -    1      -    1
337 -    0      -    1      -    1
338 -    0      -    1      -    1
339 -    1      -    1      -    0
340 -    0      -    0      -    0
341 -    0      -    1      -    1
342 -    1      -    0      -    1
343 -    1      -    1      -    0
344 -    0      -    1      -    1
345 -    0      -    1      -    1
346 -    0      -    1      -    1
347 -    0      -    1      -    1
348 -    0      -    1      -    1
349 -    0      -    1      -    1
350 -    1      -    1      -    0
351 -    0      -    1      -    1
352 -    0      -    1      -    1
353 -    0      -    0      -    0
354 -    0      -    1      -    1
355 -    0      -    1      -    1
356 -    0      -    1      -    1
357 -    0      -    1      -    1
358 -    0      -    1      -    1
359 -    0      -    1      -    1
360 -    0      -    1      -    1
361 -    1      -    1      -    0
362 -    0      -    1      -    1
363 -    0      -    1      -    1
364 -    0      -    1      -    1
365 -    0      -    1      -    1
366 -    0      -    1      -    1
367 -    0      -    1      -    1
368 -    1      -    1      -    0
369 -    0      -    1      -    1
370 -    0      -    1      -    1
371 -    0      -    1      -    1
372 -    0      -    1      -    1

Analyze Errors


In [77]:
HSC_ids


Out[77]:
array([43158176442374224, 43158176442374373, 43158176442374445, ...,
       43159155694916013, 43159155694916476, 43159155694917496])

In [79]:
df.loc[HSC_ids[testing_set_indices]].head()


Out[79]:
target
HSC_id
43158880817015286 False
43159142810013371 True
43158876522047030 False
43158863637144621 True
43158335356145699 False

In [80]:
COSMOS_filename = os.path.join(dwarfz.data_dir_default, 
                               "COSMOS_reference.sqlite")
COSMOS = dwarfz.datasets.COSMOS(COSMOS_filename)

In [81]:
HSC_filename = os.path.join(dwarfz.data_dir_default, 
                            "HSC_COSMOS_median_forced.sqlite3")
HSC = dwarfz.datasets.HSC(HSC_filename)

In [82]:
matches_filename = os.path.join(dwarfz.data_dir_default, 
                                "matches.sqlite3")
matches_df = dwarfz.matching.Matches.load_from_filename(matches_filename)

In [89]:
combined = matches_df[matches_df.match].copy()
combined["ra"]       = COSMOS.df.loc[combined.index].ra
combined["dec"]      = COSMOS.df.loc[combined.index].dec
combined["photo_z"]  = COSMOS.df.loc[combined.index].photo_z
combined["log_mass"] = COSMOS.df.loc[combined.index].mass_med
combined["active"]   = COSMOS.df.loc[combined.index].classification

combined = combined.set_index("catalog_2_ids")

combined.head()


Out[89]:
sep match error ra dec photo_z log_mass active
catalog_2_ids
43158996781122114 0.114389 True False 149.749393 1.618068 0.3797 11.07610 0
43158447025298860 0.471546 True False 150.388349 1.614538 2.3343 8.99275 1
43158447025298862 0.202378 True False 150.402935 1.614631 2.1991 9.71373 1
43158584464246387 0.207967 True False 150.295083 1.614662 2.4407 9.77811 1
43158584464253383 0.295316 True False 150.239919 1.614675 0.2079 7.04224 1

In [93]:
df_features_testing = combined.loc[HSC_ids[testing_set_indices]]

In [148]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
#            C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Number of objects",
            )

plt.suptitle("All Objects")


348
Out[148]:
Text(0.5,0.98,'All Objects')

In [146]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )


plt.suptitle("All Objects")


348
Out[146]:
Text(0.5,0.98,'All Objects')

^^ Huh, that's a pretty uniform looking spread. It doesn't really seem like it's trending in an useful direction (either near the desired boundaries or as you go further away).


In [147]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )

plt.suptitle("False Positives")


243
Out[147]:
Text(0.5,0.98,'False Positives')

In [ ]: