In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl

import seaborn as sns; sns.set(style="white", font_scale=2)

import numpy as np
import pandas as pd
from astropy.io import fits
import glob

import sklearn
import sklearn.metrics

import keras
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.layers.noise import GaussianNoise

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.preprocessing.image import random_channel_shift
from keras.optimizers import SGD, Adam
from keras import backend as K
K.set_image_data_format('channels_first')

import scipy.ndimage as ndi

import matplotlib.patches as patches
import pathlib


Using TensorFlow backend.

In [2]:
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.figsize'] = np.array((10,6))
mpl.rcParams['figure.facecolor'] = "white"

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# give access to importing dwarfz
import os, sys
dwarfz_package_dir = os.getcwd().split("dwarfz")[0]
if dwarfz_package_dir not in sys.path:
    sys.path.insert(0, dwarfz_package_dir)

import dwarfz
    
# back to regular import statements

To Do

  1. Read in fits images
  2. Apply pre-processing stretch (asinh)
  3. crop image smaller
  4. Combine filters into one cube
  5. Create training set with labels
  6. Set up keras model
  7. Poke around at the results

0) Get files


In [5]:
images_dir = "../data/galaxy_images_training/quarry_files/"

In [6]:
HSC_ids = [int(os.path.basename(image_file).split("-")[0])
           for image_file in glob.glob(os.path.join(images_dir, "*.fits"))]
HSC_ids = set(HSC_ids) # remove duplicates
HSC_ids = np.array(sorted(HSC_ids))

# now filter out galaxies missing bands
HSC_ids = [HSC_id
           for HSC_id in HSC_ids
           if len(glob.glob(os.path.join(images_dir, "{}*.fits".format(str(HSC_id)))))==5]
HSC_ids = np.array(HSC_ids)

In [7]:
HSC_ids.size


Out[7]:
1866

In [8]:
HSC_id = HSC_ids[2] # for when I need a single sample galaxy

In [9]:
bands = ["g", "r", "i", "z", "y"]

1) Read in fits image


In [10]:
import preprocessing

In [11]:
image, flux_mag_0 = preprocessing.get_image(HSC_id, "g")
print("image size: {} x {}".format(*image.shape))
image


image size: 239 x 239
Out[11]:
array([[-5.1946715e-02,  5.2224647e-02,  1.2836658e-02, ...,
        -4.2683050e-02,  4.6599112e-02, -9.3426881e-03],
       [-2.4148472e-02,  4.9130496e-02, -1.0081916e-05, ...,
         4.0721711e-02,  1.4573295e-02, -4.0571071e-02],
       [-6.8525448e-03,  5.2679256e-03, -6.4173182e-03, ...,
         3.4137823e-02, -7.7981845e-02, -2.1136018e-02],
       ...,
       [-8.8869138e-03,  3.1337745e-02,  8.5042417e-04, ...,
         3.2482971e-02, -2.9819449e-02,  4.8630014e-03],
       [ 5.8512330e-02,  4.9714420e-02,  8.1128389e-02, ...,
        -2.5828091e-02,  3.4173269e-02, -4.4498667e-02],
       [ 1.3712918e-02, -2.2365674e-02,  3.0620810e-02, ...,
         2.5152223e-02, -1.5306322e-02, -4.1226905e-02]], dtype=float32)

In [12]:
preprocessing.image_plotter(image)



In [13]:
preprocessing.image_plotter(np.log(image))


