In [116]:
import numpy as np
# scipy.special for the sigmoid function expit()
import scipy.special as special
import matplotlib.pyplot as plt
# used the link for plotting
# https://matplotlib.org/examples/showcase/bachelors_degrees_by_gender.html

In [117]:
# neural network class definition
class neuralNetwork:
    
    # initialise the neural network
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # set number of nodes in each input, hidden, output layer
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes
        
        # link weight matrices, wih and who
        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer
        # w11 w21
        # w12 w22 etc 
        self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
        self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))

        # learning rate
        self.lr = learningrate
        
        # activation function is the sigmoid function
        self.activation_function = lambda x: special.expit(x)
        
        pass

    
    # train the neural network
    def train(self, inputs_list, targets_list):
        # convert inputs list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T
        
        # calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        # calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        # calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        # calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        # output layer error is the (target - actual)
        output_errors = targets - final_outputs
        
        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes
        hidden_errors = np.dot(self.who.T, output_errors) 
        
        # update the weights for the links between the hidden and output layers
        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs))
        
        # update the weights for the links between the input and hidden layers
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
        
        #print("Output errors:", output_errors)
        return hidden_errors, output_errors
    
    # query the neural network
    def query(self, inputs_list):
        # convert inputs list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        
        # calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        # calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        # calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        # calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        return final_outputs

In [118]:
# number of input, hidden and output nodes
input_nodes = 4
hidden_nodes = 4
output_nodes = 3

# learning rate is 0.3
learning_rate = 0.3

# create instance of neural network
n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)

In [119]:
data = [[0.645569620253164, 0.795454545454545, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.681818181818182, 0.202898550724638, 0.08, 0.333333333333333],
[0.594936708860759, 0.727272727272727, 0.188405797101449, 0.08, 0.333333333333333],
[0.582278481012658, 0.704545454545454, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.818181818181818, 0.202898550724638, 0.08, 0.333333333333333],
[0.683544303797468, 0.886363636363636, 0.246376811594203, 0.16, 0.333333333333333],
[0.582278481012658, 0.772727272727273, 0.202898550724638, 0.12, 0.333333333333333],
[0.632911392405063, 0.772727272727273, 0.217391304347826, 0.08, 0.333333333333333],
[0.556962025316456, 0.659090909090909, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.683544303797468, 0.840909090909091, 0.217391304347826, 0.08, 0.333333333333333],
[0.607594936708861, 0.772727272727273, 0.231884057971014, 0.08, 0.333333333333333],
[0.607594936708861, 0.681818181818182, 0.202898550724638, 0.04, 0.333333333333333],
[0.544303797468354, 0.681818181818182, 0.159420289855072, 0.04, 0.333333333333333],
[0.734177215189873, 0.909090909090909, 0.173913043478261, 0.08, 0.333333333333333],
[0.721518987341772, 1, 0.217391304347826, 0.16, 0.333333333333333],
[0.683544303797468, 0.886363636363636, 0.188405797101449, 0.16, 0.333333333333333],
[0.645569620253164, 0.795454545454545, 0.202898550724638, 0.12, 0.333333333333333],
[0.721518987341772, 0.863636363636364, 0.246376811594203, 0.12, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.217391304347826, 0.12, 0.333333333333333],
[0.683544303797468, 0.772727272727273, 0.246376811594203, 0.08, 0.333333333333333],
[0.645569620253164, 0.840909090909091, 0.217391304347826, 0.16, 0.333333333333333],
[0.582278481012658, 0.818181818181818, 0.144927536231884, 0.08, 0.333333333333333],
[0.645569620253164, 0.75, 0.246376811594203, 0.2, 0.333333333333333],
[0.607594936708861, 0.772727272727273, 0.27536231884058, 0.08, 0.333333333333333],
[0.632911392405063, 0.681818181818182, 0.231884057971014, 0.08, 0.333333333333333],
[0.632911392405063, 0.772727272727273, 0.231884057971014, 0.16, 0.333333333333333],
[0.658227848101266, 0.795454545454545, 0.217391304347826, 0.08, 0.333333333333333],
[0.658227848101266, 0.772727272727273, 0.202898550724638, 0.08, 0.333333333333333],
[0.594936708860759, 0.727272727272727, 0.231884057971014, 0.08, 0.333333333333333],
[0.607594936708861, 0.704545454545454, 0.231884057971014, 0.08, 0.333333333333333],
[0.683544303797468, 0.772727272727273, 0.217391304347826, 0.16, 0.333333333333333],
[0.658227848101266, 0.931818181818182, 0.217391304347826, 0.04, 0.333333333333333],
[0.69620253164557, 0.954545454545454, 0.202898550724638, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.632911392405063, 0.727272727272727, 0.173913043478261, 0.08, 0.333333333333333],
[0.69620253164557, 0.795454545454545, 0.188405797101449, 0.08, 0.333333333333333],
[0.620253164556962, 0.704545454545454, 0.217391304347826, 0.04, 0.333333333333333],
[0.556962025316456, 0.681818181818182, 0.188405797101449, 0.08, 0.333333333333333],
[0.645569620253164, 0.772727272727273, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.795454545454545, 0.188405797101449, 0.12, 0.333333333333333],
[0.569620253164557, 0.522727272727273, 0.188405797101449, 0.12, 0.333333333333333],
[0.556962025316456, 0.727272727272727, 0.188405797101449, 0.08, 0.333333333333333],
[0.632911392405063, 0.795454545454545, 0.231884057971014, 0.24, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.27536231884058, 0.16, 0.333333333333333],
[0.607594936708861, 0.681818181818182, 0.202898550724638, 0.12, 0.333333333333333],
[0.645569620253164, 0.863636363636364, 0.231884057971014, 0.08, 0.333333333333333],
[0.582278481012658, 0.727272727272727, 0.202898550724638, 0.08, 0.333333333333333],
[0.670886075949367, 0.840909090909091, 0.217391304347826, 0.08, 0.333333333333333],
[0.632911392405063, 0.75, 0.202898550724638, 0.08, 0.333333333333333],
[0.886075949367089, 0.727272727272727, 0.681159420289855, 0.56, 0.666666666666667],
[0.810126582278481, 0.727272727272727, 0.652173913043478, 0.6, 0.666666666666667],
[0.873417721518987, 0.704545454545454, 0.710144927536232, 0.6, 0.666666666666667],
[0.69620253164557, 0.522727272727273, 0.579710144927536, 0.52, 0.666666666666667],
[0.822784810126582, 0.636363636363636, 0.666666666666667, 0.6, 0.666666666666667],
[0.721518987341772, 0.636363636363636, 0.652173913043478, 0.52, 0.666666666666667],
[0.79746835443038, 0.75, 0.681159420289855, 0.64, 0.666666666666667],
[0.620253164556962, 0.545454545454545, 0.478260869565217, 0.4, 0.666666666666667],
[0.835443037974683, 0.659090909090909, 0.666666666666667, 0.52, 0.666666666666667],
[0.658227848101266, 0.613636363636364, 0.565217391304348, 0.56, 0.666666666666667],
[0.632911392405063, 0.454545454545455, 0.507246376811594, 0.4, 0.666666666666667],
[0.746835443037975, 0.681818181818182, 0.608695652173913, 0.6, 0.666666666666667],
[0.759493670886076, 0.5, 0.579710144927536, 0.4, 0.666666666666667],
[0.772151898734177, 0.659090909090909, 0.681159420289855, 0.56, 0.666666666666667],
[0.708860759493671, 0.659090909090909, 0.521739130434783, 0.52, 0.666666666666667],
[0.848101265822785, 0.704545454545454, 0.63768115942029, 0.56, 0.666666666666667],
[0.708860759493671, 0.681818181818182, 0.652173913043478, 0.6, 0.666666666666667],
[0.734177215189873, 0.613636363636364, 0.594202898550725, 0.4, 0.666666666666667],
[0.784810126582278, 0.5, 0.652173913043478, 0.6, 0.666666666666667],
[0.708860759493671, 0.568181818181818, 0.565217391304348, 0.44, 0.666666666666667],
[0.746835443037975, 0.727272727272727, 0.695652173913043, 0.72, 0.666666666666667],
[0.772151898734177, 0.636363636363636, 0.579710144927536, 0.52, 0.666666666666667],
[0.79746835443038, 0.568181818181818, 0.710144927536232, 0.6, 0.666666666666667],
[0.772151898734177, 0.636363636363636, 0.681159420289855, 0.48, 0.666666666666667],
[0.810126582278481, 0.659090909090909, 0.623188405797101, 0.52, 0.666666666666667],
[0.835443037974683, 0.681818181818182, 0.63768115942029, 0.56, 0.666666666666667],
[0.860759493670886, 0.636363636363636, 0.695652173913043, 0.56, 0.666666666666667],
[0.848101265822785, 0.681818181818182, 0.72463768115942, 0.68, 0.666666666666667],
[0.759493670886076, 0.659090909090909, 0.652173913043478, 0.6, 0.666666666666667],
[0.721518987341772, 0.590909090909091, 0.507246376811594, 0.4, 0.666666666666667],
[0.69620253164557, 0.545454545454545, 0.550724637681159, 0.44, 0.666666666666667],
[0.69620253164557, 0.545454545454545, 0.536231884057971, 0.4, 0.666666666666667],
[0.734177215189873, 0.613636363636364, 0.565217391304348, 0.48, 0.666666666666667],
[0.759493670886076, 0.613636363636364, 0.739130434782609, 0.64, 0.666666666666667],
[0.683544303797468, 0.681818181818182, 0.652173913043478, 0.6, 0.666666666666667],
[0.759493670886076, 0.772727272727273, 0.652173913043478, 0.64, 0.666666666666667],
[0.848101265822785, 0.704545454545454, 0.681159420289855, 0.6, 0.666666666666667],
[0.79746835443038, 0.522727272727273, 0.63768115942029, 0.52, 0.666666666666667],
[0.708860759493671, 0.681818181818182, 0.594202898550725, 0.52, 0.666666666666667],
[0.69620253164557, 0.568181818181818, 0.579710144927536, 0.52, 0.666666666666667],
[0.69620253164557, 0.590909090909091, 0.63768115942029, 0.48, 0.666666666666667],
[0.772151898734177, 0.681818181818182, 0.666666666666667, 0.56, 0.666666666666667],
[0.734177215189873, 0.590909090909091, 0.579710144927536, 0.48, 0.666666666666667],
[0.632911392405063, 0.522727272727273, 0.478260869565217, 0.4, 0.666666666666667],
[0.708860759493671, 0.613636363636364, 0.608695652173913, 0.52, 0.666666666666667],
[0.721518987341772, 0.681818181818182, 0.608695652173913, 0.48, 0.666666666666667],
[0.721518987341772, 0.659090909090909, 0.608695652173913, 0.52, 0.666666666666667],
[0.784810126582278, 0.659090909090909, 0.623188405797101, 0.52, 0.666666666666667],
[0.645569620253164, 0.568181818181818, 0.434782608695652, 0.44, 0.666666666666667],
[0.721518987341772, 0.636363636363636, 0.594202898550725, 0.52, 0.666666666666667],
[0.79746835443038, 0.75, 0.869565217391304, 1, 1],
[0.734177215189873, 0.613636363636364, 0.739130434782609, 0.76, 1],
[0.89873417721519, 0.681818181818182, 0.855072463768116, 0.84, 1],
[0.79746835443038, 0.659090909090909, 0.811594202898551, 0.72, 1],
[0.822784810126582, 0.681818181818182, 0.840579710144927, 0.88, 1],
[0.962025316455696, 0.681818181818182, 0.956521739130435, 0.84, 1],
[0.620253164556962, 0.568181818181818, 0.652173913043478, 0.68, 1],
[0.924050632911392, 0.659090909090909, 0.91304347826087, 0.72, 1],
[0.848101265822785, 0.568181818181818, 0.840579710144927, 0.72, 1],
[0.911392405063291, 0.818181818181818, 0.884057971014493, 1, 1],
[0.822784810126582, 0.727272727272727, 0.739130434782609, 0.8, 1],
[0.810126582278481, 0.613636363636364, 0.768115942028985, 0.76, 1],
[0.860759493670886, 0.681818181818182, 0.797101449275362, 0.84, 1],
[0.721518987341772, 0.568181818181818, 0.72463768115942, 0.8, 1],
[0.734177215189873, 0.636363636363636, 0.739130434782609, 0.96, 1],
[0.810126582278481, 0.727272727272727, 0.768115942028985, 0.92, 1],
[0.822784810126582, 0.681818181818182, 0.797101449275362, 0.72, 1],
[0.974683544303797, 0.863636363636364, 0.971014492753623, 0.88, 1],
[0.974683544303797, 0.590909090909091, 1, 0.92, 1],
[0.759493670886076, 0.5, 0.72463768115942, 0.6, 1],
[0.873417721518987, 0.727272727272727, 0.826086956521739, 0.92, 1],
[0.708860759493671, 0.636363636363636, 0.710144927536232, 0.8, 1],
[0.974683544303797, 0.636363636363636, 0.971014492753623, 0.8, 1],
[0.79746835443038, 0.613636363636364, 0.710144927536232, 0.72, 1],
[0.848101265822785, 0.75, 0.826086956521739, 0.84, 1],
[0.911392405063291, 0.727272727272727, 0.869565217391304, 0.72, 1],
[0.784810126582278, 0.636363636363636, 0.695652173913043, 0.72, 1],
[0.772151898734177, 0.681818181818182, 0.710144927536232, 0.72, 1],
[0.810126582278481, 0.636363636363636, 0.811594202898551, 0.84, 1],
[0.911392405063291, 0.681818181818182, 0.840579710144927, 0.64, 1],
[0.936708860759494, 0.636363636363636, 0.884057971014493, 0.76, 1],
[1, 0.863636363636364, 0.927536231884058, 0.8, 1],
[0.810126582278481, 0.636363636363636, 0.811594202898551, 0.88, 1],
[0.79746835443038, 0.636363636363636, 0.739130434782609, 0.6, 1],
[0.772151898734177, 0.590909090909091, 0.811594202898551, 0.56, 1],
[0.974683544303797, 0.681818181818182, 0.884057971014493, 0.92, 1],
[0.79746835443038, 0.772727272727273, 0.811594202898551, 0.96, 1],
[0.810126582278481, 0.704545454545454, 0.797101449275362, 0.72, 1],
[0.759493670886076, 0.681818181818182, 0.695652173913043, 0.72, 1],
[0.873417721518987, 0.704545454545454, 0.782608695652174, 0.84, 1],
[0.848101265822785, 0.704545454545454, 0.811594202898551, 0.96, 1],
[0.873417721518987, 0.704545454545454, 0.739130434782609, 0.92, 1],
[0.734177215189873, 0.613636363636364, 0.739130434782609, 0.76, 1],
[0.860759493670886, 0.727272727272727, 0.855072463768116, 0.92, 1],
[0.848101265822785, 0.75, 0.826086956521739, 1, 1],
[0.848101265822785, 0.681818181818182, 0.753623188405797, 0.92, 1],
[0.79746835443038, 0.568181818181818, 0.72463768115942, 0.76, 1],
[0.822784810126582, 0.681818181818182, 0.753623188405797, 0.8, 1],
[0.784810126582278, 0.772727272727273, 0.782608695652174, 0.92, 1],
[0.746835443037975, 0.681818181818182, 0.739130434782609, 0.72, 1]]

In [137]:
# weight_changes should have 12 arrays 150 numbers
weight_changes = []

for i in range (0, hidden_nodes*output_nodes):
    weights = []
    weight_changes.append(weights)
    pass

output_errors = []
for i in range(0, output_nodes):
    errors = []
    output_errors.append(errors)

iterations = []
iter_count = 1
for i in data:
    # train the NN with each instance of data
    h_errors, o_errors = n.train(i[0:4], i[4:5])
    
    count = 0
    for weight_row in n.who:
        for element in weight_row:
            weight_changes[count].append(element)
            count += 1
    
    count = 0
    for error in o_errors:
        output_errors[count].append(error[0])
        count += 1
    
    iterations.append(iter_count)
    iter_count += 1
    pass

In [138]:
def PlotWeights(weight_trail):
    # These are the colors that will be used in the plot
    color_sequence = ['#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
                  '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
                  '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
                  '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5']
    # Common sizes: (10, 7.5) and (12, 9)
    fig, ax = plt.subplots(1, 1, figsize=(15, 8))
    
    # Remove the plot frame lines. They are unnecessary here.
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    # Ensure that the axis ticks only show up on the bottom and left of the plot.
    # Ticks on the right and top of the plot are generally unnecessary.
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    
    fig.subplots_adjust(left=.06, right=.90, bottom=.02, top=.94)
    # Limit the range of the plot to only where the data is.
    # Avoid unnecessary whitespace.
    ax.set_xlim(0.0, 150.1)
    ax.set_ylim(-3.0, 3.0)
    
    # Make sure your axis ticks are large enough to be easily read.
    # You don't want your viewers squinting to read your plot.
    plt.xticks(range(0, 150, 10), fontsize=14)
    plt.yticks(range(-2, 2), fontsize=14)
    ax.xaxis.set_major_formatter(plt.FuncFormatter('{:.0f}'.format))
    ax.yaxis.set_major_formatter(plt.FuncFormatter('{:.0f}'.format))
    
    # Provide tick lines across the plot to help your viewers trace along
    # the axis ticks. Make sure that the lines are light and small so they
    # don't obscure the primary data lines.
    plt.grid(True, 'major', 'y', ls='--', lw=.5, c='k', alpha=.3)
    
    # Remove the tick marks; they are unnecessary with the tick lines we just
    # plotted.
    plt.tick_params(axis='both', which='both', bottom='off', top='off',
                labelbottom='on', left='off', right='off', labelleft='on')
    
    # Now that the plot is prepared, it's time to actually plot the data!
    # Note that I plotted the majors in order of the highest % in the final year.
    hm_weights = ['w11', 'w12', 'w13', 'w14',
                  'w21', 'w22', 'w23', 'w24',
                  'w31', 'w32', 'w33', 'w34']
    
    hm_errors = ['error1', 'error2', 'error3']
    
    y_offsets = {'Foreign Languages': 0.5, 'English': -0.5,
             'Communications\nand Journalism': 0.75,
             'Art and Performance': -0.25, 'Agriculture': 1.25,
             'Social Sciences and History': 0.25, 'Business': -0.75,
             'Math and Statistics': 0.75, 'Architecture': -0.75,
             'Computer Science': 0.75, 'Engineering': -0.25}
    
    for rank, col in enumerate(hm_weights):
        line = plt.plot(iterations,
                        weight_changes[rank],
                        lw=2.5,
                        color=color_sequence[rank])
    
    for rank, col in enumerate(hm_errors):
        line = plt.plot(iterations,
                        output_errors[rank],
                        lw=2.5,
                        color=color_sequence[rank])
    
    fig.suptitle('Changes in weights as neural network gets trained by each instance', fontsize=18, ha='center')
    
    plt.show()

In [139]:
PlotWeights(weight_changes)



In [ ]: