In [1]:
# Template for the SMC competition for modeling neurons in the superior colliculus

import math
import numpy as np
import h5py

# Please download the file SCNeuronModelCompetition.mat from here.
# https://github.com/santacruzml/fall-17-scml-competition/releases/download/0.0-data/SCNeuronModelCompetition.mat

datafile = h5py.File('SCNeuronModelCompetition.mat')
movie = datafile.get('trainingmovie_mini') # movie for training
frhist = datafile.get('FRhist_tr') # firing rate histograms

# a little normalization for the movie (assuming that the movie is 3D array)
def normalize(inputmovie):
    movie_mean = np.mean(inputmovie, axis=(0, 1, 2))
    movie_std = np.std(inputmovie, axis=(0, 1, 2))
    return (inputmovie - movie_mean) / movie_std

movie_norm = normalize(movie)

In [2]:
# here's the modeling part. I'll give just a starting point

import keras
from keras.layers import LSTM, Activation, Dense, BatchNormalization

# It makes a 3-layer LSTM network with batch normalization on each layer.
# No dropout, regularization, convolution structures are used.
# As you see in the summary, most parameters go to the first weight matrix.



movie_chunk_length = movie_norm.shape[1]
movie_pix = movie_norm.shape[2]
nHidden = 100
nLayer = 3
nSCNeu = frhist.shape[2]


model = keras.models.Sequential()
model.add(LSTM(nHidden, input_shape=(movie_chunk_length, movie_pix), return_sequences=True, implementation=2))

for _ in range(nLayer-1):
    model.add(BatchNormalization(momentum=0))
    model.add(Activation('relu'))
    model.add(LSTM(nHidden, return_sequences=True))
    
model.add(BatchNormalization(momentum=0))
model.add(Activation('linear'))
model.add(Dense(nSCNeu))
model.add(Activation('softplus'))
adamopt = keras.optimizers.Adam(lr = 0.001, decay = 1e-7)

# Please make sure to use Poisson likelihood function for the loss function
model.compile(optimizer=adamopt, loss='poisson')
model.summary()

early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(movie_norm, frhist, epochs=200, batch_size=32, validation_split=0.2, shuffle=True, callbacks=[early_stopping])


Using TensorFlow backend.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 150, 100)          4955600   
_________________________________________________________________
batch_normalization_1 (Batch (None, 150, 100)          400       
_________________________________________________________________
activation_1 (Activation)    (None, 150, 100)          0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 150, 100)          80400     
_________________________________________________________________
batch_normalization_2 (Batch (None, 150, 100)          400       
_________________________________________________________________
activation_2 (Activation)    (None, 150, 100)          0         
_________________________________________________________________
lstm_3 (LSTM)                (None, 150, 100)          80400     
_________________________________________________________________
batch_normalization_3 (Batch (None, 150, 100)          400       
_________________________________________________________________
activation_3 (Activation)    (None, 150, 100)          0         
_________________________________________________________________
dense_1 (Dense)              (None, 150, 54)           5454      
_________________________________________________________________
activation_4 (Activation)    (None, 150, 54)           0         
=================================================================
Total params: 5,123,054
Trainable params: 5,122,454
Non-trainable params: 600
_________________________________________________________________
Train on 230 samples, validate on 58 samples
Epoch 1/200
230/230 [==============================] - 5s - loss: 0.8599 - val_loss: 0.9545
Epoch 2/200
230/230 [==============================] - 4s - loss: 0.7728 - val_loss: 0.7862
Epoch 3/200
230/230 [==============================] - 4s - loss: 0.7454 - val_loss: 0.7602
Epoch 4/200
230/230 [==============================] - 4s - loss: 0.7306 - val_loss: 0.7377
Epoch 5/200
230/230 [==============================] - 4s - loss: 0.7174 - val_loss: 0.7188
Epoch 6/200
230/230 [==============================] - 4s - loss: 0.7028 - val_loss: 0.7634
Epoch 7/200
230/230 [==============================] - 4s - loss: 0.6850 - val_loss: 0.8192
Epoch 8/200
230/230 [==============================] - 4s - loss: 0.6629 - val_loss: 0.6818
Epoch 9/200
230/230 [==============================] - 4s - loss: 0.6358 - val_loss: 0.6442
Epoch 10/200
230/230 [==============================] - 4s - loss: 0.6046 - val_loss: 0.6027
Epoch 11/200
230/230 [==============================] - 4s - loss: 0.5707 - val_loss: 0.6567
Epoch 12/200
230/230 [==============================] - 4s - loss: 0.5339 - val_loss: 0.5659
Epoch 13/200
230/230 [==============================] - 4s - loss: 0.4982 - val_loss: 0.5979
Epoch 14/200
230/230 [==============================] - 4s - loss: 0.4646 - val_loss: 0.5479
Epoch 15/200
230/230 [==============================] - 4s - loss: 0.4354 - val_loss: 0.4462
Epoch 16/200
230/230 [==============================] - 4s - loss: 0.4105 - val_loss: 0.5442
Epoch 17/200
230/230 [==============================] - 4s - loss: 0.3910 - val_loss: 0.4566
Epoch 18/200
230/230 [==============================] - 4s - loss: 0.3755 - val_loss: 0.4069
Epoch 19/200
230/230 [==============================] - 4s - loss: 0.3637 - val_loss: 0.3783
Epoch 20/200
230/230 [==============================] - 4s - loss: 0.3550 - val_loss: 0.4940
Epoch 21/200
230/230 [==============================] - 4s - loss: 0.3495 - val_loss: 0.3781
Epoch 22/200
230/230 [==============================] - 4s - loss: 0.3456 - val_loss: 0.3950
Epoch 23/200
230/230 [==============================] - 4s - loss: 0.3419 - val_loss: 0.3651
Epoch 24/200
230/230 [==============================] - 4s - loss: 0.3382 - val_loss: 0.3890
Epoch 25/200
230/230 [==============================] - 4s - loss: 0.3353 - val_loss: 0.3725
Epoch 26/200
230/230 [==============================] - 4s - loss: 0.3334 - val_loss: 0.4476
Epoch 27/200
230/230 [==============================] - 4s - loss: 0.3318 - val_loss: 0.4215
Epoch 28/200
230/230 [==============================] - 4s - loss: 0.3310 - val_loss: 0.3921
Epoch 29/200
230/230 [==============================] - 4s - loss: 0.3300 - val_loss: 0.3591
Epoch 30/200
230/230 [==============================] - 4s - loss: 0.3300 - val_loss: 0.3645
Epoch 31/200
230/230 [==============================] - 4s - loss: 0.3288 - val_loss: 0.3629
Epoch 32/200
230/230 [==============================] - 4s - loss: 0.3286 - val_loss: 0.3641
Epoch 33/200
230/230 [==============================] - 4s - loss: 0.3276 - val_loss: 0.3714
Epoch 34/200
230/230 [==============================] - 4s - loss: 0.3266 - val_loss: 0.3716
Epoch 35/200
230/230 [==============================] - 4s - loss: 0.3257 - val_loss: 0.3571
Epoch 36/200
230/230 [==============================] - 4s - loss: 0.3260 - val_loss: 0.3844
Epoch 37/200
230/230 [==============================] - 4s - loss: 0.3250 - val_loss: 0.3812
Epoch 38/200
230/230 [==============================] - 4s - loss: 0.3249 - val_loss: 0.3929
Epoch 39/200
230/230 [==============================] - 4s - loss: 0.3248 - val_loss: 0.3730
Epoch 40/200
230/230 [==============================] - 4s - loss: 0.3237 - val_loss: 0.4172
Epoch 41/200
230/230 [==============================] - 4s - loss: 0.3238 - val_loss: 0.3792
Epoch 42/200
230/230 [==============================] - 4s - loss: 0.3230 - val_loss: 0.3663
Epoch 43/200
230/230 [==============================] - 4s - loss: 0.3236 - val_loss: 0.3787
Epoch 44/200
230/230 [==============================] - 4s - loss: 0.3230 - val_loss: 0.3869
Epoch 45/200
230/230 [==============================] - 4s - loss: 0.3229 - val_loss: 0.3746
Epoch 46/200
230/230 [==============================] - 4s - loss: 0.3219 - val_loss: 0.3632

In [3]:
# check if it does a good job in the training dataset
%matplotlib inline
import matplotlib.pyplot as plt

output = model.predict(movie_norm)

for m in range(0, 48):
    n=31
    # plot the average of 6 trials of the same movie
    plt.plot(np.mean(frhist[(m*6):(m+1)*6, :, n], axis=(0)))
    
    # plot the output of the network
    plt.plot(output[m*6,:,n])
    plt.show()
    # last 10 movies should be the validation dataset