In [1]:
import numpy as np
# scipy.special for the sigmoid function expit()
import scipy.special as special
import matplotlib.pyplot as plt
# used the link for plotting
# https://matplotlib.org/examples/showcase/bachelors_degrees_by_gender.html

In [2]:
# neural network class definition
class neuralNetwork:
    
    # initialise the neural network
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # set number of nodes in each input, hidden, output layer
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes
        
        # link weight matrices, wih and who
        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer
        # w11 w21
        # w12 w22 etc 
        self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
        self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))

        # learning rate
        self.lr = learningrate
        
        # activation function is the sigmoid function
        self.activation_function = lambda x: special.expit(x)
        
        pass

    
    # train the neural network
    def train(self, inputs_list, targets_list):
        # convert inputs list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T
        
        # calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        # calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        # calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        # calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        # output layer error is the (target - actual)
        output_errors = targets - final_outputs
        
        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes
        hidden_errors = np.dot(self.who.T, output_errors) 
        
        # update the weights for the links between the hidden and output layers
        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs))
        
        # update the weights for the links between the input and hidden layers
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
        
        #print("Output errors:", output_errors)
        return hidden_errors, output_errors
    
    # query the neural network
    def query(self, inputs_list):
        # convert inputs list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        
        # calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        # calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        # calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        # calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        return final_outputs

In [3]:
# number of input, hidden and output nodes
input_nodes = 4
hidden_nodes = 4
output_nodes = 3

# learning rate is 0.3
learning_rate = 0.3

# create instance of neural network
n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)

In [4]:
data = [[0.645569620253164, 0.795454545454545, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.681818181818182, 0.202898550724638, 0.08, 0.333333333333333],
[0.594936708860759, 0.727272727272727, 0.188405797101449, 0.08, 0.333333333333333],
[0.582278481012658, 0.704545454545454, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.818181818181818, 0.202898550724638, 0.08, 0.333333333333333],
[0.683544303797468, 0.886363636363636, 0.246376811594203, 0.16, 0.333333333333333],
[0.582278481012658, 0.772727272727273, 0.202898550724638, 0.12, 0.333333333333333],
[0.632911392405063, 0.772727272727273, 0.217391304347826, 0.08, 0.333333333333333],
[0.556962025316456, 0.659090909090909, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.683544303797468, 0.840909090909091, 0.217391304347826, 0.08, 0.333333333333333],
[0.607594936708861, 0.772727272727273, 0.231884057971014, 0.08, 0.333333333333333],
[0.607594936708861, 0.681818181818182, 0.202898550724638, 0.04, 0.333333333333333],
[0.544303797468354, 0.681818181818182, 0.159420289855072, 0.04, 0.333333333333333],
[0.734177215189873, 0.909090909090909, 0.173913043478261, 0.08, 0.333333333333333],
[0.721518987341772, 1, 0.217391304347826, 0.16, 0.333333333333333],
[0.683544303797468, 0.886363636363636, 0.188405797101449, 0.16, 0.333333333333333],
[0.645569620253164, 0.795454545454545, 0.202898550724638, 0.12, 0.333333333333333],
[0.721518987341772, 0.863636363636364, 0.246376811594203, 0.12, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.217391304347826, 0.12, 0.333333333333333],
[0.683544303797468, 0.772727272727273, 0.246376811594203, 0.08, 0.333333333333333],
[0.645569620253164, 0.840909090909091, 0.217391304347826, 0.16, 0.333333333333333],
[0.582278481012658, 0.818181818181818, 0.144927536231884, 0.08, 0.333333333333333],
[0.645569620253164, 0.75, 0.246376811594203, 0.2, 0.333333333333333],
[0.607594936708861, 0.772727272727273, 0.27536231884058, 0.08, 0.333333333333333],
[0.632911392405063, 0.681818181818182, 0.231884057971014, 0.08, 0.333333333333333],
[0.632911392405063, 0.772727272727273, 0.231884057971014, 0.16, 0.333333333333333],
[0.658227848101266, 0.795454545454545, 0.217391304347826, 0.08, 0.333333333333333],
[0.658227848101266, 0.772727272727273, 0.202898550724638, 0.08, 0.333333333333333],
[0.594936708860759, 0.727272727272727, 0.231884057971014, 0.08, 0.333333333333333],
[0.607594936708861, 0.704545454545454, 0.231884057971014, 0.08, 0.333333333333333],
[0.683544303797468, 0.772727272727273, 0.217391304347826, 0.16, 0.333333333333333],
[0.658227848101266, 0.931818181818182, 0.217391304347826, 0.04, 0.333333333333333],
[0.69620253164557, 0.954545454545454, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.632911392405063, 0.727272727272727, 0.173913043478261, 0.08, 0.333333333333333],
[0.69620253164557, 0.795454545454545, 0.188405797101449, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.556962025316456, 0.681818181818182, 0.188405797101449, 0.08, 0.333333333333333],
[0.645569620253164, 0.772727272727273, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.795454545454545, 0.188405797101449, 0.12, 0.333333333333333],
[0.569620253164557, 0.522727272727273, 0.188405797101449, 0.12, 0.333333333333333],
[0.556962025316456, 0.727272727272727, 0.188405797101449, 0.08, 0.333333333333333],
[0.632911392405063, 0.795454545454545, 0.231884057971014, 0.24, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.27536231884058, 0.16, 0.333333333333333],
[0.607594936708861, 0.681818181818182, 0.202898550724638, 0.12, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.231884057971014, 0.08, 0.333333333333333],
[0.582278481012658, 0.727272727272727, 0.202898550724638, 0.08, 0.333333333333333],
[0.670886075949367, 0.840909090909091, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.75, 0.202898550724638, 0.08, 0.333333333333333],
[0.886075949367089, 0.727272727272727, 0.681159420289855, 0.56, 0.666666666666667],
[0.810126582278481, 0.727272727272727, 0.652173913043478, 0.6, 0.666666666666667],
[0.873417721518987, 0.704545454545454, 0.710144927536232, 0.6, 0.666666666666667],
[0.69620253164557, 0.522727272727273, 0.579710144927536, 0.52, 0.666666666666667],
[0.822784810126582, 0.636363636363636, 0.666666666666667, 0.6, 0.666666666666667],
[0.721518987341772, 0.636363636363636, 0.652173913043478, 0.52, 0.666666666666667],
[0.79746835443038, 0.75, 0.681159420289855, 0.64, 0.666666666666667],
[0.620253164556962, 0.545454545454545, 0.478260869565217, 0.4, 0.666666666666667],
[0.835443037974683, 0.659090909090909, 0.666666666666667, 0.52, 0.666666666666667],
[0.658227848101266, 0.613636363636364, 0.565217391304348, 0.56, 0.666666666666667],
[0.632911392405063, 0.454545454545455, 0.507246376811594, 0.4, 0.666666666666667],
[0.746835443037975, 0.681818181818182, 0.608695652173913, 0.6, 0.666666666666667],
[0.759493670886076, 0.5, 0.579710144927536, 0.4, 0.666666666666667],
[0.772151898734177, 0.659090909090909, 0.681159420289855, 0.56, 0.666666666666667],
[0.708860759493671, 0.659090909090909, 0.521739130434783, 0.52, 0.666666666666667],
[0.848101265822785, 0.704545454545454, 0.63768115942029, 0.56, 0.666666666666667],
[0.708860759493671, 0.681818181818182, 0.652173913043478, 0.6, 0.666666666666667],
[0.734177215189873, 0.613636363636364, 0.594202898550725, 0.4, 0.666666666666667],
[0.784810126582278, 0.5, 0.652173913043478, 0.6, 0.666666666666667],
[0.708860759493671, 0.568181818181818, 0.565217391304348, 0.44, 0.666666666666667],
[0.746835443037975, 0.727272727272727, 0.695652173913043, 0.72, 0.666666666666667],
[0.772151898734177, 0.636363636363636, 0.579710144927536, 0.52, 0.666666666666667],
[0.79746835443038, 0.568181818181818, 0.710144927536232, 0.6, 0.666666666666667],
[0.772151898734177, 0.636363636363636, 0.681159420289855, 0.48, 0.666666666666667],
[0.810126582278481, 0.659090909090909, 0.623188405797101, 0.52, 0.666666666666667],
[0.835443037974683, 0.681818181818182, 0.63768115942029, 0.56, 0.666666666666667],
[0.860759493670886, 0.636363636363636, 0.695652173913043, 0.56, 0.666666666666667],
[0.848101265822785, 0.681818181818182, 0.72463768115942, 0.68, 0.666666666666667],
[0.759493670886076, 0.659090909090909, 0.652173913043478, 0.6, 0.666666666666667],
[0.721518987341772, 0.590909090909091, 0.507246376811594, 0.4, 0.666666666666667],
[0.69620253164557, 0.545454545454545, 0.550724637681159, 0.44, 0.666666666666667],
[0.69620253164557, 0.545454545454545, 0.536231884057971, 0.4, 0.666666666666667],
[0.734177215189873, 0.613636363636364, 0.565217391304348, 0.48, 0.666666666666667],
[0.759493670886076, 0.613636363636364, 0.739130434782609, 0.64, 0.666666666666667],
[0.683544303797468, 0.681818181818182, 0.652173913043478, 0.6, 0.666666666666667],
[0.759493670886076, 0.772727272727273, 0.652173913043478, 0.64, 0.666666666666667],
[0.848101265822785, 0.704545454545454, 0.681159420289855, 0.6, 0.666666666666667],
[0.79746835443038, 0.522727272727273, 0.63768115942029, 0.52, 0.666666666666667],
[0.708860759493671, 0.681818181818182, 0.594202898550725, 0.52, 0.666666666666667],
[0.69620253164557, 0.568181818181818, 0.579710144927536, 0.52, 0.666666666666667],
[0.69620253164557, 0.590909090909091, 0.63768115942029, 0.48, 0.666666666666667],
[0.772151898734177, 0.681818181818182, 0.666666666666667, 0.56, 0.666666666666667],
[0.734177215189873, 0.590909090909091, 0.579710144927536, 0.48, 0.666666666666667],
[0.632911392405063, 0.522727272727273, 0.478260869565217, 0.4, 0.666666666666667],
[0.708860759493671, 0.613636363636364, 0.608695652173913, 0.52, 0.666666666666667],
[0.721518987341772, 0.681818181818182, 0.608695652173913, 0.48, 0.666666666666667],
[0.721518987341772, 0.659090909090909, 0.608695652173913, 0.52, 0.666666666666667],
[0.784810126582278, 0.659090909090909, 0.623188405797101, 0.52, 0.666666666666667],
[0.645569620253164, 0.568181818181818, 0.434782608695652, 0.44, 0.666666666666667],
[0.721518987341772, 0.636363636363636, 0.594202898550725, 0.52, 0.666666666666667],
[0.79746835443038, 0.75, 0.869565217391304, 1, 1],
[0.734177215189873, 0.613636363636364, 0.739130434782609, 0.76, 1],
[0.89873417721519, 0.681818181818182, 0.855072463768116, 0.84, 1],
[0.79746835443038, 0.659090909090909, 0.811594202898551, 0.72, 1],
[0.822784810126582, 0.681818181818182, 0.840579710144927, 0.88, 1],
[0.962025316455696, 0.681818181818182, 0.956521739130435, 0.84, 1],
[0.620253164556962, 0.568181818181818, 0.652173913043478, 0.68, 1],
[0.924050632911392, 0.659090909090909, 0.91304347826087, 0.72, 1],
[0.848101265822785, 0.568181818181818, 0.840579710144927, 0.72, 1],
[0.911392405063291, 0.818181818181818, 0.884057971014493, 1, 1],
[0.822784810126582, 0.727272727272727, 0.739130434782609, 0.8, 1],
[0.810126582278481, 0.613636363636364, 0.768115942028985, 0.76, 1],
[0.860759493670886, 0.681818181818182, 0.797101449275362, 0.84, 1],
[0.721518987341772, 0.568181818181818, 0.72463768115942, 0.8, 1],
[0.734177215189873, 0.636363636363636, 0.739130434782609, 0.96, 1],
[0.810126582278481, 0.727272727272727, 0.768115942028985, 0.92, 1],
[0.822784810126582, 0.681818181818182, 0.797101449275362, 0.72, 1],
[0.974683544303797, 0.863636363636364, 0.971014492753623, 0.88, 1],
[0.974683544303797, 0.590909090909091, 1, 0.92, 1],
[0.759493670886076, 0.5, 0.72463768115942, 0.6, 1],
[0.873417721518987, 0.727272727272727, 0.826086956521739, 0.92, 1],
[0.708860759493671, 0.636363636363636, 0.710144927536232, 0.8, 1],
[0.974683544303797, 0.636363636363636, 0.971014492753623, 0.8, 1],
[0.79746835443038, 0.613636363636364, 0.710144927536232, 0.72, 1],
[0.848101265822785, 0.75, 0.826086956521739, 0.84, 1],
[0.911392405063291, 0.727272727272727, 0.869565217391304, 0.72, 1],
[0.784810126582278, 0.636363636363636, 0.695652173913043, 0.72, 1],
[0.772151898734177, 0.681818181818182, 0.710144927536232, 0.72, 1],
[0.810126582278481, 0.636363636363636, 0.811594202898551, 0.84, 1],
[0.911392405063291, 0.681818181818182, 0.840579710144927, 0.64, 1],
[0.936708860759494, 0.636363636363636, 0.884057971014493, 0.76, 1],
[1, 0.863636363636364, 0.927536231884058, 0.8, 1],
[0.810126582278481, 0.636363636363636, 0.811594202898551, 0.88, 1],
[0.79746835443038, 0.636363636363636, 0.739130434782609, 0.6, 1],
[0.772151898734177, 0.590909090909091, 0.811594202898551, 0.56, 1],
[0.974683544303797, 0.681818181818182, 0.884057971014493, 0.92, 1],
[0.79746835443038, 0.772727272727273, 0.811594202898551, 0.96, 1],
[0.810126582278481, 0.704545454545454, 0.797101449275362, 0.72, 1],
[0.759493670886076, 0.681818181818182, 0.695652173913043, 0.72, 1],
[0.873417721518987, 0.704545454545454, 0.782608695652174, 0.84, 1],
[0.848101265822785, 0.704545454545454, 0.811594202898551, 0.96, 1],
[0.873417721518987, 0.704545454545454, 0.739130434782609, 0.92, 1],
[0.734177215189873, 0.613636363636364, 0.739130434782609, 0.76, 1],
[0.860759493670886, 0.727272727272727, 0.855072463768116, 0.92, 1],
[0.848101265822785, 0.75, 0.826086956521739, 1, 1],
[0.848101265822785, 0.681818181818182, 0.753623188405797, 0.92, 1],
[0.79746835443038, 0.568181818181818, 0.72463768115942, 0.76, 1],
[0.822784810126582, 0.681818181818182, 0.753623188405797, 0.8, 1],
[0.784810126582278, 0.772727272727273, 0.782608695652174, 0.92, 1],
[0.746835443037975, 0.681818181818182, 0.739130434782609, 0.72, 1]]

In [ ]:


In [ ]:


In [5]:
output_errors


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-5-7fee7a349a37> in <module>()
----> 1 output_errors

NameError: name 'output_errors' is not defined

In [100]:
# weight_changes should have 12 arrays 150 numbers
weight_changes = []

for i in range (0, hidden_nodes*output_nodes):
    weights = []
    weight_changes.append(weights)
    pass

output_errors = []
for i in range(0, output_nodes):
    errors = []
    output_errors.append(errors)

iterations = []
iter_count = 1
for i in data:
    # train the NN with each instance of data
    h_errors, o_errors = n.train(i[0:4], i[4:5])
    print(o_errors)
    
    count = 0
    for weight_row in n.who:
        for element in weight_row:
            weight_changes[count].append(element)
            count += 1
    
    count = 0
    for error in o_errors:
        output_errors[count].append(error[0])
        count += 1
    
    iterations.append(iter_count)
    iter_count += 1
    pass


[[ 0.01026773]
 [ 0.00596765]
 [ 0.00949208]]
[[-0.00818832]
 [-0.00484233]
 [-0.01040229]]
[[ 0.00222463]
 [ 0.00462725]
 [-0.00017695]]
[[-0.0104124 ]
 [-0.00560178]
 [-0.01306138]]
[[ 0.01299122]
 [ 0.00829143]
 [ 0.01223773]]
[[-0.01165783]
 [-0.02046618]
 [-0.01039962]]
[[-0.0080325 ]
 [-0.00680854]
 [-0.00973113]]
[[ 0.00307005]
 [ 0.00077795]
 [ 0.00202214]]
[[-0.01450616]
 [-0.00518987]
 [-0.01823769]]
[[ 0.00509723]
 [ 0.00684527]
 [ 0.00285141]]
[[ 0.01473771]
 [ 0.00622357]
 [ 0.01510237]]
[[-0.00233556]
 [-0.00285589]
 [-0.0037263 ]]
[[ 0.00478297]
 [ 0.00869352]
 [ 0.00195798]]
[[ 0.0138887 ]
 [ 0.02243124]
 [ 0.00931505]]
[[ 0.03437426]
 [ 0.02013042]
 [ 0.03584449]]
[[ 0.01229933]
 [-0.0030166 ]
 [ 0.01461761]]
[[ 0.00320545]
 [-0.00629274]
 [ 0.00428364]]
[[-0.00333383]
 [-0.0069001 ]
 [-0.00369445]]
[[-0.00209618]
 [-0.01212803]
 [-0.00069562]]
[[ 0.00274396]
 [-0.00377965]
 [ 0.00294018]]
[[-0.00454533]
 [-0.00905838]
 [-0.00456915]]
[[-0.01296772]
 [-0.01786481]
 [-0.01261096]]
[[ 0.02494519]
 [ 0.02351974]
 [ 0.02299327]]
[[-0.05291122]
 [-0.05221999]
 [-0.05238473]]
[[-0.01492488]
 [-0.01463713]
 [-0.01604858]]
[[-0.01532013]
 [-0.01237711]
 [-0.01695455]]
[[-0.02703084]
 [-0.02796788]
 [-0.02713773]]
[[ 0.00882401]
 [ 0.00402169]
 [ 0.00847893]]
[[ 0.00929897]
 [ 0.00550272]
 [ 0.00869374]]
[[-0.00884127]
 [-0.00587732]
 [-0.01082068]]
[[-0.01173464]
 [-0.00842034]
 [-0.01364364]]
[[-0.0189687 ]
 [-0.02316708]
 [-0.01831521]]
[[ 0.03661421]
 [ 0.02547899]
 [ 0.03716951]]
[[ 0.0330446 ]
 [ 0.01946516]
 [ 0.03443685]]
[[ 0.00501159]
 [ 0.00730326]
 [ 0.00279299]]
[[ 0.00846106]
 [ 0.00860835]
 [ 0.00682537]]
[[ 0.01643148]
 [ 0.00951443]
 [ 0.01656113]]
[[ 0.00427523]
 [ 0.00666948]
 [ 0.0020643 ]]
[[-0.00681173]
 [ 0.00153942]
 [-0.01036237]]
[[ 0.00309334]
 [ 0.0005131 ]
 [ 0.00237035]]
[[ 0.00153061]
 [-0.00130306]
 [ 0.00096121]]
[[-0.05291611]
 [-0.03494133]
 [-0.05709465]]
[[ 0.00158778]
 [ 0.00698072]
 [-0.00141913]]
[[-0.05246472]
 [-0.05314483]
 [-0.05146675]]
[[-0.0223736 ]
 [-0.02777362]
 [-0.02145214]]
[[-0.01909795]
 [-0.01457467]
 [-0.02084622]]
[[ 0.0155801 ]
 [ 0.00833072]
 [ 0.01576378]]
[[ 0.00083678]
 [ 0.00425751]
 [-0.00148145]]
[[ 0.01692901]
 [ 0.00922434]
 [ 0.01727461]]
[[ 0.00613283]
 [ 0.00499938]
 [ 0.00500845]]
[[-0.0853912 ]
 [-0.08725158]
 [-0.08301678]]
[[-0.14363092]
 [-0.1470333 ]
 [-0.14536776]]
[[-0.149534  ]
 [-0.15330765]
 [-0.15089479]]
[[-0.21108211]
 [-0.21366857]
 [-0.21712563]]
[[-0.17014704]
 [-0.17379701]
 [-0.1731044 ]]
[[-0.11711203]
 [-0.11765742]
 [-0.11860937]]
[[-0.10836848]
 [-0.11047011]
 [-0.10761009]]
[[ 0.01336978]
 [ 0.02338625]
 [ 0.01428462]]
[[-0.02590401]
 [-0.02432428]
 [-0.02042403]]
[[-0.14170352]
 [-0.1423411 ]
 [-0.14547591]]
[[-0.07195703]
 [-0.064747  ]
 [-0.07515468]]
[[-0.07951707]
 [-0.07961661]
 [-0.07820357]]
[[ 0.00058456]
 [ 0.0067298 ]
 [ 0.00435488]]
[[-0.08261723]
 [-0.08283823]
 [-0.08114281]]
[[ 0.02272429]
 [ 0.02747793]
 [ 0.0278938 ]]
[[ 0.01404922]
 [ 0.01606611]
 [ 0.02199251]]
[[-0.11993519]
 [-0.12133738]
 [-0.12133039]]
[[ 0.06265708]
 [ 0.06914167]
 [ 0.06905631]]
[[-0.2155004 ]
 [-0.21935265]
 [-0.22079638]]
[[ 0.02185907]
 [ 0.02870505]
 [ 0.02640915]]
[[-0.17466967]
 [-0.17880538]
 [-0.17782795]]
[[ 0.03008576]
 [ 0.03456854]
 [ 0.03724745]]
[[-0.1544344 ]
 [-0.15711435]
 [-0.15683427]]
[[ 0.0212973 ]
 [ 0.02591063]
 [ 0.02850575]]
[[ 0.04101487]
 [ 0.04498338]
 [ 0.04951533]]
[[ 0.02897513]
 [ 0.03209071]
 [ 0.03771457]]
[[-0.0105677 ]
 [-0.0083598 ]
 [-0.00336603]]
[[-0.09829625]
 [-0.10007853]
 [-0.09631421]]
[[-0.04433082]
 [-0.04265547]
 [-0.04046305]]
[[ 0.10233653]
 [ 0.11000533]
 [ 0.10868916]]
[[ 0.02915948]
 [ 0.03730254]
 [ 0.03362838]]
[[ 0.06208058]
 [ 0.07133518]
 [ 0.06723659]]
[[ 0.04293136]
 [ 0.0490789 ]
 [ 0.04941681]]
[[-0.17546562]
 [-0.17899899]
 [-0.17901025]]
[[-0.05624403]
 [-0.05421977]
 [-0.05385773]]
[[ 0.00997185]
 [ 0.01269804]
 [ 0.01771911]]
[[ 0.01520589]
 [ 0.01779421]
 [ 0.02396028]]
[[-0.03969596]
 [-0.03658432]
 [-0.03614698]]
[[ 0.04322619]
 [ 0.04878758]
 [ 0.05003713]]
[[-0.02568932]
 [-0.02068743]
 [-0.02305846]]
[[-0.00445595]
 [ 0.00124363]
 [-0.00045144]]
[[ 0.01033158]
 [ 0.01379029]
 [ 0.017315  ]]
[[ 0.03464547]
 [ 0.0408589 ]
 [ 0.04060551]]
[[ 0.06180118]
 [ 0.07378266]
 [ 0.064737  ]]
[[-0.00458879]
 [ 0.00021108]
 [-0.00025802]]
[[ 0.05887396]
 [ 0.06436408]
 [ 0.06591159]]
[[ 0.02480542]
 [ 0.02976958]
 [ 0.03092618]]
[[ 0.03585046]
 [ 0.03996848]
 [ 0.0433744 ]]
[[ 0.07626587]
 [ 0.08660737]
 [ 0.0801367 ]]
[[ 0.01445577]
 [ 0.01938445]
 [ 0.01989578]]
[[ 0.0475653 ]
 [ 0.04276287]
 [ 0.04279522]]
[[ 0.07870791]
 [ 0.07372443]
 [ 0.07299221]]
[[ 0.08578328]
 [ 0.07993275]
 [ 0.0807482 ]]
[[ 0.11120555]
 [ 0.10585243]
 [ 0.10624462]]
[[ 0.05416714]
 [ 0.04925381]
 [ 0.04920588]]
[[ 0.06755857]
 [ 0.06198357]
 [ 0.06263571]]
[[ 0.07114243]
 [ 0.06726561]
 [ 0.06496284]]
[[ 0.10619097]
 [ 0.10040067]
 [ 0.1016834 ]]
[[ 0.06317652]
 [ 0.05827844]
 [ 0.0579784 ]]
[[ 0.05258142]
 [ 0.04750412]
 [ 0.04792892]]
[[ 0.10241492]
 [ 0.09680561]
 [ 0.09750807]]
[[ 0.06247254]
 [ 0.05765975]
 [ 0.05726653]]
[[ 0.0594805 ]
 [ 0.05447352]
 [ 0.05453164]]
[[ 0.04511844]
 [ 0.04120238]
 [ 0.04031316]]
[[ 0.03969585]
 [ 0.03578376]
 [ 0.03532053]]
[[ 0.04892762]
 [ 0.04434748]
 [ 0.04425275]]
[[ 0.09087178]
 [ 0.08555172]
 [ 0.08578445]]
[[ 0.07917135]
 [ 0.0733236 ]
 [ 0.07462231]]
[[ 0.0380669 ]
 [ 0.03395471]
 [ 0.03401698]]
[[ 0.07008548]
 [ 0.06595449]
 [ 0.06443073]]
[[ 0.04554331]
 [ 0.04107967]
 [ 0.04112388]]
[[ 0.04743171]
 [ 0.04342639]
 [ 0.04261782]]
[[ 0.04663823]
 [ 0.0421095 ]
 [ 0.04226565]]
[[ 0.06867773]
 [ 0.06398541]
 [ 0.06342368]]
[[ 0.05545419]
 [ 0.05065332]
 [ 0.05075622]]
[[ 0.09339287]
 [ 0.08782967]
 [ 0.08881679]]
[[ 0.06876761]
 [ 0.06410888]
 [ 0.06352141]]
[[ 0.07457855]
 [ 0.06976406]
 [ 0.06931956]]
[[ 0.04077874]
 [ 0.03686373]
 [ 0.0365364 ]]
[[ 0.11444096]
 [ 0.10911347]
 [ 0.11029922]]
[[ 0.04787931]
 [ 0.04347821]
 [ 0.04349137]]
[[ 0.09456707]
 [ 0.08871808]
 [ 0.09046994]]
[[ 0.03794616]
 [ 0.03421692]
 [ 0.03391468]]
[[ 0.09576893]
 [ 0.09106396]
 [ 0.09054721]]
[[ 0.06726714]
 [ 0.06300055]
 [ 0.06201178]]
[[ 0.03810878]
 [ 0.03414476]
 [ 0.03423629]]
[[ 0.03778475]
 [ 0.03395948]
 [ 0.0338517 ]]
[[ 0.05512624]
 [ 0.05069322]
 [ 0.0504166 ]]
[[ 0.05821046]
 [ 0.05389808]
 [ 0.05325824]]
[[ 0.04343431]
 [ 0.03933904]
 [ 0.03924562]]
[[ 0.03674685]
 [ 0.03301584]
 [ 0.03292276]]
[[ 0.0398109 ]
 [ 0.03590293]
 [ 0.03580897]]
[[ 0.04093001]
 [ 0.03739733]
 [ 0.03665448]]
[[ 0.03741531]
 [ 0.03362963]
 [ 0.03358221]]
[[ 0.03630734]
 [ 0.03257778]
 [ 0.03255734]]
[[ 0.03805166]
 [ 0.03431793]
 [ 0.03414942]]
[[ 0.04096937]
 [ 0.03739993]
 [ 0.03676129]]
[[ 0.04397244]
 [ 0.04003119]
 [ 0.03973736]]
[[ 0.03860008]
 [ 0.03484038]
 [ 0.03467695]]
[[ 0.04958232]
 [ 0.04563708]
 [ 0.04499899]]

In [101]:
def PlotWeights(weight_trail):
    # These are the colors that will be used in the plot
    color_sequence = ['#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
                  '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
                  '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
                  '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5']
    # Common sizes: (10, 7.5) and (12, 9)
    fig, ax = plt.subplots(1, 1, figsize=(15, 8))
    
    # Remove the plot frame lines. They are unnecessary here.
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    # Ensure that the axis ticks only show up on the bottom and left of the plot.
    # Ticks on the right and top of the plot are generally unnecessary.
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    
    fig.subplots_adjust(left=.06, right=.90, bottom=.02, top=.94)
    # Limit the range of the plot to only where the data is.
    # Avoid unnecessary whitespace.
    ax.set_xlim(0.0, 150.1)
    ax.set_ylim(-2.0, 2.0)
    
    # Make sure your axis ticks are large enough to be easily read.
    # You don't want your viewers squinting to read your plot.
    plt.xticks(range(0, 150, 10), fontsize=14)
    plt.yticks(range(-2, 2), fontsize=14)
    ax.xaxis.set_major_formatter(plt.FuncFormatter('{:.0f}'.format))
    ax.yaxis.set_major_formatter(plt.FuncFormatter('{:.0f}'.format))
    
    # Provide tick lines across the plot to help your viewers trace along
    # the axis ticks. Make sure that the lines are light and small so they
    # don't obscure the primary data lines.
    plt.grid(True, 'major', 'y', ls='--', lw=.5, c='k', alpha=.3)
    
    # Remove the tick marks; they are unnecessary with the tick lines we just
    # plotted.
    plt.tick_params(axis='both', which='both', bottom='off', top='off',
                labelbottom='on', left='off', right='off', labelleft='on')
    
    # Now that the plot is prepared, it's time to actually plot the data!
    # Note that I plotted the majors in order of the highest % in the final year.
    hm_weights = ['w11', 'w12', 'w13', 'w14',
                  'w21', 'w22', 'w23', 'w24',
                  'w31', 'w32', 'w33', 'w34']
    
    hm_errors = ['error1', 'error2', 'error3']
    
    y_offsets = {'Foreign Languages': 0.5, 'English': -0.5,
             'Communications\nand Journalism': 0.75,
             'Art and Performance': -0.25, 'Agriculture': 1.25,
             'Social Sciences and History': 0.25, 'Business': -0.75,
             'Math and Statistics': 0.75, 'Architecture': -0.75,
             'Computer Science': 0.75, 'Engineering': -0.25}
    
    for rank, col in enumerate(hm_weights):
        line = plt.plot(iterations,
                        weight_changes[rank],
                        lw=2.5,
                        color=color_sequence[rank])
    
    for rank, col in enumerate(hm_errors):
        line = plt.plot(iterations,
                        output_errors[rank],
                        lw=2.5,
                        color=color_sequence[rank])
    
    fig.suptitle('Changes in weights as neural network gets trained by each instance', fontsize=18, ha='center')
    
    plt.show()

In [102]:
PlotWeights(weight_changes)



In [67]:
n.query([0.822784810126582, 0.681818181818182, 0.753623188405797, 0.8])


Out[67]:
array([[ 0.9599213 ],
       [ 0.93843269],
       [ 0.95758424]])

In [32]:
output_errors


Out[32]:
[[0.020398087820319322,
  -0.013338128727114262,
  0.0041183528392603197,
  -0.01359472570066983,
  0.025898105016439743,
  0.001939657197498057,
  -0.0054089354125250177,
  0.010172638936676359,
  -0.024819739203730373,
  0.0077394290721599091,
  0.030236857321448318,
  0.0038983321761696388,
  0.0045844263382112116,
  0.014532621094985931,
  0.055737830024925938,
  0.037572000455778631,
  0.017597399262652968,
  0.002532674195354967,
  0.011256285245160813,
  0.015678749425748095,
  0.00049244246366231259,
  -0.0064536744657679446,
  0.03652067526311803,
  -0.070317523337230237,
  -0.01175801079179617,
  -0.022242679898582585,
  -0.030012590784151583,
  0.019004351103885819,
  0.017219101508946899,
  -0.0093306563459433778,
  -0.015052439491161806,
  -0.018854321918767125,
  0.05993766967824099,
  0.056662360487280072,
  0.0065608076794226311,
  0.011195519986642033,
  0.027117626367917202,
  0.0056445868416076794,
  -0.013008796099013675,
  0.010001192765617117,
  0.008449095365713255,
  -0.098872049113661931,
  0.0039184566400638499,
  -0.062518993191753802,
  -0.010913708286363399,
  -0.027192012296609114,
  0.03420518196175909,
  0.0038736704620576101,
  0.033509414301396701,
  0.012669667860066569,
  -0.19483501002934656,
  -0.19974520288805531,
  -0.20246139926685736,
  -0.20316092243930484,
  -0.19654911779090889,
  -0.15645505065702658,
  -0.15455380112178618,
  -0.043142163931718658,
  -0.094881107740846127,
  -0.13711510843824415,
  -0.09226785578030694,
  -0.10426606565533247,
  -0.047881745154371713,
  -0.10216252112291702,
  -0.0061588890655511763,
  -0.025369886176932255,
  -0.11013945891614652,
  0.037274328517147848,
  -0.17687906399826481,
  -0.0063435897264427776,
  -0.14381843543593875,
  0.0046768917156786438,
  -0.13954442756846219,
  -0.0024152683557855381,
  0.019419620840517293,
  0.0020781714831447085,
  -0.047436059080703252,
  -0.10922045367033351,
  -0.056302421321372642,
  0.10374439269807545,
  0.0062420994301385369,
  0.041590601227363799,
  0.020004657531640047,
  -0.15119282750710805,
  -0.063403261917794218,
  -0.0066422442689546557,
  -0.011986092787072189,
  -0.06786546880477673,
  0.033822275904294652,
  -0.042902665480630486,
  -0.02305005399896165,
  -0.010668964759859234,
  0.0144904449701313,
  0.044407954640763858,
  -0.027092337023122415,
  0.050464608573233538,
  0.0015395769771717749,
  0.010575692641285217,
  0.058532800700712428,
  -0.016372855673905451,
  0.091271182936510931,
  0.12538523842482685,
  0.1155997042164395,
  0.13417161099312636,
  0.091208737100311188,
  0.093854796959392806,
  0.11580975724990128,
  0.1116745095289845,
  0.093010258320499917,
  0.080638885347415368,
  0.11371095853409363,
  0.091707865004751588,
  0.086066033928597685,
  0.077446843008578958,
  0.067518938450655708,
  0.077615434909464875,
  0.10155494125895415,
  0.085724282840715649,
  0.060408591388315114,
  0.09521419302719647,
  0.070800202858137373,
  0.077572866513233252,
  0.068418983572734682,
  0.08919874122149174,
  0.077187359752997287,
  0.090880379981907455,
  0.088754343319688345,
  0.091370395503654223,
  0.065057820794773069,
  0.096911396247097215,
  0.06879410389465801,
  0.083826197960065074,
  0.060683858343115626,
  0.098712148078909823,
  0.086345398466816436,
  0.058601298749391351,
  0.059929754206145547,
  0.07597420821434353,
  0.08014403521835689,
  0.065135778475318595,
  0.057064381559840283,
  0.06126573284514869,
  0.064937244209748224,
  0.057607561479986136,
  0.055745315160186437,
  0.058739349623815329,
  0.063388591013341933,
  0.06522071493251147,
  0.059584922066650226,
  0.071207021172024043],
 [0.0019985938379496471,
  -0.018050302855523614,
  -0.0041808827152005068,
  -0.01676568654714139,
  0.0064169287222371474,
  -0.019817551826227964,
  -0.013636320708120764,
  -0.0037807465479172686,
  -0.021363989389562577,
  -0.00093411066505333373,
  0.0060598956202667087,
  -0.0069269577040997898,
  -0.0010189069202872214,
  0.012116173231659644,
  0.020848847843424612,
  0.0029553548457619261,
  -0.0058904602795619154,
  -0.011099867619562886,
  -0.011661538978005259,
  -0.0033430298636619171,
  -0.013304873397957051,
  -0.020398913543384189,
  0.020960431474635,
  -0.070062715456617075,
  -0.018187690829866154,
  -0.023870640752749417,
  -0.035857771005537165,
  0.0026670690151848198,
  0.0024935227508430247,
  -0.012646919972248427,
  -0.017081446730157768,
  -0.029374895861299155,
  0.029993229437707714,
  0.024013282608874698,
  0.00094837448831969207,
  0.0024107817652248009,
  0.0077640230163452673,
  0.00044305695553459312,
  -0.0095053214121870178,
  -0.0016668370220155881,
  -0.0034388376529752018,
  -0.076608574159263532,
  0.0020227404541820482,
  -0.064477933873829663,
  -0.023913721791978038,
  -0.025659485877314869,
  0.013706760837796927,
  0.00032436164807780887,
  0.012804026392293721,
  0.0030173596735988406,
  -0.20457426582752314,
  -0.20911006482173222,
  -0.21085255996131957,
  -0.20776048897173682,
  -0.20231757205524592,
  -0.15514427184030632,
  -0.15426921306290353,
  -0.021100833840063338,
  -0.08273445668408097,
  -0.13121584868942904,
  -0.073667293541605394,
  -0.093800645631120338,
  -0.024019442461917895,
  -0.08931722332068448,
  0.017449224062845325,
  -0.0033596264927123753,
  -0.099506617925563923,
  0.067518011627602514,
  -0.17552131970256191,
  0.022771836240780607,
  -0.13843799050752836,
  0.033235097526243762,
  -0.13019321682778295,
  0.029097637330523285,
  0.049846913707922602,
  0.030630231099850747,
  -0.022648175188047626,
  -0.095576108685208827,
  -0.033578860769185592,
  0.13672199766828741,
  0.038736452301099566,
  0.076267068363932378,
  0.050610678261373998,
  -0.14474546449569403,
  -0.040933596209474898,
  0.021297216300879218,
  0.016829995222682581,
  -0.044941759043333529,
  0.065108447409642056,
  -0.017450813910982244,
  0.0060493262351009047,
  0.017409150395087503,
  0.044907822094670036,
  0.078728235985820905,
  -0.0015037141803090348,
  0.080566620948219758,
  0.028906448970980603,
  0.038183215127682968,
  0.088968892413255118,
  0.0084160711870064908,
  0.079545260313953814,
  0.11857196644636658,
  0.10638893123538196,
  0.1276489442116443,
  0.079071130193424599,
  0.081739441323865236,
  0.1074999783314412,
  0.10125132814834237,
  0.08136324285276364,
  0.06680026440693787,
  0.10169400095160863,
  0.079426113377609742,
  0.072911618437285108,
  0.065979436377994527,
  0.055808232965094962,
  0.064456687636654353,
  0.089279822777715956,
  0.071880968748556318,
  0.048328161506730916,
  0.084528258456136762,
  0.057754247296047834,
  0.065734702131106659,
  0.055801648507022628,
  0.076863458639895854,
  0.064008433556949407,
  0.077759515787862665,
  0.076236234188812091,
  0.078596636812728438,
  0.053459508985312709,
  0.084175107928019632,
  0.056408524280693983,
  0.069892585528556705,
  0.049583773437767653,
  0.08647809968069009,
  0.074814115093095768,
  0.046809326561992659,
  0.048358786157981948,
  0.063486324002754557,
  0.067760218580175779,
  0.053064590688979063,
  0.046060166750695686,
  0.049605005422391213,
  0.05437436003477325,
  0.046468422125876763,
  0.044830080547907269,
  0.04774527840727838,
  0.052928077234286164,
  0.053775036977246171,
  0.048455023659693897,
  0.05985298903038716],
 [0.0063423744770376156,
  -0.024936180601649149,
  -0.010095524230302633,
  -0.026344834545322715,
  0.011246085067302458,
  -0.0068652628990448772,
  -0.017960900290603543,
  -0.0025503625628675586,
  -0.037653513617520673,
  -0.0057808068156673387,
  0.017203606035155805,
  -0.008600486252170414,
  -0.0090229582386646134,
  -0.002752269184598144,
  0.04114750457999794,
  0.026094970586216959,
  0.0074805257319819907,
  -0.0075771480280518166,
  0.0028298362115332099,
  0.0043613416735940325,
  -0.008305107955147828,
  -0.014627634312657234,
  0.020029320426734332,
  -0.071407154213241675,
  -0.021530366842863358,
  -0.030462421837103959,
  -0.035983490568532006,
  0.0074921410701088709,
  0.0058548569701523112,
  -0.020002028301627894,
  -0.024659138992110807,
  -0.023809285558274051,
  0.043831022005456188,
  0.042293271408509747,
  -0.0052545875646448637,
  -0.00019287919555516231,
  0.016212141611229358,
  -0.0058513664054515857,
  -0.024793079673356311,
  -0.00014172328300821047,
  -0.0014284165121013492,
  -0.104271717067486,
  -0.0089349184933895098,
  -0.063208439449879139,
  -0.016878138522719188,
  -0.034437435885035728,
  0.02200579414928483,
  -0.0076577633667155864,
  0.02213999655941612,
  0.0022844911136982171,
  -0.19593041466042294,
  -0.20261660906476153,
  -0.20297402528921005,
  -0.20962227793654764,
  -0.19680784225832282,
  -0.15399346475175602,
  -0.14614517298931273,
  -0.034702897011030798,
  -0.074415909238301214,
  -0.13401143466585552,
  -0.088388854016007379,
  -0.088874199673081855,
  -0.027324557648052417,
  -0.084048231703763854,
  0.018642611758246108,
  0.0095264662910062858,
  -0.096954805093976471,
  0.066711352426497061,
  -0.17405187652761878,
  0.018330412533470897,
  -0.13181450600300515,
  0.038445987153425953,
  -0.12625324289998863,
  0.033110366990925733,
  0.05878976636833011,
  0.04222818981538401,
  -0.012280645856503014,
  -0.085071671939820681,
  -0.028726516392175272,
  0.13403015706334254,
  0.031324792393188416,
  0.068113096986839872,
  0.05061432186881798,
  -0.14156997962687157,
  -0.040753240136786051,
  0.031056187316084061,
  0.029539867715247969,
  -0.043680976435483077,
  0.066224764917787837,
  -0.022219581934905919,
  0.0014314305021462248,
  0.023628809789907335,
  0.043665979722240622,
  0.062710730872348464,
  -0.0031463678978924658,
  0.082461912556380468,
  0.030066144262987149,
  0.044353219799115995,
  0.07773506000735908,
  0.0086405349811743415,
  0.085600876141857296,
  0.12128351156158701,
  0.11460418200377442,
  0.13313034191680562,
  0.084461398780507069,
  0.089566182756336543,
  0.10511015932892953,
  0.1090250130899606,
  0.08558910125772845,
  0.074119410987953716,
  0.10828480205055979,
  0.083368527636928258,
  0.078433147709741635,
  0.067523085525840854,
  0.058540162624085457,
  0.069334058267329857,
  0.094505084825008523,
  0.080353333020477313,
  0.053303363556247541,
  0.085302312793038904,
  0.063047071929826615,
  0.067782569508175516,
  0.061378232498543017,
  0.080295710396552145,
  0.069380757114764879,
  0.084419098334214571,
  0.079618331863861025,
  0.08231372048634944,
  0.056929514411305671,
  0.090526528682066343,
  0.061412260895052029,
  0.078121129308297066,
  0.052944305288136095,
  0.08994048520356146,
  0.076977101876406606,
  0.051947683233530828,
  0.0525248665840653,
  0.067562427938415937,
  0.070889031291882731,
  0.057625001843984447,
  0.050177808021236392,
  0.054037072793806429,
  0.056391925583529456,
  0.05082816127114087,
  0.049185879495544538,
  0.051700869275239425,
  0.055392170985078537,
  0.057553488194472924,
  0.052402582925113128,
  0.062649415295468303]]

In [ ]: