Flavours of Gradient Descent

A quick recap of the Gradient Descent method: This is an iterative algorithm to minize a loss function $L(x)$, where we start with a guess of what the answer should be - and then take steps proportional to the gradient at the current point.

$x = x_0$ (initial guess)

Until Convergence is achieved:

$x_{i+1} = x_{i} - \eta\nabla_L(x_i)$

For example, Let's say $L(x) = x^2 - 2x + 1$ and we start at $x0 = 2$. Coding the Gradient Descent method in Python:


In [4]:
%matplotlib inline

In [3]:
import numpy as np

def L(x):
    return x**2 - 2*x + 1

def L_prime(x):
    return 2*x - 2


def converged(x_prev, x, epsilon):
    "Return True if the abs value of all elements in x-x_prev are <= epsilon."
    
    absdiff = np.abs(x-x_prev)
    return np.all(absdiff <= epsilon)


def gradient_descent(f_prime, x_0, learning_rate=0.2, n_iters=100, epsilon=1E-8):
    x = x_0
    
    for _ in range(n_iters):
        x_prev = x
        x -= learning_rate*f_prime(x)
        
        if converged(x_prev, x, epsilon):
            break
            
    return x

x_min = gradient_descent(L_prime, 2)

print('Minimum value of L(x) = x**2 - 2*x + 1.0 is [%.2f] at x = [%.2f]' % (L(x_min), x_min))


Minimum value of L(x) = x**2 - 2*x + 1.0 is [0.00] at x = [1.00]

Batch Gradient Descent

In most supervised ML applications, we will try to learn a pattern from a number of labeled examples. In Batch Gradient Descent, each iteration loops over entire set of examples.

So, let's build 1-layer network of Linear Perceptrons to classify Fisher's IRIS dataset (again!). Remember that a Linear Perceptron can only distinguish between two classes.

Since there are 3 classes, our mini-network will have 3 Perceptrons. We'll channel the output of each Perceptron $w_i^T + b$ into a softmax function to pick the final label. We'll train this network using Batch Gradient Descent.

Getting Data


In [2]:
import seaborn as sns
import pandas as pd

iris_df = sns.load_dataset('iris')
print('Columns: %s' % (iris_df.columns.values, ))
print('Labels:  %s' % (pd.unique(iris_df['species']), ))

iris_df.head(5)


Columns: ['sepal_length' 'sepal_width' 'petal_length' 'petal_width' 'species']
Labels:  ['setosa' 'versicolor' 'virginica']
Out[2]:
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa

The Softmax Function

The softmax function is a technique to apply a probabilistic classifier by making a probability distribution out of a set of values $(v_1, v_2, ..., v_n)$ which may or may not satisfy all the features of probability distribution:

  • $v_i >= 0$
  • $\sum_{i=1}^n v_i = 1$

The probability distribution is the Gibbs Distribution: $v'_i = \frac {\exp {v_i}} {\sum_{j=1}^n\exp {v_j})}$ for $i = 1, 2, ... n$.


In [5]:
def softmax(x):
    # Uncomment to find out why we shouldn't do it this way...
    # return np.exp(x) / np.sum(np.exp(x))
    scaled_x = x - np.max(x)    
    result = np.exp(scaled_x) / np.sum(np.exp(scaled_x))
    return result

a = np.array([-500.9, 2000, 7, 11, 12, -15, 100])
sm_a = softmax(a)
print('Softmax(%s) = %s' % (a, sm_a))


Softmax([ -500.9  2000.      7.     11.     12.    -15.    100. ]) = [ 0.  1.  0.  0.  0.  0.  0.]

Non-linear Perceptron With SoftMax

With softmax, we typically use the cross-entropy error as the function to minimize.

The Cross Entropy Error for a given input $X = (x_1, x_2, ..., x_n)$, where each $x_i$ is a vector, is given by:

$L(x) = - \frac {1}{n} \sum_{i=1}^n y_i^T log(\hat{y_i})$

Where

  • The sum runs over $X = (x_1, x_2, ..., x_n)$.
  • Each $y_i$ is the 1-of-n encoded label of the $i$-th example, so it's also a vector. For example, if the labels in order are ('apple', 'banana', 'orange') and the label of $x_i$ is 'banana', then $y_i = [0, 1, 0]$.
  • $\hat{y_i}$ is the softmax output for $x_i$ from the network.
  • The term $y_i^T log(\hat{y_i})$ is the vector dot product between $y_i$ and $log(\hat{y_i})$.

One of n Encoding


In [6]:
def encode_1_of_n(ordered_labels, y):
    label2idx = dict((label, idx)
                     for idx, label in enumerate(ordered_labels))
    
    def encode_one(y_i):        
        enc = np.zeros(len(ordered_labels))
        enc[label2idx[y_i]] = 1.0
        return enc
    
    return np.array([x for x in map(encode_one, y)])

encode_1_of_n(['apple', 'banana', 'orange'], 
              ['apple', 'banana', 'orange', 'apple', 'apple'])


Out[6]:
array([[ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 1.,  0.,  0.]])

Cross Entropy Error


In [7]:
def cross_entropy_loss(Y, Y_hat):
    entropy_sum = 0.0    
    log_Y_hat = np.log(Y_hat)
    
    for y, y_hat in zip(Y, log_Y_hat):        
        entropy_sum += np.dot(y, y_hat)  
    
    return -entropy_sum/Y.shape[0]

Y_tst = np.array([[1, 0, 0], 
                  [0, 1, 0]])

# log(Y_hat_tst1) is the same as Y_tst, so we expect the x-entropy error to be the min (-1) in this case.
print(Y_tst)
Y_hat_tst1 = np.array([[np.e, 1, 1,],
                     [1, np.e, 1]])
print(Y_hat_tst1)
print(cross_entropy_loss(Y_tst, Y_hat_tst1))
print()

# expect it to be > -1
Y_hat_tst2 = np.array([[1, 1, 1,],
                     [1, np.e, 1]])
print(Y_hat_tst2)
print(cross_entropy_loss(Y_tst, Y_hat_tst2))
print()


[[1 0 0]
 [0 1 0]]
[[ 2.71828183  1.          1.        ]
 [ 1.          2.71828183  1.        ]]
-1.0

[[ 1.          1.          1.        ]
 [ 1.          2.71828183  1.        ]]
-0.5

Gradient of the Cross Entropy Error

The Gradient update step in Gradient Descent when the Loss Function uses Cross Entropy Error is:

$w_i^{j+1} = w_i^{j} - \eta [\frac {\partial L} {\partial w_i}]^{j}$


In [11]:
import pandas as pd


class OneLayerNetworkWithSoftMax:
    
    def __init__(self):
        self.w, self.bias = None, 0.0
        self.optimiser = None
        self.output = None
        
    def init_weights(self, X, Y):
        """
        Initialize a 2D weight matrix as a Dataframe with 
        dim(n_labels*n_features).         
        """
        self.labels = np.unique(Y)
              
        w_init = np.random.randn(len(self.labels), X.shape[1])        
        self.w = pd.DataFrame(data=w_init)
        self.w.index.name = 'node_id'

    def predict(self, x):
        """
        Return the predicted label of x using current weights.
        """
        output = self.forward(x, update=False)
        max_label_idx = np.argmax(output)
        return self.labels[max_label_idx]

    def forward(self, x, update=True):
        """
        Calculate softmax(w^Tx+b) for x using current $w_i$ s.
        """
        #output = self.w.apply(lambda row: np.dot(row, x), axis=1)
        output = np.dot(self.w, x)
        output += self.bias
        
        output = softmax(output) 
        if update:
            self.output = output        
        return output
    
    def backward(self, x, y, learning_rate):
        """
        Executes the weight update step
        
            grad = (self.output - y)          
            
            for i in range(len(grad)):
                dw[i] -= grad[i] * x
                
            w -= learning_rate * dw
        
        :param x: one sample vector.
        :param y: One-hot encoded label for x.
        """
        
        # [y_hat1 - y1, y_hat2-y2, ... ]
        y_hat_min_y = self.output - y
        
        # Transpose the above to a column vector
        # and then multiply x with each element
        # to produce a 2D array (n_labels*n_features), same as w
        error_grad = np.apply_along_axis(lambda z: z*x , 
            1, np.atleast_2d(y_hat_min_y).T)
        dw = learning_rate * error_grad
        return dw
    
    def print_weight_diff(self, i, w_old, diff_only=True):
        if not diff_only:        
            print('Before Iteration [%s]: weights are: \n%s' % 
                  (i+1, w_old))

            print('After Iteration [%s]: weights are: \n%s' % 
                  (i+1, self.w))
        
        w_diff = np.abs(w_old - self.w)
        print('After Iteration [%s]: weights diff: \n%s' % 
              (i+1, w_diff))
                
    def _gen_minibatch(self, X, Y, mb_size):
        """Generates `mb_size` sized chunks from X and Y."""
        n_samples = X.shape[0]
        indices = np.arange(n_samples)
        np.random.shuffle(indices)

        for start in range(0, n_samples, mb_size):        
            yield X[start:start+mb_size, :], Y[start:start+mb_size, :]
    
    def _update_batch(self, i, X_batch, Y_batch, learning_rate, print_every=100):        
        w_old = self.w.copy()
        
        dw = []            
        for x, y in zip(X_batch, Y_batch):
            self.forward(x)
            dw_item = self.backward(x, y, learning_rate)
            dw.append(dw_item)
            
        dw_batch = np.mean(dw, axis=0)
        self.w -= dw_batch

        if (i == 0) or ((i+1) % print_every == 0):
            self.print_weight_diff(i, w_old)
    
    def train(self, X, Y, 
              n_iters=1000, 
              learning_rate=0.2,
              minibatch_size=30,
              epsilon=1E-8):
        """
        Entry point for the Minibatch SGD training method.
                      
        Calls forward+backward for each (x_i, y_i) pair and adjusts the
        weight w accordingly.        
        """
        self.init_weights(X, Y)    
        Y = encode_1_of_n(self.labels, Y)
        
        n_samples = X.shape[0]     
               
        # MiniBatch SGD
        for i in range(n_iters):
            for X_batch, Y_batch in self._gen_minibatch(X, Y, minibatch_size):
                self._update_batch(i, X_batch, Y_batch, learning_rate)
                
# Set aside test data
label_grouper = iris_df.groupby('species')
test = label_grouper.head(10).set_index('species')
train = label_grouper.tail(100).set_index('species')

# Train the Network
X_train, Y_train = train.as_matrix(), train.index.values
nn = OneLayerNetworkWithSoftMax()
nn.train(X_train, Y_train)

# Test
results = test.apply(lambda row : nn.predict(row.as_matrix()), axis=1)
results.name = 'predicted_label'
results.index.name = 'expected_label'

results.reset_index()


After Iteration [1]: weights diff: 
                0         1         2         3
node_id                                        
0        0.157134  0.109942  0.043641  0.007615
1        0.034134  0.024803  0.009005  0.001791
2        0.123000  0.085140  0.034636  0.005824
After Iteration [1]: weights diff: 
                0         1         2         3
node_id                                        
0        0.390608  0.180224  0.286775  0.091272
1        0.401743  0.187803  0.289938  0.091786
2        0.011135  0.007579  0.003163  0.000514
After Iteration [1]: weights diff: 
                0         1         2         3
node_id                                        
0        1.194903  0.549488  0.859244  0.267766
1        1.195796  0.549898  0.859858  0.267953
2        0.000893  0.000410  0.000614  0.000187
After Iteration [1]: weights diff: 
                    0             1             2             3
node_id                                                        
0        3.149479e-07  1.342144e-07  2.516804e-07  8.022062e-08
1        8.746413e-01  3.893221e-01  7.539808e-01  2.726606e-01
2        8.746416e-01  3.893223e-01  7.539810e-01  2.726607e-01
After Iteration [1]: weights diff: 
                    0             1             2             3
node_id                                                        
0        1.726641e-07  8.037789e-08  1.463753e-07  5.344492e-08
1        8.617145e-05  4.268113e-05  7.251887e-05  2.889487e-05
2        8.634411e-05  4.276151e-05  7.266525e-05  2.894831e-05
After Iteration [100]: weights diff: 
                    0             1             2             3
node_id                                                        
0        2.279157e-02  1.529074e-02  7.037264e-03  1.123195e-03
1        2.279157e-02  1.529074e-02  7.037264e-03  1.123195e-03
2        4.014566e-13  2.682299e-13  1.358913e-13  1.998401e-14
After Iteration [100]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001958  0.001070  0.006294  0.002406
1        0.122217  0.055516  0.095648  0.030965
2        0.120259  0.056586  0.089354  0.028559
After Iteration [100]: weights diff: 
                0         1         2         3
node_id                                        
0        0.015737  0.007362  0.010743  0.003297
1        0.213279  0.097849  0.166981  0.054197
2        0.197542  0.090486  0.156237  0.050900
After Iteration [100]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001000  0.000484  0.000659  0.000215
1        0.298726  0.137498  0.249880  0.090784
2        0.299726  0.137981  0.250539  0.091000
After Iteration [100]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000035  0.000016  0.000028  0.000010
1        0.007685  0.003551  0.006063  0.002215
2        0.007720  0.003567  0.006091  0.002225
After Iteration [200]: weights diff: 
                    0             1             2         3
node_id                                                    
0        1.031856e-02  6.861618e-03  3.227568e-03  0.000495
1        1.031856e-02  6.861618e-03  3.227568e-03  0.000495
2        2.220446e-15  1.332268e-15  8.881784e-16  0.000000
After Iteration [200]: weights diff: 
                0         1         2         3
node_id                                        
0        0.002410  0.002156  0.001197  0.000635
1        0.049655  0.022499  0.040443  0.013079
2        0.052065  0.024655  0.039246  0.012445
After Iteration [200]: weights diff: 
                0         1         2         3
node_id                                        
0        0.010006  0.004752  0.006782  0.002117
1        0.224828  0.104561  0.177301  0.058148
2        0.214822  0.099809  0.170518  0.056031
After Iteration [200]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000720  0.000349  0.000462  0.000154
1        0.195135  0.089791  0.162007  0.058457
2        0.195855  0.090140  0.162470  0.058611
After Iteration [200]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000030  0.000014  0.000024  0.000009
1        0.066797  0.030670  0.053013  0.018442
2        0.066827  0.030684  0.053036  0.018451
After Iteration [300]: weights diff: 
                0         1         2         3
node_id                                        
0        0.007104  0.004709  0.002232  0.000334
1        0.007104  0.004709  0.002232  0.000334
2        0.000000  0.000000  0.000000  0.000000
After Iteration [300]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001561  0.001440  0.001031  0.000515
1        0.068747  0.031915  0.053940  0.017396
2        0.070308  0.033356  0.052909  0.016881
After Iteration [300]: weights diff: 
                0         1         2         3
node_id                                        
0        0.006701  0.003206  0.004527  0.001427
1        0.240006  0.111852  0.189540  0.062387
2        0.233305  0.108647  0.185013  0.060960
After Iteration [300]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000484  0.000235  0.000306  0.000104
1        0.139921  0.064663  0.115456  0.041626
2        0.140404  0.064898  0.115763  0.041729
After Iteration [300]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000017  0.000008  0.000014  0.000005
1        0.160093  0.073254  0.127421  0.043473
2        0.160110  0.073262  0.127435  0.043478
After Iteration [400]: weights diff: 
                0         1        2         3
node_id                                       
0        0.006417  0.004247  0.00202  0.000297
1        0.006417  0.004247  0.00202  0.000297
2        0.000000  0.000000  0.00000  0.000000
After Iteration [400]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001134  0.001156  0.001129  0.000530
1        0.086787  0.040568  0.067167  0.021700
2        0.087920  0.041724  0.066038  0.021169
After Iteration [400]: weights diff: 
                0         1         2         3
node_id                                        
0        0.005674  0.002728  0.003829  0.001215
1        0.252977  0.117850  0.199696  0.065841
2        0.247302  0.115122  0.195866  0.064626
After Iteration [400]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000389  0.000189  0.000245  0.000083
1        0.099408  0.046400  0.081400  0.029573
2        0.099797  0.046589  0.081645  0.029657
After Iteration [400]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000012  0.000005  0.000009  0.000003
1        0.232315  0.106128  0.185057  0.062951
2        0.232326  0.106134  0.185066  0.062955
After Iteration [500]: weights diff: 
               0         1         2         3
node_id                                       
0        0.00613  0.004052  0.001934  0.000279
1        0.00613  0.004052  0.001934  0.000279
2        0.00000  0.000000  0.000000  0.000000
After Iteration [500]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001177  0.001135  0.001001  0.000484
1        0.081663  0.038154  0.063299  0.020445
2        0.082840  0.039288  0.062298  0.019961
After Iteration [500]: weights diff: 
                0         1         2         3
node_id                                        
0        0.005489  0.002650  0.003701  0.001182
1        0.259569  0.120859  0.204829  0.067620
2        0.254080  0.118209  0.201128  0.066437
After Iteration [500]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000372  0.000181  0.000233  0.000080
1        0.073677  0.034795  0.059834  0.022015
2        0.074050  0.034976  0.060067  0.022094
After Iteration [500]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000009  0.000004  0.000007  0.000003
1        0.260033  0.118742  0.207285  0.070280
2        0.260042  0.118746  0.207291  0.070282
After Iteration [600]: weights diff: 
                0         1         2         3
node_id                                        
0        0.005567  0.003677  0.001761  0.000251
1        0.005567  0.003677  0.001761  0.000251
2        0.000000  0.000000  0.000000  0.000000
After Iteration [600]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001328  0.001145  0.000733  0.000387
1        0.067291  0.031355  0.052442  0.016913
2        0.068619  0.032500  0.051710  0.016526
After Iteration [600]: weights diff: 
                0         1         2         3
node_id                                        
0        0.005234  0.002536  0.003524  0.001132
1        0.260220  0.121115  0.205400  0.067886
2        0.254987  0.118579  0.201877  0.066754
After Iteration [600]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000374  0.000182  0.000233  0.000080
1        0.059132  0.028199  0.047690  0.017765
2        0.059505  0.028381  0.047923  0.017846
After Iteration [600]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000007  0.000003  0.000005  0.000002
1        0.261734  0.119512  0.208822  0.070377
2        0.261740  0.119515  0.208827  0.070379
After Iteration [700]: weights diff: 
                0         1         2         3
node_id                                        
0        0.004876  0.003218  0.001548  0.000218
1        0.004876  0.003218  0.001548  0.000218
2        0.000000  0.000000  0.000000  0.000000
After Iteration [700]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001401  0.001112  0.000489  0.000293
1        0.053766  0.024973  0.042119  0.013568
2        0.055167  0.026085  0.041631  0.013275
After Iteration [700]: weights diff: 
                0         1         2         3
node_id                                        
0        0.004795  0.002330  0.003223  0.001040
1        0.257620  0.119858  0.203486  0.067319
2        0.252825  0.117528  0.200264  0.066279
After Iteration [700]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000375  0.000182  0.000232  0.000080
1        0.050653  0.024325  0.040647  0.015291
2        0.051028  0.024507  0.040879  0.015371
After Iteration [700]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000005  0.000002  0.000004  0.000001
1        0.255018  0.116444  0.203643  0.068181
2        0.255023  0.116447  0.203647  0.068182
After Iteration [800]: weights diff: 
                0         1         2         3
node_id                                        
0        0.004236  0.002794  0.001348  0.000189
1        0.004236  0.002794  0.001348  0.000189
2        0.000000  0.000000  0.000000  0.000000
After Iteration [800]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001393  0.001047  0.000315  0.000221
1        0.043346  0.020071  0.034100  0.010979
2        0.044739  0.021118  0.033785  0.010758
After Iteration [800]: weights diff: 
                0         1         2         3
node_id                                        
0        0.004306  0.002097  0.002889  0.000935
1        0.253877  0.118065  0.200691  0.066444
2        0.249570  0.115968  0.197803  0.065509
After Iteration [800]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000372  0.000181  0.000229  0.000080
1        0.045268  0.021844  0.036197  0.013717
2        0.045640  0.022025  0.036426  0.013796
After Iteration [800]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000004  0.000002  0.000003  0.000001
1        0.247045  0.112803  0.197425  0.065709
2        0.247049  0.112804  0.197429  0.065710
After Iteration [900]: weights diff: 
                0         1         2         3
node_id                                        
0        0.003699  0.002439  0.001181  0.000164
1        0.003699  0.002439  0.001181  0.000164
2        0.000000  0.000000  0.000000  0.000000
After Iteration [900]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001341  0.000973  0.000201  0.000170
1        0.035702  0.016484  0.028181  0.009073
2        0.037043  0.017456  0.027980  0.008903
After Iteration [900]: weights diff: 
                0         1         2         3
node_id                                        
0        0.003850  0.001878  0.002578  0.000837
1        0.249943  0.116181  0.197738  0.065504
2        0.246093  0.114303  0.195160  0.064667
After Iteration [900]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000366  0.000178  0.000225  0.000078
1        0.041568  0.020127  0.033154  0.012632
2        0.041934  0.020306  0.033379  0.012710
After Iteration [900]: weights diff: 
                0         1         2             3
node_id                                            
0        0.000003  0.000001  0.000002  8.848048e-07
1        0.239816  0.109500  0.191765  6.351111e-02
2        0.239819  0.109502  0.191767  6.351200e-02
After Iteration [1000]: weights diff: 
                0        1         2         3
node_id                                       
0        0.003263  0.00215  0.001044  0.000144
1        0.003263  0.00215  0.001044  0.000144
2        0.000000  0.00000  0.000000  0.000000
After Iteration [1000]: weights diff: 
                0         1         2         3
node_id                                        
0        0.001273  0.000899  0.000125  0.000135
1        0.030092  0.013858  0.023817  0.007669
2        0.031365  0.014758  0.023692  0.007535
After Iteration [1000]: weights diff: 
                0         1         2         3
node_id                                        
0        0.003452  0.001687  0.002308  0.000751
1        0.246190  0.114381  0.194910  0.064595
2        0.242738  0.112694  0.192602  0.063844
After Iteration [1000]: weights diff: 
                0         1         2         3
node_id                                        
0        0.000358  0.000175  0.000219  0.000077
1        0.038876  0.018872  0.030948  0.011840
2        0.039235  0.019046  0.031167  0.011917
After Iteration [1000]: weights diff: 
                0         1         2             3
node_id                                            
0        0.000002  0.000001  0.000002  7.051972e-07
1        0.233661  0.106690  0.186932  6.166419e-02
2        0.233663  0.106691  0.186934  6.166489e-02
Out[11]:
expected_label predicted_label
0 setosa setosa
1 setosa setosa
2 setosa setosa
3 setosa setosa
4 setosa setosa
5 setosa setosa
6 setosa setosa
7 setosa setosa
8 setosa setosa
9 setosa setosa
10 versicolor versicolor
11 versicolor versicolor
12 versicolor versicolor
13 versicolor versicolor
14 versicolor versicolor
15 versicolor versicolor
16 versicolor versicolor
17 versicolor versicolor
18 versicolor versicolor
19 versicolor versicolor
20 virginica virginica
21 virginica virginica
22 virginica virginica
23 virginica virginica
24 virginica virginica
25 virginica virginica
26 virginica virginica
27 virginica virginica
28 virginica virginica
29 virginica virginica

Gradient of the Cross Entropy Error

Recap We know the the cross entropy error is the average of the vector products between the 1-hot enconding of label and the softmax output.

$L = - \frac {1}{n} \sum_{i=1}^n Y_i^T ln(\hat Y_i)$

Where the sum runs over all of the $n$ input samples.

This is a complex derivation, and we need to approach it step-by step. First, let's work out what the $i$-th sample contributes to the gradient of L, i.e. the derivative of - $Y_i^Tln(\hat Y_i)$.

Let's draw the structure of the Network using networkx for a 2-class problem, so we have 2 input nodes.


In [ ]:
import networkx as nx
from matplotlib import pylab

G = nx.DiGraph()
G.add_edges_from(
    [('i', 'n1'), 
     ('i', 'n2'),
     ('n1', 's1'),
     ('n2', 's1'),
     ('n1', 's2'),
     ('n2', 's2'),
     ('s1', 'y1'),
     ('s2', 'y2'),
    ])

pos = {'i': (1, 1), 
       'n1': (2, 0), 'n2': (2, 2),
       's1': (3, 0), 's2': (3, 2),
       'y1': (4, 0), 'y2': (4, 2),
      }

labels = {'i': r'$x_i$',
         'n1': r'$w_1$', 'n2': r'$w_2$',
         's1': r'$s_1$', # r'$\frac {\exp(z_{i1})} {S_i}$', 
         's2': r'$s_2$', # r'$\frac {\exp(z_{i2})} {S_i}$'         
         }

edge_labels = {('i', 'n1'): r'$x_i$', 
               ('i', 'n2'): r'$x_i$',
               ('n1', 's1'): r'$w_1^Tx_i$',
               ('n1', 's2'): r'$w_1^Tx_i$', 
               ('n2', 's1'): r'$w_2^Tx_i$', 
               ('n2', 's2'): r'$w_2^Tx_i$',
               ('n2', 's1'): r'$w_2^Tx_i$',
               ('s1', 'y1'): r'$\frac {\exp(z_{i1})} {S_i}$',
               ('s2', 'y2'): r'$\frac {\exp(z_{i2})} {S_i}$',
              }
nx.draw(G, pos=pos, node_size=1000)
nx.draw_networkx_labels(G,pos,labels, font_size=15, color='white')
nx.draw_networkx_edge_labels(G, pos=pos, 
    edge_labels=edge_labels, font_size=15)

To calculate the derivative $- Y_i^Tln(\hat {Y_i})$, we will break down the vector sum:

$L_i = - [y_1 ln (\hat {y_1}) + y_2 ln (\hat {y_2}) + ... ]$, where

  • Each of $(y_1, y_2, ...)$ is an element of the one hot encoded label for sample $x_i$, so only one of them is 1, all the others are 0.
  • Each of $(\hat {y_1}, \hat {y_2}, ...)$ is an element of the softmax output for input $x_i$.

We know that

$\begin{equation} y_1 ln (\hat {y_1}) = y_1 ln \frac {\exp(z_{i1})} {\exp(z_{i1}) + \exp(z_{i2}) + ...} \\ y_2 ln (\hat {y_2}) = y_2 ln \frac {\exp(z_{i2})} {\exp(z_{i1}) + \exp(z_{i2}) + ...} \\ \vdots \end{equation}$

Where $z_{i1} = w_1^Tx_i, z_{i2} = w_2^Tx_i$, and so on.

Our end goal is to calculate $(\frac {\partial L_i}{\partial w_1}, \frac {\partial L_i}{\partial w_2}, ...)$. We can use the Chain rule to produce:

$\begin{equation} \frac {\partial L_i}{\partial w_1} = \frac {\partial L_i} {\partial z_{i1}} \frac {\partial z_{i1}}{\partial w_1} \\ \frac {\partial L_i}{\partial w_2} = \frac {\partial L_i} {\partial z_{i2}} \frac {\partial z_{i2}}{\partial w_2} \\ \vdots \end{equation}$

The denominator is the same for all of $(\hat {y_1}, \hat {y_2}, ...)$, so let's call that $S_i$.

$S_i = \exp(z_{i1}) + \exp(z_{i1}) + ...$

So the equations above become:

$\begin{equation} y_1 ln(\hat {y_1}) = z_{i1} - ln(S_i) \\ y_2 ln(\hat {y_2}) = z_{i2} - ln(S_i) \\ \vdots \end{equation}$

Taking the partial derivative of all these equations w.r.t $z_{i1}$, we get:

$\begin{equation} \frac {y_1} {\hat {y_1}} \frac {\partial \hat {y_1}} {\partial z_{i1}} = y_1(1 - \frac {z_{i1}} {S_i}) = y_1(1 - \hat {y_1}) \\ \frac {y_2} {\hat {y_2}} \frac {\partial \hat {y_2}} {\partial z_{i1}} = - y_2 \frac {z_{i1}} {S_i} = - y_2 \hat {y_1} \\ \vdots \end{equation}$

Thus, we can express $\frac {\partial L_i} {\partial z_{j1}}$ as:

$\frac {\partial L_i} {\partial z_{j1}} = [y1(\hat {y_1} - 1) + y2 \hat {y_1} + y3 \hat{y_1}+ ...] = [\hat {y_1}(y_1 + y_2 + ...) - y_1] = (\hat {y_1} - y_1)$

Since exactly 1 of $(y_1 + y_2 + ...)$ is 1, and all the others are zero.

Similarly, we can prove:

$\begin{equation} \frac {\partial L_i} {\partial z_{j2}} = (\hat {y_2} - y_2) \\ \frac {\partial L_i} {\partial z_{j3}} = (\hat {y_3} - y_3) \\ \vdots \end{equation}$

Noting that

$\begin{equation} \frac {\partial z_{i1}} {\partial w_1} = x_i \\ \frac {\partial z_{i2}} {\partial w_2} = x_i \\ \vdots \end{equation}$

We finally arrive at the result:

$\begin{equation} \frac {\partial L_i} {\partial w_1} = - (\hat {y_1} - y_1)x_i \\ \frac {\partial L_i} {\partial w_2} = - (\hat {y_2} - y_2)x_i \\ \vdots \end{equation}$


In [ ]: