Data Driven CNN StarNet

This notebook builds a supervized learning model StarNet to predict stellar parameters from spectra, assuming we have access to a set of stellar parameters previously estimated.

Summary of the current implementation

  • Inputs: APOGEE DR14 spectra
  • Labels: 3 stellar parameters resulting from the APOGEE pipeline
  • Model: See the build_model routine below

TODO


In [4]:
import numpy as np
import random
import h5py
import time

from keras.models import Sequential
from keras.layers import Dense, Flatten, BatchNormalization, Dropout, Input
from keras.layers.convolutional import Conv1D, MaxPooling1D, AveragePooling1D
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau


from keras_contrib.layers import InstanceNormalization
from keras.layers import RepeatVector,Add
from keras.layers import UpSampling2D, Reshape, Activation
from keras.models import Model
import keras.initializers


Using TensorFlow backend.

Hyper parameters for the model


In [5]:
# activation function used following every layer except for the output layers
activation = 'relu'

# model weight initializer
initializer = 'he_normal'

num_fluxes = 7514
num_labels = 3

# shape of input spectra that is fed into the input layer
input_shape = (None,num_fluxes,1)

# number of filters used in the convolutional layers
num_filters = 8

# length of the filters in the convolutional layers
filter_length = 3

# length of the maxpooling window 
pool_length = 4

# number of nodes in each of the hidden fully connected layers
num_hidden = [256,128]

# number of spectra fed into model at once during training
batch_size = 64

# maximum number of interations for model training
max_epochs = 30

# initial learning rate for optimization algorithm
lr = 0.0001
    
# exponential decay rate for the 1st moment estimates for optimization algorithm
beta_1 = 0.9

# exponential decay rate for the 2nd moment estimates for optimization algorithm
beta_2 = 0.999

# a small constant for numerical stability for optimization algorithm
optimizer_epsilon = 1e-08

early_stopping_min_delta = 0.0001
early_stopping_patience = 4
reduce_lr_factor = 0.5
reuce_lr_epsilon = 0.0009
reduce_lr_patience = 2
reduce_lr_min = 0.00008
loss_function = 'mean_squared_error'

In [6]:
def build_model(input_spec):

    # input conv layer with filter length 1, no bias value
    x = Conv1D(kernel_initializer=keras.initializers.Constant(0.5),
               activation='linear', padding="same", filters=1,
               kernel_size=1,use_bias=False)(input_spec)
    
    # instance normalize to bring each spectrum to zero-mean and unit variance
    normed_spec = InstanceNormalization()(x)
    
    # upsample the spectra so that they can be easily added to the output of the conv blocks
    # this method just repeats the spectra n=num_filters times
    normed_spec = Reshape((num_fluxes,1,1))(normed_spec)
    repeated_spec = UpSampling2D(size=(1, num_filters))(normed_spec)
    
    # reshape spectra and repeated spectra to proper shape for 1D Conv layers
    repeated_spec = Reshape((num_fluxes,num_filters))(repeated_spec)    
    x = Reshape((num_fluxes,1))(normed_spec)
    
    # Conv block w/ InstanceNorm w/ dropout
    x = Conv1D(kernel_initializer=initializer, padding="same", filters=num_filters, 
               kernel_size=filter_length)(x)
    x = Activation('relu')(x)
    x = InstanceNormalization()(x)
    x = Conv1D(kernel_initializer=initializer, padding="same", filters=num_filters, 
               kernel_size=filter_length)(x)
    x = Activation('relu')(x)
    x = InstanceNormalization()(x)
    x = Add()([x, repeated_spec])
    x = Dropout(0.2)(x)

    # Conv block w/ InstanceNorm w/o dropout
    x = Conv1D(kernel_initializer=initializer, padding="same", filters=num_filters, 
               kernel_size=filter_length)(x)
    x = Activation('relu')(x)
    x = InstanceNormalization()(x)
    x = Conv1D(kernel_initializer=initializer, padding="same", filters=num_filters, 
               kernel_size=filter_length)(x)
    x = Activation('relu')(x)
    x = InstanceNormalization()(x)
    x = Add()([x, repeated_spec])

    # Avg pooling w/ dropout (DO NOT APPLY DROPOUT BEFORE POOLING)
    x = AveragePooling1D(pool_size=pool_length)(x)
    x = Dropout(0.2)(x)
    x = Flatten()(x)

    # Fully connected blocks w/ BatchNorm
    x = Dense(num_hidden[0], kernel_initializer=initializer)(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)

    x = Dense(num_hidden[1], kernel_initializer=initializer)(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    # output nodes
    output_pred = Dense(units=num_labels, activation="linear")(x)

    return Model(input_spec,output_pred)

Build and compile model


In [7]:
input_spec = Input(shape=(num_fluxes,1,))
model = build_model(input_spec)
optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=optimizer_epsilon, decay=0.0)

early_stopping = EarlyStopping(monitor='val_loss', min_delta=early_stopping_min_delta, 
                               patience=early_stopping_patience, verbose=2, mode='min')

reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, epsilon=reuce_lr_epsilon, 
                              patience=reduce_lr_patience, min_lr=reduce_lr_min, mode='min', verbose=2)

model.compile(optimizer=optimizer, loss=loss_function)
model.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 7514, 1)      0                                            
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 7514, 1)      1           input_1[0][0]                    
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 7514, 1)      2           conv1d_1[0][0]                   
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 7514, 1, 1)   0           instance_normalization_1[0][0]   
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 7514, 1)      0           reshape_1[0][0]                  
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 7514, 8)      32          reshape_3[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 7514, 8)      0           conv1d_2[0][0]                   
__________________________________________________________________________________________________
instance_normalization_2 (Insta (None, 7514, 8)      2           activation_1[0][0]               
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 7514, 8)      200         instance_normalization_2[0][0]   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 7514, 8)      0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 7514, 8, 1)   0           reshape_1[0][0]                  
__________________________________________________________________________________________________
instance_normalization_3 (Insta (None, 7514, 8)      2           activation_2[0][0]               
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 7514, 8)      0           up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
add_1 (Add)                     (None, 7514, 8)      0           instance_normalization_3[0][0]   
                                                                 reshape_2[0][0]                  
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 7514, 8)      0           add_1[0][0]                      
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 7514, 8)      200         dropout_1[0][0]                  
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 7514, 8)      0           conv1d_4[0][0]                   
__________________________________________________________________________________________________
instance_normalization_4 (Insta (None, 7514, 8)      2           activation_3[0][0]               
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, 7514, 8)      200         instance_normalization_4[0][0]   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 7514, 8)      0           conv1d_5[0][0]                   
__________________________________________________________________________________________________
instance_normalization_5 (Insta (None, 7514, 8)      2           activation_4[0][0]               
__________________________________________________________________________________________________
add_2 (Add)                     (None, 7514, 8)      0           instance_normalization_5[0][0]   
                                                                 reshape_2[0][0]                  
__________________________________________________________________________________________________
average_pooling1d_1 (AveragePoo (None, 1878, 8)      0           add_2[0][0]                      
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 1878, 8)      0           average_pooling1d_1[0][0]        
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 15024)        0           dropout_2[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          3846400     flatten_1[0][0]                  
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 256)          0           dense_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256)          1024        activation_5[0][0]               
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 256)          0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 128)          32896       dropout_3[0][0]                  
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 128)          0           dense_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128)          512         activation_6[0][0]               
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 128)          0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 3)            387         dropout_4[0][0]                  
==================================================================================================
Total params: 3,881,862
Trainable params: 3,881,094
Non-trainable params: 768
__________________________________________________________________________________________________

Load non-normalized spectra


In [8]:
# hack to load pre-computed mean and std-dev for faster normalization
mean_and_std = np.load('/data/stars/apogee/dr14/aspcap_labels_mean_and_std.npy')
mean_labels = mean_and_std[0]
std_labels = mean_and_std[1]
num_labels = mean_and_std.shape[1]

def normalize(lb):
    return (lb-mean_labels)/std_labels

data_file = '/data/stars/apogee/dr14/starnet_training_data.h5'

with h5py.File(data_file,"r") as F:
    spectra = F["spectrum"][:]
    labels = np.column_stack((F["TEFF"][:],F["LOGG"][:],F["FE_H"][:]))
    # Normalize labels
    labels = normalize(labels)
print('Reference set includes '+str(len(spectra))+' individual visit spectra.')

# define the number of wavelength bins (typically 7214)
num_fluxes = spectra.shape[1]
print('Each spectrum contains '+str(num_fluxes)+' wavelength bins')

num_train=int(0.9*len(labels))

# set NaN values to zero
indices_nan = np.where(np.isnan(spectra))
spectra[indices_nan]=0.

# some visit spectra are just zero-vectors... remove these.
spec_std = np.std(spectra,axis=1)
spec_std = spec_std.reshape(spec_std.shape[0],1)
indices = np.where(spec_std!=0.)[0]
spectra = spectra[indices]
labels = labels[indices]

reference_data = np.column_stack((spectra,labels))
np.random.shuffle(reference_data)

train_spectra = reference_data[0:num_train,0:num_fluxes]

# Reshape spectra for convolutional layers
train_spectra = train_spectra.reshape(train_spectra.shape[0], train_spectra.shape[1], 1)
train_labels = reference_data[0:num_train,num_fluxes:]

cv_spectra = reference_data[num_train:,0:num_fluxes]
cv_spectra = cv_spectra.reshape(cv_spectra.shape[0], cv_spectra.shape[1], 1)
cv_labels = reference_data[num_train:,num_fluxes:]

reference_data=[]
spectra=[]
labels=[]

print('Training set includes '+str(len(train_spectra))+' spectra and the cross-validation set includes '+str(len(cv_spectra))+' spectra')


Reference set includes 89554 individual visit spectra.
Each spectrum contains 7514 wavelength bins
Training set includes 80598 spectra and the cross-validation set includes 8784 spectra

In [ ]:
time1 = time.time()

# Train model 
model.fit(train_spectra, train_labels, validation_data=(cv_spectra, cv_labels),
          epochs=max_epochs, batch_size=batch_size, verbose=2,
          callbacks=[reduce_lr,early_stopping])

time2 = time.time()

print("\n" + str(time2-time1) + " seconds for training\n")

# Save model in current directory
model.save('StarNet_DR14.h5')


Train on 80598 samples, validate on 8784 samples
Epoch 1/30
 - 1722s - loss: 0.6490 - val_loss: 0.0766
Epoch 2/30
 - 1743s - loss: 0.3082 - val_loss: 0.0539
Epoch 3/30
 - 1745s - loss: 0.1925 - val_loss: 0.0336
Epoch 4/30

Spectra Normalization

Tentative to replace what stellar spectroscopist call normalization (supression of a global continuum over the whole spectrum) with an input normalization.

This test is simply a convolutional layer with one filter of length 1, followed by an InstanceNormalization layer

First build a model that only includes our input convolutional and instance normalization layers.

Note: I use a constant initialization of 0.5 because if the kernel is < 0. then the normalized spectra are inverted. this probably doesn't matter for the NN but it makes it a lot nicer to plot


In [ ]:
def build_normalizer_model(input_spec):

    # input conv layer with filter length 1 to flatten the shape
    x = Conv1D(kernel_initializer=keras.initializers.Constant(0.5), activation='linear', padding="same", filters=1, 
           kernel_size=1,use_bias=False)(input_spec)
    # instance normalize to bring each spectrum to zero-mean and unit variance
    normed_spec = InstanceNormalization()(x)
    
    return Model(input_spec,normed_spec) 

input_spec = Input(shape=(num_fluxes,1,))
model = build_normalizer_model(input_spec)
model.summary()
normalized_cv = model.predict(cv_spectra)

Plot the input spectra, then the normalized spectra. I will force the second of the two plots to have the same y-axis range to ensure that the range for our normalized spectra are similar to one another


In [ ]:
import matplotlib.pyplot as plt
%matplotlib inline

for i in range(10):
    fig, axes = plt.subplots(2,1,figsize=(70, 10))
    axes[0].plot(cv_spectra[i,:,0],c='b')
    axes[1].plot(normalized_cv[i,:,0],c='r')
    axes[1].set_ylim((-4,4))
    plt.show()

We may want to do some pre-processing clipping to the spectra to elminate the outliers

Stacking

Is the stacking method used on spectra to add them to the output from conv blocks correct?

First extend previous model to include the upsample layer.


In [ ]:
def build_upsample_model(input_spec):

    # input conv layer with filter length 1, no bias value
    x = Conv1D(kernel_initializer=keras.initializers.Constant(0.5), activation='linear', padding="same", filters=1, 
           kernel_size=1,use_bias=False)(input_spec)
    # instance normalize to bring each spectrum to zero-mean and unit variance
    normed_spec = InstanceNormalization()(x)
    
    # upsample the spectra so that they can be easily added to the output of the conv layers
    # this method just repeats the spectra n=num_filters times
    normed_spec = Reshape((num_fluxes,1,1))(normed_spec)
    repeated_spec = UpSampling2D(size=(1, num_filters))(normed_spec)
    repeated_spec = Reshape((num_fluxes,num_filters))(repeated_spec)
    
    return Model(input_spec,repeated_spec) 

input_spec = Input(shape=(num_fluxes,1,))
model = build_upsample_model(input_spec)
model.summary()
upsampled_cv = model.predict(cv_spectra[0:100])

In [ ]:
# Plot the input spectra, then the normalized upsampled spectra
for i in range(5):
    fig, axes = plt.subplots(9,1,figsize=(70, 10))
    axes[0].plot(cv_spectra[i,:,0],c='b')
    for ii in range(8):
        axes[ii+1].plot(upsampled_cv[i,:,ii],c='r')
        axes[ii+1].set_ylim((-4,4))
    plt.show()