Train the StarNet Model

This notebook takes you through the steps of how to train a StarNet Model

  • Required Python packages: numpy h5py keras
  • Required data files: training_data.h5, mean_and_std.npy

Note: We use tensorflow for the keras backend.


In [1]:
import numpy as np
import h5py
import random

from keras.models import Model
from keras.layers import Input, Dense, InputLayer, Flatten, Reshape
from keras.layers.convolutional import Conv1D
from keras.layers.convolutional import MaxPooling1D
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.utils import HDF5Matrix

datadir = ""
training_set = datadir + 'training_data.h5'
normalization_data = datadir + 'mean_and_std.npy'


Using TensorFlow backend.

Normalization

Write a function to normalize the output labels. Each label will be normalized to have approximately have a mean of zero and unit variance.

NOTE: This is necessary to put output labels on a similar scale in order for the model to train properly, this process is reversed in the test stage to give the output labels their proper units


In [3]:
mean_and_std = np.load(normalization_data)
mean_labels = mean_and_std[0]
std_labels = mean_and_std[1]

def normalize(labels):
    # Normalize labels
    return (labels-mean_labels) / std_labels

Obtain training data

Here we will collect the output labels for the training and cross-validation sets, then normalize each.

Next we will create an HDF5Matrix for the training and cross-validation input spectra rather than loading them all into memory. This is useful to save RAM when training the model.


In [4]:
# Define the number of output labels
num_labels = np.load(datadir+'mean_and_std.npy').shape[1]

# Define the number of training spectra
num_train = 41000

# Load labels
with  h5py.File(training_set, 'r') as F:
    y_train = np.hstack((F['TEFF'][0:num_train], F['LOGG'][0:num_train], F['FE_H'][0:num_train]))
    y_cv = np.hstack((F['TEFF'][num_train:], F['LOGG'][num_train:], F['FE_H'][num_train:]))

# Normalize labels
y_train = normalize(y_train)
y_cv = normalize(y_cv)

# Create the spectra training and cv datasets
x_train = HDF5Matrix(training_set, 'spectrum', 
                           start=0, end=num_train)
x_cv = HDF5Matrix(training_set, 'spectrum', 
                           start=num_train, end=None)

# Define the number of output labels
num_labels = y_train.shape[1]

num_fluxes = x_train.shape[1]

print('Each spectrum contains ' + str(num_fluxes) + ' wavelength bins')
print('Training set includes ' + str(x_train.shape[0]) + 
      ' spectra and the cross-validation set includes ' + str(x_cv.shape[0])+' spectra')


Each spectrum contains 7214 wavelength bins
Training set includes 41000 spectra and the cross-validation set includes 3784 spectra

Build the StarNet model architecture

The StarNet architecture is built with:

  • input layer
  • 2 convolutional layers
  • 1 maxpooling layer followed by flattening for the fully connected layer
  • 2 fully connected layers
  • output layer

First, let's define some model variables.


In [5]:
# activation function used following every layer except for the output layers
activation = 'relu'

# model weight initializer
initializer = 'he_normal'

# number of filters used in the convolutional layers
num_filters = [4,16]

# length of the filters in the convolutional layers
filter_length = 8

# length of the maxpooling window 
pool_length = 4

# number of nodes in each of the hidden fully connected layers
num_hidden = [256,128]

# number of spectra fed into model at once during training
batch_size = 64

# maximum number of interations for model training
max_epochs = 30

# initial learning rate for optimization algorithm
lr = 0.0007
    
# exponential decay rate for the 1st moment estimates for optimization algorithm
beta_1 = 0.9

# exponential decay rate for the 2nd moment estimates for optimization algorithm
beta_2 = 0.999

# a small constant for numerical stability for optimization algorithm
optimizer_epsilon = 1e-08

In [6]:
# Input spectra
input_spec = Input(shape=(num_fluxes,), name='starnet_input_x')

# Reshape spectra for CNN layers
cur_in = Reshape((num_fluxes, 1))(input_spec)

# CNN layers
cur_in = Conv1D(kernel_initializer=initializer, activation=activation, 
                padding="same", filters=num_filters[0], kernel_size=filter_length)(cur_in)
cur_in = Conv1D(kernel_initializer=initializer, activation=activation,
                padding="same", filters=num_filters[1], kernel_size=filter_length)(cur_in)

# Max pooling layer
cur_in = MaxPooling1D(pool_size=pool_length)(cur_in)

# Flatten the current input for the fully-connected layers
cur_in = Flatten()(cur_in)

# Fully-connected layers
cur_in = Dense(units=num_hidden[0], kernel_initializer=initializer, 
               activation=activation)(cur_in)
cur_in = Dense(units=num_hidden[1], kernel_initializer=initializer, 
               activation=activation)(cur_in)

# Output nodes
output_label = Dense(units=num_labels, activation="linear", 
                     input_dim=num_hidden[1], name='starnet_output_y')(cur_in)

model = Model(input_spec, output_label, name='StarNet')

More model techniques

  • The Adam optimizer is the gradient descent algorithm used for minimizing the loss function
  • EarlyStopping uses the cross-validation set to test the model following every iteration and stops the training if the cv loss does not decrease by min_delta after patience iterations
  • ReduceLROnPlateau is a form of learning rate decay where the learning rate is decreased by a factor of factor if the training loss does not decrease by epsilon after patience iterations unless the learning rate has reached min_lr

In [7]:
# Default loss function parameters
early_stopping_min_delta = 0.0001
early_stopping_patience = 4
reduce_lr_factor = 0.5
reuce_lr_epsilon = 0.0009
reduce_lr_patience = 2
reduce_lr_min = 0.00008

# loss function to minimize
loss_function = 'mean_squared_error'

# compute mean absolute deviation
metrics = ['mae']

In [8]:
optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=optimizer_epsilon, decay=0.0)

early_stopping = EarlyStopping(monitor='val_loss', min_delta=early_stopping_min_delta, 
                                       patience=early_stopping_patience, verbose=2, mode='min')

reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, epsilon=reuce_lr_epsilon, 
                                  patience=reduce_lr_patience, min_lr=reduce_lr_min, mode='min', verbose=2)

Compile model


In [10]:
model.compile(optimizer=optimizer, loss=loss_function, metrics=metrics)
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
starnet_input_x (InputLayer) (None, 7214)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7214, 1)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 7214, 4)           36        
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 7214, 16)          528       
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 1803, 16)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 28848)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               7385344   
_________________________________________________________________
dense_2 (Dense)              (None, 128)               32896     
_________________________________________________________________
starnet_output_y (Dense)     (None, 3)                 387       
=================================================================
Total params: 7,419,191
Trainable params: 7,419,191
Non-trainable params: 0
_________________________________________________________________

Train model


In [11]:
model.fit(x_train, y_train, validation_data=(x_cv, y_cv),
          epochs=max_epochs, verbose=1, shuffle='batch',
          callbacks=[early_stopping, reduce_lr])


Train on 41000 samples, validate on 3784 samples
Epoch 1/30
41000/41000 [==============================] - 209s 5ms/step - loss: 0.1755 - mean_absolute_error: 0.2274 - val_loss: 0.0316 - val_mean_absolute_error: 0.1300
Epoch 2/30
41000/41000 [==============================] - 209s 5ms/step - loss: 0.0262 - mean_absolute_error: 0.1148 - val_loss: 0.0141 - val_mean_absolute_error: 0.0880
Epoch 3/30
41000/41000 [==============================] - 211s 5ms/step - loss: 0.0140 - mean_absolute_error: 0.0843 - val_loss: 0.0098 - val_mean_absolute_error: 0.0676
Epoch 4/30
41000/41000 [==============================] - 210s 5ms/step - loss: 0.0112 - mean_absolute_error: 0.0738 - val_loss: 0.0089 - val_mean_absolute_error: 0.0667
Epoch 5/30
41000/41000 [==============================] - 121s 3ms/step - loss: 0.0091 - mean_absolute_error: 0.0668 - val_loss: 0.0079 - val_mean_absolute_error: 0.0645
Epoch 6/30
41000/41000 [==============================] - 138s 3ms/step - loss: 0.0084 - mean_absolute_error: 0.0635 - val_loss: 0.0062 - val_mean_absolute_error: 0.0533
Epoch 7/30
41000/41000 [==============================] - 136s 3ms/step - loss: 0.0073 - mean_absolute_error: 0.0590 - val_loss: 0.0068 - val_mean_absolute_error: 0.0590
Epoch 8/30
41000/41000 [==============================] - 137s 3ms/step - loss: 0.0078 - mean_absolute_error: 0.0594 - val_loss: 0.0073 - val_mean_absolute_error: 0.0550
Epoch 9/30
41000/41000 [==============================] - 137s 3ms/step - loss: 0.0061 - mean_absolute_error: 0.0544 - val_loss: 0.0065 - val_mean_absolute_error: 0.0539
Epoch 10/30
41000/41000 [==============================] - 136s 3ms/step - loss: 0.0057 - mean_absolute_error: 0.0528 - val_loss: 0.0050 - val_mean_absolute_error: 0.0483
Epoch 11/30
41000/41000 [==============================] - 136s 3ms/step - loss: 0.0051 - mean_absolute_error: 0.0499 - val_loss: 0.0050 - val_mean_absolute_error: 0.0507
Epoch 12/30
41000/41000 [==============================] - 135s 3ms/step - loss: 0.0047 - mean_absolute_error: 0.0479 - val_loss: 0.0044 - val_mean_absolute_error: 0.0445
Epoch 13/30
41000/41000 [==============================] - 105s 3ms/step - loss: 0.0041 - mean_absolute_error: 0.0460 - val_loss: 0.0042 - val_mean_absolute_error: 0.0431
Epoch 14/30
41000/41000 [==============================] - 107s 3ms/step - loss: 0.0039 - mean_absolute_error: 0.0449 - val_loss: 0.0046 - val_mean_absolute_error: 0.0462
Epoch 15/30
41000/41000 [==============================] - 108s 3ms/step - loss: 0.0036 - mean_absolute_error: 0.0432 - val_loss: 0.0042 - val_mean_absolute_error: 0.0453
Epoch 16/30
41000/41000 [==============================] - 105s 3ms/step - loss: 0.0035 - mean_absolute_error: 0.0428 - val_loss: 0.0033 - val_mean_absolute_error: 0.0401

Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.0003499999875202775.
Epoch 17/30
41000/41000 [==============================] - 101s 2ms/step - loss: 0.0027 - mean_absolute_error: 0.0368 - val_loss: 0.0034 - val_mean_absolute_error: 0.0405
Epoch 18/30
41000/41000 [==============================] - 104s 3ms/step - loss: 0.0025 - mean_absolute_error: 0.0359 - val_loss: 0.0028 - val_mean_absolute_error: 0.0357
Epoch 19/30
41000/41000 [==============================] - 145s 4ms/step - loss: 0.0024 - mean_absolute_error: 0.0354 - val_loss: 0.0029 - val_mean_absolute_error: 0.0366
Epoch 20/30
41000/41000 [==============================] - 146s 4ms/step - loss: 0.0024 - mean_absolute_error: 0.0353 - val_loss: 0.0031 - val_mean_absolute_error: 0.0388

Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.00017499999376013875.
Epoch 21/30
41000/41000 [==============================] - 145s 4ms/step - loss: 0.0019 - mean_absolute_error: 0.0317 - val_loss: 0.0027 - val_mean_absolute_error: 0.0361
Epoch 22/30
41000/41000 [==============================] - 145s 4ms/step - loss: 0.0019 - mean_absolute_error: 0.0316 - val_loss: 0.0028 - val_mean_absolute_error: 0.0366

Epoch 00022: ReduceLROnPlateau reducing learning rate to 8.749999688006938e-05.
Epoch 00022: early stopping
Out[11]:
<keras.callbacks.History at 0x7fbf95ecfb00>

Save model


In [12]:
starnet_model = 'starnet_cnn.h5'
model.save(datadir + starnet_model)
print(starnet_model+' saved.')


starnet_cnn.h5 saved.