/Users/egentry/anaconda3/envs/tf36/lib/python3.6/site-packages/ipykernel_launcher.py:1: RuntimeWarning: invalid value encountered in log
  """Entry point for launching an IPython kernel.

2) Apply stretch

We're using (negative) asinh magnitudes, as implemented by the HSC collaboration.

To see more about asinh magnitude system, see : Lupton, Gunn and Szalay (1999) used for SDSS. (It's expliticly given in the SDSS Algorithms documentation as well as this overview page).

To see the source of our code, see: the HSC color image creator

And for reference, a common form of this stretch is: $$ \mathrm{mag}_\mathrm{asinh} = - \left(\frac{2.5}{\ln(10)}\right) \left(\mathrm{asinh}\left(\frac{f/f_0}{2b}\right) + \ln(b) \right)$$ for dimensionless softening parameter $b$, and reference flux (f_0).


In [14]:
image_scaled = preprocessing.scale(image, flux_mag_0)

In [15]:
preprocessing.image_plotter(image_scaled)



In [16]:
sns.distplot(image_scaled.flatten())
plt.title("Distribution of Transformed Intensities")


Out[16]:
Text(0.5, 1.0, 'Distribution of Transformed Intensities')

In [17]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)

    plt.figure()
    preprocessing.image_plotter(image_scaled)
    plt.title("{} band".format(band))
    plt.colorbar()


3) Crop Image

Am I properly handling odd numbers?


In [18]:
pre_transformed_image_size  = 150
post_transformed_image_size = 75

In [19]:
cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)
cutout.shape


Out[19]:
(75, 75)

In [20]:
preprocessing.image_plotter(cutout)



In [21]:
for band in bands:
    image, flux_mag_0 = preprocessing.get_image(HSC_id, band)
    image_scaled = preprocessing.scale(image, flux_mag_0)
    cutout = preprocessing.get_cutout(image_scaled, post_transformed_image_size)

    plt.figure()
    preprocessing.image_plotter(cutout)
    plt.title("{} band".format(band))
    plt.colorbar()


4) Combine filters into cube


In [22]:
images = [None]*len(bands)
flux_mag_0s = [None]*len(bands)
cutouts = [None]*len(bands)
for i, band in enumerate(bands):
    images[i], flux_mag_0s[i] = preprocessing.get_image(HSC_id, band)
    
    cutouts[i] = preprocessing.get_cutout(
        preprocessing.scale(images[i], flux_mag_0s[i]), 
        post_transformed_image_size
    )

In [23]:
cutout_cube = np.array(cutouts)
cutout_cube.shape


Out[23]:
(5, 75, 75)

In [24]:
# must transform into [0,1] for plt.imshow
# the HSC standard tool accomplishes this by clipping instead.
plt.imshow(preprocessing.transform_0_1(cutout_cube[(4,2,0),:,:].transpose(1,2,0)) )


Out[24]:
<matplotlib.image.AxesImage at 0x133106ef0>

In [25]:
for i, band in enumerate(bands):
    sns.distplot(cutout_cube[:,:,:].transpose(1,2,0)[:,:,i].flatten(), label=band)
    plt.legend(loc="best")
    plt.xlabel("pixel intensity (asinh scale)")


5) Load Training Set Labels


In [26]:
training_set_labels_filename = "../data/galaxy_images_training/2017_09_26-dwarf_galaxy_scores.csv"

In [27]:
df = pd.read_csv(training_set_labels_filename)
df = df.drop_duplicates("HSC_id")
df = df.set_index("HSC_id")
df = df[["low_z_low_mass"]]
df = df.rename(columns={"low_z_low_mass":"target"})
df.head()


Out[27]:
target
HSC_id
43158322471244656 False
43158605939114836 False
43159142810013665 False
43158734788125011 False
43158863637144621 True

In [28]:
def load_image_mappable(HSC_id):
    images      = [None]*len(bands)
    flux_mag_0s = [None]*len(bands)
    cutouts     = [None]*len(bands)
    for j, band in enumerate(bands):
        images[j], flux_mag_0s[j] = preprocessing.get_image(HSC_id, band)
        cutouts[j] = preprocessing.get_cutout(
            preprocessing.scale(images[j], flux_mag_0s[j]),
            pre_transformed_image_size
        )
    cutout_cube = np.array(cutouts)
    return cutout_cube

In [29]:
X = np.empty((len(HSC_ids), 5, 
              pre_transformed_image_size, pre_transformed_image_size))

In [30]:
X = np.array(list(map(load_image_mappable, HSC_ids)))

In [31]:
X_full = X
X_small = X[:,(0,2,4),:,:] # drop down to 3 bands

In [32]:
X.shape


Out[32]:
(1866, 5, 150, 150)

In [33]:
Y = df.loc[HSC_ids].target.values
Y


Out[33]:
array([False, False,  True, ..., False, False, False])

In [34]:
Y.mean()


Out[34]:
0.2792068595927117

Geometric Transformations for Data Augmentation


In [35]:
import geometry

h = pre_transformed_image_size
w = pre_transformed_image_size
transform_matrix = geometry.create_random_transform_matrix(h, w,
                                                  include_rotation=True,
                                                  translation_size = .01,
                                                  verbose=False)

x_tmp = X[0][:3]

result = geometry.apply_transform_new(x_tmp, transform_matrix, 
                            channel_axis=0, fill_mode="constant", cval=np.max(x_tmp))

result = preprocessing.get_cutout(x_tmp, post_transformed_image_size)
plt.imshow(preprocessing.transform_0_1(result.transpose(1,2,0)))


Out[35]:
<matplotlib.image.AxesImage at 0x133134eb8>

In [36]:
import ipywidgets
ipywidgets.interact(preprocessing.transform_plotter,
                    X = ipywidgets.fixed(X),
                    rotation_degrees = ipywidgets.IntSlider(min=0, max=360, step=15, value=45),
                    dx_after = ipywidgets.IntSlider(min=-15, max=15),
                    dy_after = ipywidgets.IntSlider(min=-15, max=15),
                    color = ipywidgets.fixed(True),
                    shear_degrees = ipywidgets.IntSlider(min=0, max=90, step=5, value=0),
                    zoom_x = ipywidgets.FloatSlider(min=.5, max=2, value=1),
                    crop = ipywidgets.Checkbox(value=True)
                    )


Out[36]:
<function preprocessing.transform_plotter(X, reflect_x=False, rotation_degrees=45, dx_after=0, dy_after=0, shear_degrees=0, zoom_x=1, crop=False, color=True)>

5b) Split training and testing set


In [37]:
randomized_indices = np.arange(X.shape[0])
np.random.seed(42)
np.random.shuffle(randomized_indices)

testing_fraction = 0.2
testing_set_indices = randomized_indices[:int(testing_fraction*X.shape[0])]
training_set_indices = np.array(list(set([*randomized_indices]) - set([*testing_set_indices])))

In [38]:
print(training_set_indices.size)
print(Y[training_set_indices].mean())


1493
0.27461486939048896

In [39]:
print(testing_set_indices.size)
print(Y[testing_set_indices].mean())


373
0.2975871313672922

In [40]:
p = Y[training_set_indices].mean()
prior_loss = sklearn.metrics.log_loss(Y[testing_set_indices], 
                                      [p]*testing_set_indices.size)

prior_loss


Out[40]:
0.6101087775907272

6b) Adapt NumpyArrayIterator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images.

Also, I want to change the way that augmentation is happening


In [41]:
from data_generator import ArrayIterator

6c) Adapt ImageDataGenerator

The original only allowed 1, 3 or 4 channel images. I have 5 channel images. Also, I'm adjusting the way that the affine transformations work for the data augmentation


In [42]:
from data_generator import ImageDataGenerator

6d) Create Data Generator


In [43]:
print('Using real-time data augmentation.')

h_before, w_before = X[0,0].shape
print("image shape before: ({},{})".format(h_before, w_before))

h_after = post_transformed_image_size
w_after = post_transformed_image_size
print("image shape after:  ({},{})".format(h_after, w_after))

# get a closure that binds the image size to get_cutout
postprocessing_function = lambda image: preprocessing.get_cutout(image, post_transformed_image_size)

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    with_reflection_x=True, # randomly apply a reflection (in x)
    with_reflection_y=True, # randomly apply a reflection (in y)
    with_rotation=False, # randomly apply a rotation
    width_shift_range=0.002,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.002,  # randomly shift images vertically (fraction of total height)
    postprocessing_function=postprocessing_function, # get a cutout of the processed image
    output_image_shape=(post_transformed_image_size,post_transformed_image_size)
)


Using real-time data augmentation.
image shape before: (150,150)
image shape after:  (75,75)

In [44]:
datagen.fit(X[training_set_indices])

7) Set up keras model


In [45]:
n_conv_filters = 16
conv_kernel_size = 4
input_shape = cutout_cube.shape

dropout_fraction = .50

nb_dense = 64

In [46]:
model = Sequential()

model.add(Conv2D(n_conv_filters, conv_kernel_size,
                        padding='same', input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(dropout_fraction))


model.add(Conv2D(n_conv_filters, conv_kernel_size*2,
                        padding='same',))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(dropout_fraction))

model.add(Conv2D(n_conv_filters, conv_kernel_size*4,
                        padding='same',))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(dropout_fraction))

model.add(Flatten())
model.add(Dense(2*nb_dense, activation="relu"))
model.add(Dense(nb_dense, activation="relu"))
model.add(Dense(1, activation="sigmoid"))

In [47]:
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 16, 75, 75)        1296      
_________________________________________________________________
activation_1 (Activation)    (None, 16, 75, 75)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 37, 37)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 16, 37, 37)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 37, 37)        16400     
_________________________________________________________________
activation_2 (Activation)    (None, 16, 37, 37)        0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 18, 18)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 16, 18, 18)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 18, 18)        65552     
_________________________________________________________________
activation_3 (Activation)    (None, 16, 18, 18)        0         
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 16, 9, 9)          0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 16, 9, 9)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1296)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               166016    
_________________________________________________________________
dense_2 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 65        
=================================================================
Total params: 257,585
Trainable params: 257,585
Non-trainable params: 0
_________________________________________________________________

In [48]:
learning_rate = 0.00025
decay = 1e-5
momentum = 0.9

# sgd = SGD(lr=learning_rate, decay=decay, momentum=momentum, nesterov=True)

adam = Adam(lr=learning_rate)

In [49]:
logger_filename = "training.log"

model.compile(loss='binary_crossentropy', 
#               optimizer=sgd, 
              optimizer=adam,
#               metrics=["accuracy"]
             )

if os.path.exists(logger_filename):
    logger_filename_tmp = logger_filename + ".old"
    os.rename(logger_filename, logger_filename_tmp)

In [50]:
earlystopping = EarlyStopping(monitor='loss',
                              patience=35,
                              verbose=1,
                              mode='auto' )

In [51]:
csv_logger = keras.callbacks.CSVLogger(logger_filename,
                                       append=True)

In [52]:
# modelcheckpoint = ModelCheckpoint(pathinCat+'Models/bestmodelMAG.hdf5',monitor='val_loss',verbose=0,save_best_only=True)

8) Run basic keras model


In [53]:
goal_batch_size = 64
steps_per_epoch = max(2, training_set_indices.size//goal_batch_size)
batch_size = training_set_indices.size//steps_per_epoch
print("steps_per_epoch: ", steps_per_epoch)
print("batch_size: ", batch_size)
epoch_to_save_state = 100
epochs = 400
verbose = 1


steps_per_epoch:  23
batch_size:  64

In [54]:
X_test_transformed = np.array([datagen.standardize(X_img)
                               for X_img in X[testing_set_indices]])

X_test_transformed.shape


Out[54]:
(373, 5, 75, 75)

In [55]:
prior_loss


Out[55]:
0.6101087775907272

In [56]:
args = (datagen.flow(X[training_set_indices], 
                                           Y[training_set_indices],
                                           batch_size=batch_size,
                                          ),
       )

kwargs = dict(steps_per_epoch=steps_per_epoch,
              validation_data=(X_test_transformed, 
                               Y[testing_set_indices]),
              verbose=verbose,
              callbacks=[csv_logger],
             )

history = model.fit_generator(*args,
                              epochs=epoch_to_save_state,
                              **kwargs,
                              )  

print("saving state after {} epochs".format(epoch_to_save_state), flush=True)
weights_at_middle = model.get_weights()
print("continuing to get a full learning curve, including overfitted regime")


history = model.fit_generator(*args,
                              initial_epoch=epoch_to_save_state,
                              epochs=epochs,
                              **kwargs,
                              )  

weights_at_end = model.get_weights()


print("restoring model to state at epoch = {}".format(epoch_to_save_state))
model.set_weights(weights_at_middle)
model.save_weights("cached_weights.h5")


Epoch 1/100
23/23 [==============================] - 24s 1s/step - loss: 0.6754 - val_loss: 0.6380
Epoch 2/100
23/23 [==============================] - 23s 991ms/step - loss: 0.6524 - val_loss: 0.6374
Epoch 3/100
23/23 [==============================] - 23s 1s/step - loss: 0.5931 - val_loss: 0.6630
Epoch 4/100
23/23 [==============================] - 23s 1s/step - loss: 0.5973 - val_loss: 0.6581
Epoch 5/100
23/23 [==============================] - 25s 1s/step - loss: 0.5950 - val_loss: 0.6481
Epoch 6/100
23/23 [==============================] - 25s 1s/step - loss: 0.5932 - val_loss: 0.6388
Epoch 7/100
23/23 [==============================] - 22s 969ms/step - loss: 0.5851 - val_loss: 0.6431
Epoch 8/100
23/23 [==============================] - 23s 1s/step - loss: 0.5777 - val_loss: 0.6439
Epoch 9/100
23/23 [==============================] - 23s 1s/step - loss: 0.5918 - val_loss: 0.6393
Epoch 10/100
23/23 [==============================] - 23s 994ms/step - loss: 0.5938 - val_loss: 0.6303
Epoch 11/100
23/23 [==============================] - 24s 1s/step - loss: 0.5847 - val_loss: 0.6362
Epoch 12/100
23/23 [==============================] - 22s 967ms/step - loss: 0.5992 - val_loss: 0.6298
Epoch 13/100
23/23 [==============================] - 24s 1s/step - loss: 0.5921 - val_loss: 0.6295
Epoch 14/100
23/23 [==============================] - 25s 1s/step - loss: 0.5929 - val_loss: 0.6220
Epoch 15/100
23/23 [==============================] - 24s 1s/step - loss: 0.5861 - val_loss: 0.6256
Epoch 16/100
23/23 [==============================] - 24s 1s/step - loss: 0.5991 - val_loss: 0.6350
Epoch 17/100
23/23 [==============================] - 24s 1s/step - loss: 0.5840 - val_loss: 0.6298
Epoch 18/100
23/23 [==============================] - 24s 1s/step - loss: 0.5842 - val_loss: 0.6320
Epoch 19/100
23/23 [==============================] - 24s 1s/step - loss: 0.5909 - val_loss: 0.6288
Epoch 20/100
23/23 [==============================] - 24s 1s/step - loss: 0.5899 - val_loss: 0.6350
Epoch 21/100
23/23 [==============================] - 24s 1s/step - loss: 0.5930 - val_loss: 0.6336
Epoch 22/100
23/23 [==============================] - 24s 1s/step - loss: 0.5866 - val_loss: 0.6145
Epoch 23/100
23/23 [==============================] - 24s 1s/step - loss: 0.5821 - val_loss: 0.6106
Epoch 24/100
23/23 [==============================] - 24s 1s/step - loss: 0.5988 - val_loss: 0.6384
Epoch 25/100
23/23 [==============================] - 24s 1s/step - loss: 0.5931 - val_loss: 0.6264
Epoch 26/100
23/23 [==============================] - 25s 1s/step - loss: 0.5949 - val_loss: 0.6282
Epoch 27/100
23/23 [==============================] - 26s 1s/step - loss: 0.5876 - val_loss: 0.6242
Epoch 28/100
23/23 [==============================] - 26s 1s/step - loss: 0.5940 - val_loss: 0.6283
Epoch 29/100
23/23 [==============================] - 25s 1s/step - loss: 0.5866 - val_loss: 0.6154
Epoch 30/100
23/23 [==============================] - 26s 1s/step - loss: 0.5927 - val_loss: 0.6268
Epoch 31/100
23/23 [==============================] - 26s 1s/step - loss: 0.5915 - val_loss: 0.6234
Epoch 32/100
23/23 [==============================] - 25s 1s/step - loss: 0.5873 - val_loss: 0.6239
Epoch 33/100
23/23 [==============================] - 23s 1s/step - loss: 0.5802 - val_loss: 0.6086
Epoch 34/100
23/23 [==============================] - 23s 986ms/step - loss: 0.5945 - val_loss: 0.6145
Epoch 35/100
23/23 [==============================] - 24s 1s/step - loss: 0.5607 - val_loss: 0.6159
Epoch 36/100
23/23 [==============================] - 25s 1s/step - loss: 0.6014 - val_loss: 0.6270
Epoch 37/100
23/23 [==============================] - 24s 1s/step - loss: 0.5912 - val_loss: 0.6197
Epoch 38/100
23/23 [==============================] - 24s 1s/step - loss: 0.5938 - val_loss: 0.6152
Epoch 39/100
23/23 [==============================] - 24s 1s/step - loss: 0.5740 - val_loss: 0.6130
Epoch 40/100
23/23 [==============================] - 25s 1s/step - loss: 0.5864 - val_loss: 0.6167
Epoch 41/100
23/23 [==============================] - 25s 1s/step - loss: 0.5925 - val_loss: 0.6153
Epoch 42/100
23/23 [==============================] - 24s 1s/step - loss: 0.5936 - val_loss: 0.6188
Epoch 43/100
23/23 [==============================] - 24s 1s/step - loss: 0.5817 - val_loss: 0.6040
Epoch 44/100
23/23 [==============================] - 24s 1s/step - loss: 0.5842 - val_loss: 0.6157
Epoch 45/100
23/23 [==============================] - 24s 1s/step - loss: 0.5898 - val_loss: 0.6077
Epoch 46/100
23/23 [==============================] - 24s 1s/step - loss: 0.5816 - val_loss: 0.6158
Epoch 47/100
23/23 [==============================] - 24s 1s/step - loss: 0.5805 - val_loss: 0.5984
Epoch 48/100
23/23 [==============================] - 24s 1s/step - loss: 0.5886 - val_loss: 0.6175
Epoch 49/100
23/23 [==============================] - 26s 1s/step - loss: 0.5950 - val_loss: 0.5994
Epoch 50/100
23/23 [==============================] - 26s 1s/step - loss: 0.5802 - val_loss: 0.6015
Epoch 51/100
23/23 [==============================] - 26s 1s/step - loss: 0.5872 - val_loss: 0.5949
Epoch 52/100
23/23 [==============================] - 26s 1s/step - loss: 0.5785 - val_loss: 0.5953
Epoch 53/100
23/23 [==============================] - 26s 1s/step - loss: 0.5982 - val_loss: 0.5883
Epoch 54/100
23/23 [==============================] - 27s 1s/step - loss: 0.5909 - val_loss: 0.6110
Epoch 55/100
23/23 [==============================] - 29s 1s/step - loss: 0.5668 - val_loss: 0.6015
Epoch 56/100
23/23 [==============================] - 26s 1s/step - loss: 0.5910 - val_loss: 0.5947
Epoch 57/100
23/23 [==============================] - 26s 1s/step - loss: 0.5896 - val_loss: 0.6000
Epoch 58/100
23/23 [==============================] - 26s 1s/step - loss: 0.5902 - val_loss: 0.5989
Epoch 59/100
23/23 [==============================] - 26s 1s/step - loss: 0.5836 - val_loss: 0.6000
Epoch 60/100
23/23 [==============================] - 26s 1s/step - loss: 0.5768 - val_loss: 0.6033
Epoch 61/100
23/23 [==============================] - 26s 1s/step - loss: 0.5877 - val_loss: 0.5985
Epoch 62/100
23/23 [==============================] - 25s 1s/step - loss: 0.5966 - val_loss: 0.6032
Epoch 63/100
23/23 [==============================] - 24s 1s/step - loss: 0.5714 - val_loss: 0.5903
Epoch 64/100
23/23 [==============================] - 26s 1s/step - loss: 0.5960 - val_loss: 0.5985
Epoch 65/100
23/23 [==============================] - 25s 1s/step - loss: 0.5766 - val_loss: 0.5965
Epoch 66/100
23/23 [==============================] - 27s 1s/step - loss: 0.5645 - val_loss: 0.5923
Epoch 67/100
23/23 [==============================] - 26s 1s/step - loss: 0.5945 - val_loss: 0.5987
Epoch 68/100
23/23 [==============================] - 26s 1s/step - loss: 0.5795 - val_loss: 0.5865
Epoch 69/100
23/23 [==============================] - 27s 1s/step - loss: 0.5912 - val_loss: 0.6212
Epoch 70/100
23/23 [==============================] - 27s 1s/step - loss: 0.5848 - val_loss: 0.6007
Epoch 71/100
23/23 [==============================] - 25s 1s/step - loss: 0.5794 - val_loss: 0.5962
Epoch 72/100
23/23 [==============================] - 27s 1s/step - loss: 0.5826 - val_loss: 0.5953
Epoch 73/100
23/23 [==============================] - 26s 1s/step - loss: 0.5829 - val_loss: 0.5958
Epoch 74/100
23/23 [==============================] - 26s 1s/step - loss: 0.5770 - val_loss: 0.6000
Epoch 75/100
23/23 [==============================] - 26s 1s/step - loss: 0.5770 - val_loss: 0.5893
Epoch 76/100
23/23 [==============================] - 26s 1s/step - loss: 0.6018 - val_loss: 0.5990
Epoch 77/100
23/23 [==============================] - 25s 1s/step - loss: 0.5867 - val_loss: 0.5941
Epoch 78/100
23/23 [==============================] - 24s 1s/step - loss: 0.5667 - val_loss: 0.5963
Epoch 79/100
23/23 [==============================] - 26s 1s/step - loss: 0.5952 - val_loss: 0.5967
Epoch 80/100
23/23 [==============================] - 27s 1s/step - loss: 0.5771 - val_loss: 0.5925
Epoch 81/100
23/23 [==============================] - 26s 1s/step - loss: 0.5739 - val_loss: 0.5881
Epoch 82/100
23/23 [==============================] - 26s 1s/step - loss: 0.5996 - val_loss: 0.5993
Epoch 83/100
23/23 [==============================] - 26s 1s/step - loss: 0.5775 - val_loss: 0.5937
Epoch 84/100
23/23 [==============================] - 26s 1s/step - loss: 0.5852 - val_loss: 0.5925
Epoch 85/100
23/23 [==============================] - 25s 1s/step - loss: 0.5651 - val_loss: 0.5915
Epoch 86/100
23/23 [==============================] - 29s 1s/step - loss: 0.5846 - val_loss: 0.5973
Epoch 87/100
23/23 [==============================] - 26s 1s/step - loss: 0.5876 - val_loss: 0.5941
Epoch 88/100
23/23 [==============================] - 26s 1s/step - loss: 0.5770 - val_loss: 0.5901
Epoch 89/100
23/23 [==============================] - 26s 1s/step - loss: 0.5759 - val_loss: 0.6058
Epoch 90/100
23/23 [==============================] - 26s 1s/step - loss: 0.5846 - val_loss: 0.6024
Epoch 91/100
23/23 [==============================] - 26s 1s/step - loss: 0.5931 - val_loss: 0.5972
Epoch 92/100
23/23 [==============================] - 24s 1s/step - loss: 0.5679 - val_loss: 0.5895
Epoch 93/100
23/23 [==============================] - 26s 1s/step - loss: 0.5701 - val_loss: 0.6013
Epoch 94/100
23/23 [==============================] - 26s 1s/step - loss: 0.5799 - val_loss: 0.5892
Epoch 95/100
23/23 [==============================] - 26s 1s/step - loss: 0.5811 - val_loss: 0.5929
Epoch 96/100
23/23 [==============================] - 26s 1s/step - loss: 0.5794 - val_loss: 0.5937
Epoch 97/100
23/23 [==============================] - 26s 1s/step - loss: 0.5801 - val_loss: 0.5902
Epoch 98/100
23/23 [==============================] - 27s 1s/step - loss: 0.5749 - val_loss: 0.5859
Epoch 99/100
23/23 [==============================] - 25s 1s/step - loss: 0.5897 - val_loss: 0.5935
Epoch 100/100
23/23 [==============================] - 26s 1s/step - loss: 0.5775 - val_loss: 0.5918
saving state after 100 epochs
continuing to get a full learning curve, including overfitted regime
Epoch 101/400
23/23 [==============================] - 26s 1s/step - loss: 0.5820 - val_loss: 0.5846
Epoch 102/400
23/23 [==============================] - 26s 1s/step - loss: 0.5897 - val_loss: 0.5884
Epoch 103/400
23/23 [==============================] - 26s 1s/step - loss: 0.5755 - val_loss: 0.5979
Epoch 104/400
23/23 [==============================] - 26s 1s/step - loss: 0.5794 - val_loss: 0.5895
Epoch 105/400
23/23 [==============================] - 26s 1s/step - loss: 0.5787 - val_loss: 0.5906
Epoch 106/400
23/23 [==============================] - 26s 1s/step - loss: 0.5770 - val_loss: 0.5923
Epoch 107/400
23/23 [==============================] - 24s 1s/step - loss: 0.5842 - val_loss: 0.5951
Epoch 108/400
23/23 [==============================] - 26s 1s/step - loss: 0.5835 - val_loss: 0.5982
Epoch 109/400
23/23 [==============================] - 26s 1s/step - loss: 0.5770 - val_loss: 0.5858
Epoch 110/400
23/23 [==============================] - 25s 1s/step - loss: 0.5661 - val_loss: 0.5910
Epoch 111/400
23/23 [==============================] - 28s 1s/step - loss: 0.5768 - val_loss: 0.5897
Epoch 112/400
23/23 [==============================] - 25s 1s/step - loss: 0.5910 - val_loss: 0.5930
Epoch 113/400
23/23 [==============================] - 26s 1s/step - loss: 0.5542 - val_loss: 0.5832
Epoch 114/400
23/23 [==============================] - 26s 1s/step - loss: 0.5835 - val_loss: 0.5896
Epoch 115/400
23/23 [==============================] - 27s 1s/step - loss: 0.5829 - val_loss: 0.5911
Epoch 116/400
23/23 [==============================] - 27s 1s/step - loss: 0.5677 - val_loss: 0.5914
Epoch 117/400
23/23 [==============================] - 27s 1s/step - loss: 0.5771 - val_loss: 0.5965
Epoch 118/400
23/23 [==============================] - 26s 1s/step - loss: 0.5893 - val_loss: 0.5912
Epoch 119/400
23/23 [==============================] - 25s 1s/step - loss: 0.5680 - val_loss: 0.5998
Epoch 120/400
23/23 [==============================] - 27s 1s/step - loss: 0.5864 - val_loss: 0.5958
Epoch 121/400
23/23 [==============================] - 25s 1s/step - loss: 0.5819 - val_loss: 0.5957
Epoch 122/400
23/23 [==============================] - 24s 1s/step - loss: 0.5786 - val_loss: 0.5917
Epoch 123/400
23/23 [==============================] - 25s 1s/step - loss: 0.5620 - val_loss: 0.5939
Epoch 124/400
23/23 [==============================] - 26s 1s/step - loss: 0.5837 - val_loss: 0.5902
Epoch 125/400
23/23 [==============================] - 26s 1s/step - loss: 0.5795 - val_loss: 0.5928
Epoch 126/400
23/23 [==============================] - 26s 1s/step - loss: 0.5800 - val_loss: 0.5932
Epoch 127/400
23/23 [==============================] - 27s 1s/step - loss: 0.5837 - val_loss: 0.6014
Epoch 128/400
23/23 [==============================] - 26s 1s/step - loss: 0.5870 - val_loss: 0.5950
Epoch 129/400
23/23 [==============================] - 26s 1s/step - loss: 0.5777 - val_loss: 0.6029
Epoch 130/400
23/23 [==============================] - 27s 1s/step - loss: 0.5815 - val_loss: 0.5934
Epoch 131/400
23/23 [==============================] - 27s 1s/step - loss: 0.5820 - val_loss: 0.5952
Epoch 132/400
23/23 [==============================] - 27s 1s/step - loss: 0.5726 - val_loss: 0.5925
Epoch 133/400
23/23 [==============================] - 25s 1s/step - loss: 0.5832 - val_loss: 0.5977
Epoch 134/400
23/23 [==============================] - 26s 1s/step - loss: 0.5833 - val_loss: 0.5934
Epoch 135/400
23/23 [==============================] - 27s 1s/step - loss: 0.5635 - val_loss: 0.5963
Epoch 136/400
23/23 [==============================] - 26s 1s/step - loss: 0.5907 - val_loss: 0.5909
Epoch 137/400
23/23 [==============================] - 23s 993ms/step - loss: 0.5755 - val_loss: 0.5902
Epoch 138/400
23/23 [==============================] - 26s 1s/step - loss: 0.5782 - val_loss: 0.5940
Epoch 139/400
23/23 [==============================] - 26s 1s/step - loss: 0.5592 - val_loss: 0.5869
Epoch 140/400
23/23 [==============================] - 26s 1s/step - loss: 0.6018 - val_loss: 0.6110
Epoch 141/400
23/23 [==============================] - 27s 1s/step - loss: 0.5807 - val_loss: 0.5946
Epoch 142/400
23/23 [==============================] - 26s 1s/step - loss: 0.5756 - val_loss: 0.5953
Epoch 143/400
23/23 [==============================] - 26s 1s/step - loss: 0.5673 - val_loss: 0.5918
Epoch 144/400
23/23 [==============================] - 26s 1s/step - loss: 0.5806 - val_loss: 0.5946
Epoch 145/400
23/23 [==============================] - 26s 1s/step - loss: 0.5759 - val_loss: 0.5889
Epoch 146/400
23/23 [==============================] - 26s 1s/step - loss: 0.5733 - val_loss: 0.5897
Epoch 147/400
23/23 [==============================] - 26s 1s/step - loss: 0.5686 - val_loss: 0.5904
Epoch 148/400
23/23 [==============================] - 26s 1s/step - loss: 0.5813 - val_loss: 0.5937
Epoch 149/400
23/23 [==============================] - 25s 1s/step - loss: 0.5760 - val_loss: 0.5981
Epoch 150/400
23/23 [==============================] - 25s 1s/step - loss: 0.5725 - val_loss: 0.5971
Epoch 151/400
23/23 [==============================] - 25s 1s/step - loss: 0.5769 - val_loss: 0.5951
Epoch 152/400
23/23 [==============================] - 24s 1s/step - loss: 0.5721 - val_loss: 0.5953
Epoch 153/400
23/23 [==============================] - 26s 1s/step - loss: 0.5684 - val_loss: 0.5888
Epoch 154/400
23/23 [==============================] - 25s 1s/step - loss: 0.5982 - val_loss: 0.5940
Epoch 155/400
23/23 [==============================] - 26s 1s/step - loss: 0.5681 - val_loss: 0.5863
Epoch 156/400
23/23 [==============================] - 26s 1s/step - loss: 0.5790 - val_loss: 0.5887
Epoch 157/400
23/23 [==============================] - 27s 1s/step - loss: 0.5670 - val_loss: 0.5850
Epoch 158/400
23/23 [==============================] - 26s 1s/step - loss: 0.5681 - val_loss: 0.5912
Epoch 159/400
23/23 [==============================] - 26s 1s/step - loss: 0.5825 - val_loss: 0.5903
Epoch 160/400
23/23 [==============================] - 27s 1s/step - loss: 0.5743 - val_loss: 0.5895
Epoch 161/400
23/23 [==============================] - 26s 1s/step - loss: 0.5594 - val_loss: 0.5926
Epoch 162/400
23/23 [==============================] - 26s 1s/step - loss: 0.5785 - val_loss: 0.5950
Epoch 163/400
23/23 [==============================] - 26s 1s/step - loss: 0.5670 - val_loss: 0.5949
Epoch 164/400
23/23 [==============================] - 27s 1s/step - loss: 0.5710 - val_loss: 0.6064
Epoch 165/400
23/23 [==============================] - 24s 1s/step - loss: 0.5854 - val_loss: 0.6138
Epoch 166/400
23/23 [==============================] - 24s 1s/step - loss: 0.5550 - val_loss: 0.5946
Epoch 167/400
23/23 [==============================] - 26s 1s/step - loss: 0.5745 - val_loss: 0.5960
Epoch 168/400
23/23 [==============================] - 26s 1s/step - loss: 0.5766 - val_loss: 0.5908
Epoch 169/400
23/23 [==============================] - 26s 1s/step - loss: 0.5746 - val_loss: 0.5834
Epoch 170/400
23/23 [==============================] - 26s 1s/step - loss: 0.5734 - val_loss: 0.5851
Epoch 171/400
23/23 [==============================] - 26s 1s/step - loss: 0.5748 - val_loss: 0.5822
Epoch 172/400
23/23 [==============================] - 26s 1s/step - loss: 0.5722 - val_loss: 0.5993
Epoch 173/400
23/23 [==============================] - 26s 1s/step - loss: 0.5629 - val_loss: 0.5897
Epoch 174/400
23/23 [==============================] - 26s 1s/step - loss: 0.5759 - val_loss: 0.5935
Epoch 175/400
23/23 [==============================] - 26s 1s/step - loss: 0.5815 - val_loss: 0.5936
Epoch 176/400
23/23 [==============================] - 26s 1s/step - loss: 0.5651 - val_loss: 0.5887
Epoch 177/400
23/23 [==============================] - 27s 1s/step - loss: 0.5776 - val_loss: 0.5956
Epoch 178/400
23/23 [==============================] - 26s 1s/step - loss: 0.5662 - val_loss: 0.5885
Epoch 179/400
23/23 [==============================] - 24s 1s/step - loss: 0.5776 - val_loss: 0.5882
Epoch 180/400
23/23 [==============================] - 25s 1s/step - loss: 0.5634 - val_loss: 0.5851
Epoch 181/400
23/23 [==============================] - 26s 1s/step - loss: 0.5820 - val_loss: 0.5855
Epoch 182/400
23/23 [==============================] - 26s 1s/step - loss: 0.5552 - val_loss: 0.6053
Epoch 183/400
23/23 [==============================] - 25s 1s/step - loss: 0.5702 - val_loss: 0.5854
Epoch 184/400
23/23 [==============================] - 27s 1s/step - loss: 0.5825 - val_loss: 0.5925
Epoch 185/400
23/23 [==============================] - 25s 1s/step - loss: 0.5800 - val_loss: 0.5967
Epoch 186/400
23/23 [==============================] - 26s 1s/step - loss: 0.5665 - val_loss: 0.5976
Epoch 187/400
23/23 [==============================] - 26s 1s/step - loss: 0.5800 - val_loss: 0.5917
Epoch 188/400
23/23 [==============================] - 26s 1s/step - loss: 0.5686 - val_loss: 0.5771
Epoch 189/400
23/23 [==============================] - 27s 1s/step - loss: 0.5677 - val_loss: 0.6001
Epoch 190/400
23/23 [==============================] - 25s 1s/step - loss: 0.5701 - val_loss: 0.5883
Epoch 191/400
23/23 [==============================] - 26s 1s/step - loss: 0.5664 - val_loss: 0.5837
Epoch 192/400
23/23 [==============================] - 26s 1s/step - loss: 0.5590 - val_loss: 0.5784
Epoch 193/400
23/23 [==============================] - 25s 1s/step - loss: 0.5652 - val_loss: 0.5826
Epoch 194/400
23/23 [==============================] - 23s 1s/step - loss: 0.5771 - val_loss: 0.5894
Epoch 195/400
23/23 [==============================] - 22s 974ms/step - loss: 0.5735 - val_loss: 0.6077
Epoch 196/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5703 - val_loss: 0.5990
Epoch 197/400
23/23 [==============================] - 23s 1s/step - loss: 0.5716 - val_loss: 0.6045
Epoch 198/400
23/23 [==============================] - 22s 962ms/step - loss: 0.5695 - val_loss: 0.5895
Epoch 199/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5768 - val_loss: 0.5949
Epoch 200/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5609 - val_loss: 0.6004
Epoch 201/400
23/23 [==============================] - 23s 990ms/step - loss: 0.5625 - val_loss: 0.5983
Epoch 202/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5608 - val_loss: 0.5945
Epoch 203/400
23/23 [==============================] - 23s 1s/step - loss: 0.5675 - val_loss: 0.6113
Epoch 204/400
23/23 [==============================] - 22s 964ms/step - loss: 0.5624 - val_loss: 0.5938
Epoch 205/400
23/23 [==============================] - 23s 1s/step - loss: 0.5602 - val_loss: 0.6052
Epoch 206/400
23/23 [==============================] - 22s 954ms/step - loss: 0.5878 - val_loss: 0.6295
Epoch 207/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5770 - val_loss: 0.6025
Epoch 208/400
23/23 [==============================] - 22s 971ms/step - loss: 0.5518 - val_loss: 0.5941
Epoch 209/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5722 - val_loss: 0.5900
Epoch 210/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5575 - val_loss: 0.6177
Epoch 211/400
23/23 [==============================] - 23s 1s/step - loss: 0.5667 - val_loss: 0.5988
Epoch 212/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5801 - val_loss: 0.5932
Epoch 213/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5474 - val_loss: 0.6106
Epoch 214/400
23/23 [==============================] - 22s 958ms/step - loss: 0.5723 - val_loss: 0.6145
Epoch 215/400
23/23 [==============================] - 23s 1s/step - loss: 0.5495 - val_loss: 0.6236
Epoch 216/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5730 - val_loss: 0.6093
Epoch 217/400
23/23 [==============================] - 22s 938ms/step - loss: 0.5659 - val_loss: 0.6058
Epoch 218/400
23/23 [==============================] - 23s 1s/step - loss: 0.5615 - val_loss: 0.5990
Epoch 219/400
23/23 [==============================] - 22s 973ms/step - loss: 0.5622 - val_loss: 0.5968
Epoch 220/400
23/23 [==============================] - 22s 973ms/step - loss: 0.5796 - val_loss: 0.6030
Epoch 221/400
23/23 [==============================] - 22s 969ms/step - loss: 0.5607 - val_loss: 0.6026
Epoch 222/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5609 - val_loss: 0.5959
Epoch 223/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5674 - val_loss: 0.6017
Epoch 224/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5635 - val_loss: 0.5950
Epoch 225/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5695 - val_loss: 0.6107
Epoch 226/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5577 - val_loss: 0.5943
Epoch 227/400
23/23 [==============================] - 23s 1s/step - loss: 0.5583 - val_loss: 0.6095
Epoch 228/400
23/23 [==============================] - 22s 961ms/step - loss: 0.5501 - val_loss: 0.6007
Epoch 229/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5599 - val_loss: 0.6006
Epoch 230/400
23/23 [==============================] - 23s 989ms/step - loss: 0.5545 - val_loss: 0.5879
Epoch 231/400
23/23 [==============================] - 23s 1s/step - loss: 0.5849 - val_loss: 0.6058
Epoch 232/400
23/23 [==============================] - 22s 953ms/step - loss: 0.5508 - val_loss: 0.5918
Epoch 233/400
23/23 [==============================] - 23s 986ms/step - loss: 0.5573 - val_loss: 0.5974
Epoch 234/400
23/23 [==============================] - 23s 1s/step - loss: 0.5751 - val_loss: 0.5971
Epoch 235/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5597 - val_loss: 0.6027
Epoch 236/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5548 - val_loss: 0.5986
Epoch 237/400
23/23 [==============================] - 22s 958ms/step - loss: 0.5740 - val_loss: 0.5922
Epoch 238/400
23/23 [==============================] - 23s 1s/step - loss: 0.5430 - val_loss: 0.6065
Epoch 239/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5711 - val_loss: 0.6089
Epoch 240/400
23/23 [==============================] - 22s 973ms/step - loss: 0.5528 - val_loss: 0.5935
Epoch 241/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5672 - val_loss: 0.6017
Epoch 242/400
23/23 [==============================] - 23s 983ms/step - loss: 0.5547 - val_loss: 0.6208
Epoch 243/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5687 - val_loss: 0.6124
Epoch 244/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5639 - val_loss: 0.6055
Epoch 245/400
23/23 [==============================] - 22s 967ms/step - loss: 0.5595 - val_loss: 0.6039
Epoch 246/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5656 - val_loss: 0.6085
Epoch 247/400
23/23 [==============================] - 23s 1s/step - loss: 0.5597 - val_loss: 0.6159
Epoch 248/400
23/23 [==============================] - 22s 944ms/step - loss: 0.5684 - val_loss: 0.6047
Epoch 249/400
23/23 [==============================] - 23s 983ms/step - loss: 0.5566 - val_loss: 0.5957
Epoch 250/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5517 - val_loss: 0.6050
Epoch 251/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5418 - val_loss: 0.6056
Epoch 252/400
23/23 [==============================] - 23s 1s/step - loss: 0.5681 - val_loss: 0.6220
Epoch 253/400
23/23 [==============================] - 22s 956ms/step - loss: 0.5678 - val_loss: 0.6105
Epoch 254/400
23/23 [==============================] - 23s 1s/step - loss: 0.5470 - val_loss: 0.6208
Epoch 255/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5627 - val_loss: 0.6084
Epoch 256/400
23/23 [==============================] - 22s 951ms/step - loss: 0.5548 - val_loss: 0.6049
Epoch 257/400
23/23 [==============================] - 23s 1s/step - loss: 0.5353 - val_loss: 0.6310
Epoch 258/400
23/23 [==============================] - 22s 956ms/step - loss: 0.5738 - val_loss: 0.6196
Epoch 259/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5583 - val_loss: 0.6048
Epoch 260/400
23/23 [==============================] - 23s 1s/step - loss: 0.5559 - val_loss: 0.6083
Epoch 261/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5510 - val_loss: 0.6084
Epoch 262/400
23/23 [==============================] - 22s 972ms/step - loss: 0.5581 - val_loss: 0.6242
Epoch 263/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5453 - val_loss: 0.6202
Epoch 264/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5665 - val_loss: 0.6208
Epoch 265/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5591 - val_loss: 0.6371
Epoch 266/400
23/23 [==============================] - 23s 983ms/step - loss: 0.5639 - val_loss: 0.6201
Epoch 267/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5588 - val_loss: 0.6127
Epoch 268/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5547 - val_loss: 0.6228
Epoch 269/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5535 - val_loss: 0.6288
Epoch 270/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5516 - val_loss: 0.6291
Epoch 271/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5573 - val_loss: 0.6212
Epoch 272/400
23/23 [==============================] - 22s 968ms/step - loss: 0.5479 - val_loss: 0.6282
Epoch 273/400
23/23 [==============================] - 23s 1s/step - loss: 0.5539 - val_loss: 0.6314
Epoch 274/400
23/23 [==============================] - 22s 959ms/step - loss: 0.5613 - val_loss: 0.6196
Epoch 275/400
23/23 [==============================] - 23s 1s/step - loss: 0.5363 - val_loss: 0.6167
Epoch 276/400
23/23 [==============================] - 22s 956ms/step - loss: 0.5427 - val_loss: 0.6319
Epoch 277/400
23/23 [==============================] - 23s 1s/step - loss: 0.5471 - val_loss: 0.6346
Epoch 278/400
23/23 [==============================] - 22s 957ms/step - loss: 0.5688 - val_loss: 0.6096
Epoch 279/400
23/23 [==============================] - 23s 987ms/step - loss: 0.5603 - val_loss: 0.6129
Epoch 280/400
23/23 [==============================] - 23s 1s/step - loss: 0.5434 - val_loss: 0.6272
Epoch 281/400
23/23 [==============================] - 22s 956ms/step - loss: 0.5560 - val_loss: 0.6250
Epoch 282/400
23/23 [==============================] - 23s 1s/step - loss: 0.5618 - val_loss: 0.6395
Epoch 283/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5563 - val_loss: 0.6330
Epoch 284/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5534 - val_loss: 0.6284
Epoch 285/400
23/23 [==============================] - 22s 962ms/step - loss: 0.5569 - val_loss: 0.6240
Epoch 286/400
23/23 [==============================] - 23s 1s/step - loss: 0.5605 - val_loss: 0.6103
Epoch 287/400
23/23 [==============================] - 22s 947ms/step - loss: 0.5630 - val_loss: 0.6245
Epoch 288/400
23/23 [==============================] - 23s 1s/step - loss: 0.5558 - val_loss: 0.6099
Epoch 289/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5300 - val_loss: 0.6531
Epoch 290/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5610 - val_loss: 0.6254
Epoch 291/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5454 - val_loss: 0.6382
Epoch 292/400
23/23 [==============================] - 22s 971ms/step - loss: 0.5539 - val_loss: 0.6229
Epoch 293/400
23/23 [==============================] - 22s 974ms/step - loss: 0.5534 - val_loss: 0.6265
Epoch 294/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5488 - val_loss: 0.6240
Epoch 295/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5461 - val_loss: 0.6330
Epoch 296/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5453 - val_loss: 0.6443
Epoch 297/400
23/23 [==============================] - 22s 970ms/step - loss: 0.5419 - val_loss: 0.6527
Epoch 298/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5280 - val_loss: 0.6484
Epoch 299/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5655 - val_loss: 0.6446
Epoch 300/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5446 - val_loss: 0.6576
Epoch 301/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5430 - val_loss: 0.6880
Epoch 302/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5444 - val_loss: 0.6754
Epoch 303/400
23/23 [==============================] - 23s 990ms/step - loss: 0.5548 - val_loss: 0.6506
Epoch 304/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5574 - val_loss: 0.6369
Epoch 305/400
23/23 [==============================] - 23s 1s/step - loss: 0.5433 - val_loss: 0.6543
Epoch 306/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5567 - val_loss: 0.6507
Epoch 307/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5638 - val_loss: 0.6482
Epoch 308/400
23/23 [==============================] - 23s 987ms/step - loss: 0.5445 - val_loss: 0.6577
Epoch 309/400
23/23 [==============================] - 22s 974ms/step - loss: 0.5502 - val_loss: 0.6546
Epoch 310/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5532 - val_loss: 0.6282
Epoch 311/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5428 - val_loss: 0.6494
Epoch 312/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5445 - val_loss: 0.6255
Epoch 313/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5501 - val_loss: 0.6324
Epoch 314/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5439 - val_loss: 0.6298
Epoch 315/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5566 - val_loss: 0.6355
Epoch 316/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5336 - val_loss: 0.6383
Epoch 317/400
23/23 [==============================] - 22s 969ms/step - loss: 0.5471 - val_loss: 0.6220
Epoch 318/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5337 - val_loss: 0.6365
Epoch 319/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5452 - val_loss: 0.6313
Epoch 320/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5635 - val_loss: 0.6226
Epoch 321/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5387 - val_loss: 0.6362
Epoch 322/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5440 - val_loss: 0.6407
Epoch 323/400
23/23 [==============================] - 23s 1s/step - loss: 0.5382 - val_loss: 0.6560
Epoch 324/400
23/23 [==============================] - 22s 936ms/step - loss: 0.5460 - val_loss: 0.6416
Epoch 325/400
23/23 [==============================] - 23s 1s/step - loss: 0.5479 - val_loss: 0.6457
Epoch 326/400
23/23 [==============================] - 22s 954ms/step - loss: 0.5381 - val_loss: 0.6730
Epoch 327/400
23/23 [==============================] - 23s 1s/step - loss: 0.5253 - val_loss: 0.6559
Epoch 328/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5456 - val_loss: 0.6724
Epoch 329/400
23/23 [==============================] - 23s 993ms/step - loss: 0.5555 - val_loss: 0.6677
Epoch 330/400
23/23 [==============================] - 22s 973ms/step - loss: 0.5312 - val_loss: 0.6551
Epoch 331/400
23/23 [==============================] - 22s 951ms/step - loss: 0.5424 - val_loss: 0.6651
Epoch 332/400
23/23 [==============================] - 23s 1s/step - loss: 0.5352 - val_loss: 0.6570
Epoch 333/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5266 - val_loss: 0.6639
Epoch 334/400
23/23 [==============================] - 23s 983ms/step - loss: 0.5512 - val_loss: 0.6485
Epoch 335/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5457 - val_loss: 0.6664
Epoch 336/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5332 - val_loss: 0.6668
Epoch 337/400
23/23 [==============================] - 23s 983ms/step - loss: 0.5416 - val_loss: 0.6532
Epoch 338/400
23/23 [==============================] - 22s 960ms/step - loss: 0.5298 - val_loss: 0.6720
Epoch 339/400
23/23 [==============================] - 23s 1s/step - loss: 0.5247 - val_loss: 0.6785
Epoch 340/400
23/23 [==============================] - 23s 982ms/step - loss: 0.5444 - val_loss: 0.6735
Epoch 341/400
23/23 [==============================] - 22s 974ms/step - loss: 0.5350 - val_loss: 0.6724
Epoch 342/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5601 - val_loss: 0.6761
Epoch 343/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5259 - val_loss: 0.6759
Epoch 344/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5355 - val_loss: 0.6723
Epoch 345/400
23/23 [==============================] - 22s 974ms/step - loss: 0.5280 - val_loss: 0.7302
Epoch 346/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5406 - val_loss: 0.6843
Epoch 347/400
23/23 [==============================] - 23s 1s/step - loss: 0.5384 - val_loss: 0.7191
Epoch 348/400
23/23 [==============================] - 22s 960ms/step - loss: 0.5359 - val_loss: 0.7039
Epoch 349/400
23/23 [==============================] - 23s 1s/step - loss: 0.5453 - val_loss: 0.6936
Epoch 350/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5271 - val_loss: 0.6828
Epoch 351/400
23/23 [==============================] - 22s 961ms/step - loss: 0.5413 - val_loss: 0.6947
Epoch 352/400
23/23 [==============================] - 23s 1s/step - loss: 0.5525 - val_loss: 0.6873
Epoch 353/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5302 - val_loss: 0.6767
Epoch 354/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5355 - val_loss: 0.6898
Epoch 355/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5403 - val_loss: 0.7024
Epoch 356/400
23/23 [==============================] - 22s 942ms/step - loss: 0.5261 - val_loss: 0.7069
Epoch 357/400
23/23 [==============================] - 23s 1s/step - loss: 0.5337 - val_loss: 0.6926
Epoch 358/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5269 - val_loss: 0.6922
Epoch 359/400
23/23 [==============================] - 22s 958ms/step - loss: 0.5326 - val_loss: 0.6859
Epoch 360/400
23/23 [==============================] - 23s 1s/step - loss: 0.5392 - val_loss: 0.6904
Epoch 361/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5310 - val_loss: 0.7120
Epoch 362/400
23/23 [==============================] - 23s 979ms/step - loss: 0.5207 - val_loss: 0.7170
Epoch 363/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5318 - val_loss: 0.7333
Epoch 364/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5380 - val_loss: 0.7127
Epoch 365/400
23/23 [==============================] - 23s 1s/step - loss: 0.5365 - val_loss: 0.7125
Epoch 366/400
23/23 [==============================] - 22s 951ms/step - loss: 0.5350 - val_loss: 0.7552
Epoch 367/400
23/23 [==============================] - 23s 988ms/step - loss: 0.5558 - val_loss: 0.7008
Epoch 368/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5298 - val_loss: 0.7598
Epoch 369/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5327 - val_loss: 0.7332
Epoch 370/400
23/23 [==============================] - 23s 1s/step - loss: 0.5345 - val_loss: 0.7187
Epoch 371/400
23/23 [==============================] - 22s 957ms/step - loss: 0.5363 - val_loss: 0.7352
Epoch 372/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5472 - val_loss: 0.6957
Epoch 373/400
23/23 [==============================] - 23s 1s/step - loss: 0.5334 - val_loss: 0.7170
Epoch 374/400
23/23 [==============================] - 22s 955ms/step - loss: 0.5346 - val_loss: 0.7176
Epoch 375/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5178 - val_loss: 0.7368
Epoch 376/400
23/23 [==============================] - 23s 1s/step - loss: 0.5440 - val_loss: 0.7176
Epoch 377/400
23/23 [==============================] - 22s 975ms/step - loss: 0.5241 - val_loss: 0.7432
Epoch 378/400
23/23 [==============================] - 22s 953ms/step - loss: 0.5404 - val_loss: 0.7554
Epoch 379/400
23/23 [==============================] - 23s 1s/step - loss: 0.5208 - val_loss: 0.7685
Epoch 380/400
23/23 [==============================] - 22s 973ms/step - loss: 0.5459 - val_loss: 0.7304
Epoch 381/400
23/23 [==============================] - 23s 980ms/step - loss: 0.5081 - val_loss: 0.7437
Epoch 382/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5400 - val_loss: 0.7586
Epoch 383/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5338 - val_loss: 0.7255
Epoch 384/400
23/23 [==============================] - 22s 977ms/step - loss: 0.5185 - val_loss: 0.7596
Epoch 385/400
23/23 [==============================] - 24s 1s/step - loss: 0.5156 - val_loss: 0.7420
Epoch 386/400
23/23 [==============================] - 24s 1s/step - loss: 0.5291 - val_loss: 0.7709
Epoch 387/400
23/23 [==============================] - 24s 1s/step - loss: 0.5404 - val_loss: 0.7238
Epoch 388/400
23/23 [==============================] - 23s 1s/step - loss: 0.5237 - val_loss: 0.7477
Epoch 389/400
23/23 [==============================] - 22s 976ms/step - loss: 0.5240 - val_loss: 0.7437
Epoch 390/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5454 - val_loss: 0.7467
Epoch 391/400
23/23 [==============================] - 23s 985ms/step - loss: 0.5456 - val_loss: 0.7267
Epoch 392/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5365 - val_loss: 0.7103
Epoch 393/400
23/23 [==============================] - 23s 984ms/step - loss: 0.5125 - val_loss: 0.7495
Epoch 394/400
23/23 [==============================] - 22s 978ms/step - loss: 0.5441 - val_loss: 0.7408
Epoch 395/400
23/23 [==============================] - 23s 981ms/step - loss: 0.5345 - val_loss: 0.7265
Epoch 396/400
23/23 [==============================] - 23s 1s/step - loss: 0.5261 - val_loss: 0.7495
Epoch 397/400
23/23 [==============================] - 23s 978ms/step - loss: 0.5255 - val_loss: 0.7478
Epoch 398/400
23/23 [==============================] - 22s 959ms/step - loss: 0.5237 - val_loss: 0.7726
Epoch 399/400
23/23 [==============================] - 23s 1s/step - loss: 0.5278 - val_loss: 0.7448
Epoch 400/400
23/23 [==============================] - 22s 955ms/step - loss: 0.5190 - val_loss: 0.7613
restoring model to state at epoch = 100

In [84]:
logged_history = pd.read_csv(logger_filename)
logged_history.head()


Out[84]:
epoch loss val_loss
0 0 0.674815 0.637952
1 1 0.650405 0.637426
2 2 0.595751 0.663026
3 3 0.594203 0.658145
4 4 0.594962 0.648077

In [85]:
color_CNN = "b"
color_RF = "g"


linestyle_CNN = "solid"
linestyle_RF = "dashed"

linewidth=3

In [86]:
def plot_learning_curve(logged_history, with_convolution=False, plot_kwargs={}):
    
    with mpl.rc_context(rc={"figure.figsize": (10,6)}):

        loss = logged_history["loss"]
        val_loss = logged_history["val_loss"]
        
        if with_convolution:
            simple_conv = lambda x: np.convolve(x, np.ones(5)/5, mode="valid")
            
            loss = simple_conv(loss)
            val_loss = simple_conv(val_loss)
            

        plt.axhline(prior_loss, label="Initial Bias", 
                    linestyle="dashed", color="black",
                    **plot_kwargs,
                   )

        plt.plot(val_loss, label="Validation", 
                 **plot_kwargs,
                )
        plt.plot(loss, label="Training",
                 **plot_kwargs,
                )

        plt.xlabel("Epoch")
        plt.ylabel("Loss\n(avg. binary cross-entropy)")

        plt.legend()

plot_learning_curve(logged_history,)



In [87]:
plot_learning_curve(logged_history, 
                    with_convolution=True, plot_kwargs=dict(linewidth=linewidth),
                   )



In [88]:
logged_history_for_thesis = pd.read_csv("training.thesis.log")
logged_history_for_thesis.head()

plot_learning_curve(logged_history_for_thesis, 
                    with_convolution=True, plot_kwargs=dict(linewidth=linewidth),
                   )

plt.xlim(right=350)

plot_filename = "plots_for_thesis/learning_curve-smoothed"
plt.tight_layout()
plt.savefig(plot_filename + ".pdf")
plt.savefig(plot_filename + ".png")


9) Look at validation results

9A) First save CNN-predicted probabilities

This also saves "predictions" for the training data, just to be complete


In [89]:
overwrite = False
use_cached_if_exists = True
class_probs_filename = "class_probs.csv"

if use_cached_if_exists and pathlib.Path(class_probs_filename).is_file():
    df_class_probs = pd.read_csv(class_probs_filename)
else:
    X_transformed = np.array([datagen.standardize(X_img)
                              for X_img in X])
    class_probs = model.predict_proba(X_transformed).flatten()
    
    df_class_probs = pd.DataFrame({
        "HSC_id": HSC_ids,
        "CNN_prob": class_probs,
        "testing": [HSC_id in HSC_ids[testing_set_indices] for HSC_id in HSC_ids],
        "target": Y,
    })

    if overwrite or (not pathlib.Path(class_probs_filename).is_file()):
        df_class_probs.to_csv(class_probs_filename, index=False)

df_class_probs.head()


Out[89]:
HSC_id CNN_prob testing target
0 43158176442374224 0.433963 False False
1 43158176442374373 0.232534 False False
2 43158176442374445 0.429375 False True
3 43158176442375078 0.412902 False True
4 43158176442375086 0.197128 False False

9b) Combined CNN probability with RF prob


In [90]:
def logit(p):
    return np.log(p/(1-p))

def expit(x):
    return (1 + np.exp(-x))**-1

def combine_probabilities(prob_a, prob_b, prior_prob_a, prior_prob_b, prior_prob_overall):
    logit_post_a = logit(prob_a)
    logit_post_b = logit(prob_b)
    logit_prior_a = logit(prior_prob_a)
    logit_prior_b = logit(prior_prob_b)
    
    logit_prior_overall = logit(prior_prob_overall)
    
    p_combined = expit(logit_prior_overall + logit_post_a + logit_post_b - logit_prior_a - logit_prior_b)
    return p_combined

In [91]:
hdf_file = pathlib.Path.cwd().parent / "catalog_only_classifier" / "results_cross-validated_all.hdf5"
df_catalog_only = pd.read_hdf(hdf_file)
df_catalog_only = df_catalog_only[["HSC_id", "RF_prob", "target"]]
df_catalog_only = df_catalog_only.set_index("HSC_id").reset_index()
df_catalog_only.head()

df_combined_probs = pd.DataFrame(dict(
    HSC_id = df_catalog_only.HSC_id,
    RF_prob = df_catalog_only.RF_prob,
    CNN_prob = [np.nan]*df_catalog_only.shape[0],
    target = df_catalog_only.target,
    testing = True,
    weight = 1,
    has_CNN = False,
))
df_combined_probs = df_combined_probs.set_index("HSC_id")

for _, row in df_class_probs.iterrows():
    df_combined_probs.loc[row.HSC_id, "CNN_prob"] = row.CNN_prob
    df_combined_probs.loc[row.HSC_id, "testing"] = row.testing
    df_combined_probs.loc[row.HSC_id, "weight"] = df_class_probs.testing.mean()**-1
    df_combined_probs.loc[row.HSC_id, "has_CNN"] = True
    
df_combined_probs["RF_prob_softened"] = (df_combined_probs.RF_prob * 1000 + 1) / (1000 + 2) 
RF_softened_prior = df_combined_probs.RF_prob_softened.mean()
CNN_prior = df_combined_probs[df_combined_probs.has_CNN & ~df_combined_probs.testing].target.mean()
df_combined_probs["combined_prob"] = df_combined_probs.apply(lambda row:
                                                             combine_probabilities(row.RF_prob_softened, 
                                                                                   row.CNN_prob,
                                                                                   RF_softened_prior, 
                                                                                   CNN_prior,
                                                                                   RF_softened_prior,
                                                                                  )
                                                             if row.has_CNN else row.RF_prob_softened,
                                                             axis="columns",
                                                            )
    
df_testing = df_combined_probs[df_combined_probs.testing]

df_combined_probs.head()


Out[91]:
RF_prob CNN_prob target testing weight has_CNN RF_prob_softened combined_prob
HSC_id
43158176442354198 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354210 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354213 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354230 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354240 0.0 NaN False True 1.0 False 0.000998 0.000998

9c) Plots for just CNN probabilities


In [66]:
from sklearn import metrics
from sklearn.metrics import roc_auc_score

with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    df_testing_tmp = df_testing[df_testing.has_CNN]
    fpr, tpr, _ = metrics.roc_curve(df_testing_tmp.target, df_testing_tmp.CNN_prob)
    roc_auc = roc_auc_score(df_testing_tmp.target, df_testing_tmp.CNN_prob)

    plt.plot(fpr, tpr, label="CNN (AUC = {:.2})".format(roc_auc), color=color_CNN, linewidth=linewidth)
    plt.plot([0,1], [0,1], linestyle="dotted", color="black", label="Random guessing",
             linewidth=linewidth,
            )

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")

    plt.title("Within CNN subsample")

    plt.legend(loc="best")
    
    plt.tight_layout()
    plot_filename = "plots_for_thesis/ROC-CNN_sample"
    plt.savefig(plot_filename + ".pdf")
    plt.savefig(plot_filename + ".png")



In [67]:
from sklearn.metrics import average_precision_score
with mpl.rc_context(rc={"figure.figsize": (10,6)}):
    df_testing_tmp = df_testing[df_testing.has_CNN]
    precision, recall, _ = metrics.precision_recall_curve(df_testing_tmp.target, df_testing_tmp.CNN_prob)
    pr_auc = average_precision_score(df_testing_tmp.target, df_testing_tmp.CNN_prob)

    plt.plot(recall, precision, label="AUC = {:.3}".format(pr_auc), color=color_CNN, linewidth=linewidth)
    
    plt.plot([0,1], [Y[testing_set_indices].mean()]*2, linestyle="dotted", color="black", 
             label="Random guessing",
             linewidth=linewidth,
            )


    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("Completeness")
    plt.ylabel("Purity ")

    plt.title("Within CNN subsample")

    plt.legend(loc="best")
    
    plt.tight_layout()
    plot_filename = "plots_for_thesis/purity_completeness-CNN_sample"
    plt.savefig(plot_filename + ".pdf")
    plt.savefig(plot_filename + ".png")


Now run some metrics for CNN + RF

First, I need to find a way to combine the two scores from the models.

For just creating scores (e.g. for PR and ROC curves) I can do the following:

1) If the object didn't pass the RF cut, then leave it's score as the RF probability 2) If the object did pass the RF cut, then use $\mathrm{score} = 1 + p_\mathrm{CNN}$. This is effectively what I'm doing in terms of my thresholded cut.

But what about the actual combined probability? I ideally want to get a binary cross-entropy, to see if this helped at all. Also, there's still some information that can be used from the RF

For now just try averaging but see if we can come up with something better later


In [68]:
df_combined_probs[df_combined_probs.has_CNN].head()


Out[68]:
RF_prob CNN_prob target testing weight has_CNN RF_prob_softened combined_prob
HSC_id
43158176442374224 0.302 0.433963 False False 5.002681 True 0.302395 0.467475
43158176442374373 0.113 0.232534 False False 5.002681 True 0.113772 0.093172
43158176442374445 0.438 0.429375 True False 5.002681 True 0.438124 0.607818
43158176442375078 0.574 0.412902 True False 5.002681 True 0.573852 0.714417
43158176442375086 0.174 0.197128 False False 5.002681 True 0.174651 0.120678

In [69]:
df_tmp = df_combined_probs[df_combined_probs.has_CNN]
plt.hexbin(df_tmp.RF_prob, df_tmp.CNN_prob, gridsize=30)


Out[69]:
<matplotlib.collections.PolyCollection at 0x13631be80>

In [70]:
df_testing.head()


Out[70]:
RF_prob CNN_prob target testing weight has_CNN RF_prob_softened combined_prob
HSC_id
43158176442354198 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354210 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354213 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354230 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354240 0.0 NaN False True 1.0 False 0.000998 0.000998

In [71]:
from scipy.special import expit
threshold_probs = expit(np.linspace(-9, 6, num=2000))
threshold_probs = np.array([-1e-6, *threshold_probs, 1+1e-6])

def get_purities(key, df_results=df_testing, threshold_probs=threshold_probs, weighted=True):
    
    purities = np.empty_like(threshold_probs)
    df_tmp = df_results[[key, "target", "weight"]].copy()
    if not weighted:
        df_tmp["weight"] = 1
    
    for i, threshold_prob in enumerate(threshold_probs):
        mask = df_tmp[key] > threshold_prob
        purities[i] = (df_tmp["target"][mask] * df_tmp["weight"][mask]).sum() / df_tmp["weight"][mask].sum()
    
    return purities

def get_completenesses(key, df_results=df_testing, threshold_probs=threshold_probs, weighted=True):
    
    completenesses = np.empty_like(threshold_probs)
    df_tmp = df_results[[key, "target", "weight"]]
    df_tmp = df_tmp[df_tmp.target].copy()
    if not weighted:
        df_tmp["weight"] = 1

    for i, threshold_prob in enumerate(threshold_probs):
        mask = df_tmp[key] > threshold_prob

        completenesses[i] = df_tmp[mask].weight.sum() / df_tmp.weight.sum()
    
    return completenesses

In [72]:
print("ROC AUC - RF (original): ", roc_auc_score(df_combined_probs.target, df_combined_probs.RF_prob_softened))
print("ROC AUC - RF (weighted): ", roc_auc_score(df_testing.target, df_testing.RF_prob_softened, sample_weight=df_testing.weight))


ROC AUC - RF (original):  0.9715697894047796
ROC AUC - RF (weighted):  0.9727745254070185

In [73]:
print("PR AUC - RF (original): ", sklearn.metrics.average_precision_score(df_combined_probs.target, df_combined_probs.RF_prob_softened))
print("PR AUC - RF (weighted): ", sklearn.metrics.average_precision_score(df_testing.target, df_testing.RF_prob_softened, sample_weight=df_testing.weight))


PR AUC - RF (original):  0.4717104271298889
PR AUC - RF (weighted):  0.4821918238086995

In [74]:
print("ROC AUC - combined (weighted): ", roc_auc_score(df_testing.target, df_testing.combined_prob, sample_weight=df_testing.weight))


ROC AUC - combined (weighted):  0.9727389828610948

In [75]:
print("PR AUC - RF (weighted):       ", sklearn.metrics.average_precision_score(df_testing.target, df_testing.RF_prob_softened, sample_weight=df_testing.weight))
print("PR AUC - combined (weighted): ", sklearn.metrics.average_precision_score(df_testing.target, df_testing.combined_prob, sample_weight=df_testing.weight))


PR AUC - RF (weighted):        0.4821918238086995
PR AUC - combined (weighted):  0.4815925383121601

In [76]:
from sklearn import metrics
from sklearn.metrics import roc_auc_score

with mpl.rc_context(rc={"figure.figsize": np.array((10,6))}):
    fpr, tpr, _ = metrics.roc_curve(df_testing.target, df_testing.combined_prob, sample_weight = df_testing.weight)
    roc_auc = roc_auc_score(df_testing.target, df_testing.combined_prob, sample_weight=df_testing.weight)

    plt.plot(fpr, tpr, label="CNN+RF  (AUC = {:.5})".format(roc_auc), color=color_CNN, linewidth=linewidth,
             linestyle=linestyle_CNN)
    
    fpr, tpr, _ = metrics.roc_curve(df_testing.target, df_testing.RF_prob_softened,
                                    sample_weight=df_testing.weight,
                                   )
    roc_auc_RF = roc_auc_score(df_testing.target, df_testing.RF_prob_softened,
                               sample_weight=df_testing.weight,
                              )
    plt.plot(fpr, tpr, label="RF            (AUC = {:.5})".format(roc_auc_RF), 
             color=color_RF, linewidth=linewidth, linestyle=linestyle_RF)
    
    
    plt.plot([0,1], [0,1], linestyle="dotted", color="black", label="Random guessing",
             linewidth=linewidth,
            )

#     plt.xlim(0,1)
#     plt.ylim(0,1)

    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")

    plt.title("CNN and RF Combined")

    plt.legend(loc="best")
    
    plt.tight_layout()
    plot_filename = "plots_for_thesis/ROC-CNN_and_RF"
    plt.savefig(plot_filename + ".pdf")
    plt.savefig(plot_filename + ".png")



In [77]:
from sklearn.metrics import average_precision_score
with mpl.rc_context(rc={"figure.figsize": (10,6)}):    
    pr_auc = average_precision_score(df_testing.target, df_testing.combined_prob, sample_weight=df_testing.weight)
    recall = get_completenesses("combined_prob")
    precision = get_purities("combined_prob")
    
    plt.plot(recall, precision, label="CNN+RF (AUC = {:.4})".format(pr_auc), 
             color=color_CNN, linewidth=linewidth)

    pr_auc_RF = average_precision_score(df_testing.target, df_testing.RF_prob_softened,
                                        sample_weight=df_testing.weight)
    recall_RF = get_completenesses("RF_prob_softened")
    precision_RF = get_purities("RF_prob_softened")

    plt.plot(recall_RF, precision_RF, label="RF           (AUC = {:.4})".format(pr_auc_RF), 
             color=color_RF, linewidth=linewidth)
    
    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.xlabel("Completeness")
    plt.ylabel("Purity ")

    plt.title("CNN and RF Combined")

    plt.legend(loc="best")
    
    plt.tight_layout()
    plot_filename = "plots_for_thesis/purity_completeness-CNN_and_RF"
    plt.savefig(plot_filename + ".pdf")
    plt.savefig(plot_filename + ".png")


/Users/egentry/anaconda3/envs/tf36/lib/python3.6/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in double_scalars
  

In [78]:
df_testing.head()


Out[78]:
RF_prob CNN_prob target testing weight has_CNN RF_prob_softened combined_prob
HSC_id
43158176442354198 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354210 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354213 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354230 0.0 NaN False True 1.0 False 0.000998 0.000998
43158176442354240 0.0 NaN False True 1.0 False 0.000998 0.000998

In [79]:
with plt.rc_context({"figure.figsize":1.5*np.array((8,6))}):
    theoretical_probs=np.linspace(0,1,num=11)
    empirical_probs_RF = np.empty(theoretical_probs.size-1)
    num_in_bin_RF = np.empty_like(empirical_probs_RF)

    empirical_probs_combined = np.empty(theoretical_probs.size-1)
    num_in_bin_combined = np.empty_like(empirical_probs_combined)

    for i in range(theoretical_probs.size-1):
        prob_lim_low  = theoretical_probs[i]
        prob_lim_high = theoretical_probs[i+1]

        mask_RF = (df_testing["RF_prob_softened"] >= prob_lim_low) & (df_testing["RF_prob_softened"] < prob_lim_high)
        empirical_probs_RF[i] = df_testing["target"][mask_RF].mean()
        num_in_bin_RF[i] = df_testing["target"][mask_RF].size

        mask_combined = (df_testing["combined_prob"] >= prob_lim_low) & (df_testing["combined_prob"] < prob_lim_high)
        empirical_probs_combined[i] = df_testing["target"][mask_combined].mean()
        num_in_bin_combined[i] = df_testing["target"][mask_combined].size

    f, (ax1, ax2) = plt.subplots(2, sharex=True, 
                                 gridspec_kw = {'height_ratios':[1, 3]},
                                )

    ax1.plot(theoretical_probs, [num_in_bin_combined[0], *num_in_bin_combined],
             drawstyle="steps", color=color_CNN,
             linewidth=linewidth,
            )
    
    ax1.plot(theoretical_probs, [num_in_bin_RF[0], *num_in_bin_RF],
             drawstyle="steps", color=color_RF,
             linewidth=linewidth,
            )

    ax1.set_yscale("log")
    ax1.set_ylim(bottom=10**-.5, top=10**6.5)
    ax1.yaxis.set_ticks([1e0, 1e3, 1e6])
    ax1.set_ylabel("Number of \nGalaxies in Bin")

    ax2.step(theoretical_probs, [empirical_probs_RF[0], *empirical_probs_RF], 
             linestyle="steps", color=color_RF, label="RF",
             linewidth=linewidth,
            )
    
    ax2.step(theoretical_probs, [empirical_probs_combined[0], *empirical_probs_combined], 
             linestyle="steps", color=color_CNN, label="RF + CNN",
             linewidth=linewidth,
            )

    # ax2.plot(theoretical_probs, theoretical_probs-.05, 
    #          drawstyle="steps", color="black", label="ideal", linestyle="dotted")

    ax2.fill_between(theoretical_probs, theoretical_probs-theoretical_probs[1], theoretical_probs, 
                     step="pre", color="black", label="ideal", alpha=.2,
                     linewidth=linewidth,
                    )

    plt.xlabel("Reported Probability")
    plt.ylabel("Actual (Binned) Probability")

    plt.legend(loc="upper left")

    plt.xlim(0,1)
    plt.ylim(0,1)


    plt.tight_layout()


    filename = "plots_for_thesis/probability-calibration-RF_and_CNN"
    plt.tight_layout()
    plt.savefig(filename + ".pdf")
    plt.savefig(filename + ".png")


Now get the combined cross entropy


In [80]:
sklearn.metrics.log_loss(df_testing.target,
                         df_testing.combined_prob,
                         sample_weight=df_testing.weight,
                        )


Out[80]:
0.007040964268274968

In [81]:
sklearn.metrics.log_loss(df_testing.target,
                         df_testing.RF_prob_softened,
                         sample_weight=df_testing.weight,
                        )


Out[81]:
0.007018665013401135

In [82]:
sklearn.metrics.log_loss(df_testing[df_testing.has_CNN].target,
                         df_testing[df_testing.has_CNN].combined_prob,
                         sample_weight=df_testing[df_testing.has_CNN].weight,
                        )


Out[82]:
0.521149599081223

In [83]:
sklearn.metrics.log_loss(df_testing[df_testing.has_CNN].target,
                         df_testing[df_testing.has_CNN].RF_prob,
                         sample_weight=df_testing[df_testing.has_CNN].weight,
                        )


Out[83]:
0.5173064221591772

In [ ]:

Misc analysis

Apply a threshold

For now, just use a threshold using the prior class probability (estimated from the training set)


In [ ]:
predicted_classes = class_probs > (Y[training_set_indices].mean())
predicted_classes.mean()

In [ ]:
confusion_matrix = metrics.confusion_matrix(Y[testing_set_indices], predicted_classes)
confusion_matrix

In [ ]:
print("number of dwarfs (true)     : ", Y[testing_set_indices].sum())
print("number of dwarfs (predicted): ", predicted_classes.sum())

In [ ]:
import itertools
# adapted from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

confusion_matrix_normalized = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]

plt.imshow(confusion_matrix_normalized, interpolation='nearest',cmap="gray_r")
# plt.title(title)
plt.colorbar()
tick_marks = np.arange(2)
classes = [False, True]
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = 'd'
thresh = 1 / 2.
for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])):
    plt.text(j, i, format(confusion_matrix[i, j], fmt),
             fontdict={"size":20},
             horizontalalignment="center",
             color="white" if confusion_matrix_normalized[i, j] > thresh else "black")

plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.tight_layout()

In [ ]:
print("  i - Y_true[i] - Y_pred[i] -  error?")
print("-------------------------------------")
for i in range(predicted_classes.size):
    print("{:>3} -    {:>1d}      -    {:d}      -   {:>2d}".format(i, 
                                                    Y[testing_set_indices][i], 
                                                    predicted_classes[i],
                                                    (Y[testing_set_indices][i] != predicted_classes[i]), 
                                                   ))

Analyze Errors


In [ ]:
HSC_ids

In [ ]:
df.loc[HSC_ids[testing_set_indices]].head()

In [ ]:
COSMOS_filename = os.path.join(dwarfz.data_dir_default, 
                               "COSMOS_reference.sqlite")
COSMOS = dwarfz.datasets.COSMOS(COSMOS_filename)

In [ ]:
HSC_filename = os.path.join(dwarfz.data_dir_default, 
                            "HSC_COSMOS_median_forced.sqlite3")
HSC = dwarfz.datasets.HSC(HSC_filename)

In [ ]:
matches_filename = os.path.join(dwarfz.data_dir_default, 
                                "matches.sqlite3")
matches_df = dwarfz.matching.Matches.load_from_filename(matches_filename)

In [ ]:
combined = matches_df[matches_df.match].copy()
combined["ra"]       = COSMOS.df.loc[combined.index].ra
combined["dec"]      = COSMOS.df.loc[combined.index].dec
combined["photo_z"]  = COSMOS.df.loc[combined.index].photo_z
combined["log_mass"] = COSMOS.df.loc[combined.index].mass_med
combined["active"]   = COSMOS.df.loc[combined.index].classification

combined = combined.set_index("catalog_2_ids")

combined.head()

In [ ]:
df_features_testing = combined.loc[HSC_ids[testing_set_indices]]

In [ ]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
#            C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Number of objects",
            )

plt.suptitle("All Objects")

In [ ]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
# mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )


plt.suptitle("All Objects")

^^ Huh, that's a pretty uniform looking spread. It doesn't really seem like it's trending in an useful direction (either near the desired boundaries or as you go further away).


In [ ]:
df_tmp = df_features_testing.loc[HSC_ids[testing_set_indices]]
df_tmp["error"] =   np.array(Y[testing_set_indices], dtype=int) \
                  - np.array(predicted_classes, dtype=int)


mask = (df_tmp.photo_z < .5)
mask &= (df_tmp.error == -1)

print(sum(mask))

plt.hexbin(df_tmp.photo_z[mask], 
           df_tmp.log_mass[mask],
           C=class_probs,
           gridsize=20,
           cmap="Blues",
           vmin=0,
          )

plt.xlabel("photo z")
plt.ylabel("log M_star")

plt.gca().add_patch(
    patches.Rectangle([0, 8], 
                      .15, 1, 
                      fill=False, 
                      linewidth=3,
                      color="red",
                     ),
)

plt.colorbar(label="Average Predicted\nProb. within bin",
            )

plt.suptitle("False Positives")

In [ ]: