Graph Convolutional Neural Networks

Graph LeNet5 with PyTorch

Xavier Bresson, Oct. 2017

Implementation of spectral graph ConvNets
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering
M Defferrard, X Bresson, P Vandergheynst
Advances in Neural Information Processing Systems, 3844-3852, 2016
ArXiv preprint: arXiv:1606.09375


In [1]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import pdb #pdb.set_trace()
import collections
import time
import numpy as np

import sys
sys.path.insert(0, 'lib/')
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

if torch.cuda.is_available():
    print('cuda available')
    dtypeFloat = torch.cuda.FloatTensor
    dtypeLong = torch.cuda.LongTensor
    torch.cuda.manual_seed(1)
else:
    print('cuda not available')
    dtypeFloat = torch.FloatTensor
    dtypeLong = torch.LongTensor
    torch.manual_seed(1)


cuda available

MNIST


In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('datasets', one_hot=False) # load data in folder datasets/

train_data = mnist.train.images.astype(np.float32)
val_data = mnist.validation.images.astype(np.float32)
test_data = mnist.test.images.astype(np.float32)
train_labels = mnist.train.labels
val_labels = mnist.validation.labels
test_labels = mnist.test.labels
print(train_data.shape)
print(train_labels.shape)
print(val_data.shape)
print(val_labels.shape)
print(test_data.shape)
print(test_labels.shape)


Extracting datasets/train-images-idx3-ubyte.gz
Extracting datasets/train-labels-idx1-ubyte.gz
Extracting datasets/t10k-images-idx3-ubyte.gz
Extracting datasets/t10k-labels-idx1-ubyte.gz
(55000, 784)
(55000,)
(5000, 784)
(5000,)
(10000, 784)
(10000,)

Graph


In [3]:
from lib.grid_graph import grid_graph
from lib.coarsening import coarsen
from lib.coarsening import lmax_L
from lib.coarsening import perm_data
from lib.coarsening import rescale_L

# Construct graph
t_start = time.time()
grid_side = 28
number_edges = 8
metric = 'euclidean'
A = grid_graph(grid_side,number_edges,metric) # create graph of Euclidean grid

# Compute coarsened graphs
coarsening_levels = 4
L, perm = coarsen(A, coarsening_levels)

# Compute max eigenvalue of graph Laplacians
lmax = []
for i in range(coarsening_levels):
    lmax.append(lmax_L(L[i]))
print('lmax: ' + str([lmax[i] for i in range(coarsening_levels)]))

# Reindex nodes to satisfy a binary tree structure
train_data = perm_data(train_data, perm)
val_data = perm_data(val_data, perm)
test_data = perm_data(test_data, perm)

print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

print('Execution time: {:.2f}s'.format(time.time() - t_start))
del perm


nb edges:  6396
Heavy Edge Matching coarsening with Xavier version
Layer 0: M_0 = |V| = 928 nodes (144 added), |E| = 3198 edges
Layer 1: M_1 = |V| = 464 nodes (61 added), |E| = 1592 edges
Layer 2: M_2 = |V| = 232 nodes (22 added), |E| = 772 edges
Layer 3: M_3 = |V| = 116 nodes (6 added), |E| = 370 edges
Layer 4: M_4 = |V| = 58 nodes (0 added), |E| = 189 edges
lmax: [1.3857549, 1.3441432, 1.219873, 0.99999928]
(55000, 928)
(5000, 928)
(10000, 928)
Execution time: 1.06s

Graph ConvNet LeNet5

Layers: CL32-MP4-CL64-MP4-FC512-FC10


In [6]:
# class definitions

class my_sparse_mm(torch.autograd.Function):
    """
    Implementation of a new autograd function for sparse variables, 
    called "my_sparse_mm", by subclassing torch.autograd.Function 
    and implementing the forward and backward passes.
    """
    
    def forward(self, W, x):  # W is SPARSE
        self.save_for_backward(W, x)
        y = torch.mm(W, x)
        return y
    
    def backward(self, grad_output):
        W, x = self.saved_tensors 
        grad_input = grad_output.clone()
        grad_input_dL_dW = torch.mm(grad_input, x.t()) 
        grad_input_dL_dx = torch.mm(W.t(), grad_input )
        return grad_input_dL_dW, grad_input_dL_dx
    
    
class Graph_ConvNet_LeNet5(nn.Module):
    
    def __init__(self, net_parameters):
        
        print('Graph ConvNet: LeNet5')
        
        super(Graph_ConvNet_LeNet5, self).__init__()
        
        # parameters
        D, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F = net_parameters
        FC1Fin = CL2_F*(D//16)
        
        # graph CL1
        self.cl1 = nn.Linear(CL1_K, CL1_F) 
        Fin = CL1_K; Fout = CL1_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.cl1.weight.data.uniform_(-scale, scale)
        self.cl1.bias.data.fill_(0.0)
        self.CL1_K = CL1_K; self.CL1_F = CL1_F; 
        
        # graph CL2
        self.cl2 = nn.Linear(CL2_K*CL1_F, CL2_F) 
        Fin = CL2_K*CL1_F; Fout = CL2_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.cl2.weight.data.uniform_(-scale, scale)
        self.cl2.bias.data.fill_(0.0)
        self.CL2_K = CL2_K; self.CL2_F = CL2_F; 

        # FC1
        self.fc1 = nn.Linear(FC1Fin, FC1_F) 
        Fin = FC1Fin; Fout = FC1_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.fc1.weight.data.uniform_(-scale, scale)
        self.fc1.bias.data.fill_(0.0)
        self.FC1Fin = FC1Fin
        
        # FC2
        self.fc2 = nn.Linear(FC1_F, FC2_F)
        Fin = FC1_F; Fout = FC2_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.fc2.weight.data.uniform_(-scale, scale)
        self.fc2.bias.data.fill_(0.0)

        # nb of parameters
        nb_param = CL1_K* CL1_F + CL1_F          # CL1
        nb_param += CL2_K* CL1_F* CL2_F + CL2_F  # CL2
        nb_param += FC1Fin* FC1_F + FC1_F        # FC1
        nb_param += FC1_F* FC2_F + FC2_F         # FC2
        print('nb of parameters=',nb_param,'\n')
        
        
    def init_weights(self, W, Fin, Fout):

        scale = np.sqrt( 2.0/ (Fin+Fout) )
        W.uniform_(-scale, scale)

        return W
        
        
    def graph_conv_cheby(self, x, cl, L, lmax, Fout, K):

        # parameters
        # B = batch size
        # V = nb vertices
        # Fin = nb input features
        # Fout = nb output features
        # K = Chebyshev order & support size
        B, V, Fin = x.size(); B, V, Fin = int(B), int(V), int(Fin) 

        # rescale Laplacian
        lmax = lmax_L(L)
        L = rescale_L(L, lmax) 
        
        # convert scipy sparse matric L to pytorch
        L = L.tocoo()
        indices = np.column_stack((L.row, L.col)).T 
        indices = indices.astype(np.int64)
        indices = torch.from_numpy(indices)
        indices = indices.type(torch.LongTensor)
        L_data = L.data.astype(np.float32)
        L_data = torch.from_numpy(L_data) 
        L_data = L_data.type(torch.FloatTensor)
        L = torch.sparse.FloatTensor(indices, L_data, torch.Size(L.shape))
        L = Variable( L , requires_grad=False)
        if torch.cuda.is_available():
            L = L.cuda()
        
        # transform to Chebyshev basis
        x0 = x.permute(1,2,0).contiguous()  # V x Fin x B
        x0 = x0.view([V, Fin*B])            # V x Fin*B
        x = x0.unsqueeze(0)                 # 1 x V x Fin*B
        
        def concat(x, x_):
            x_ = x_.unsqueeze(0)            # 1 x V x Fin*B
            return torch.cat((x, x_), 0)    # K x V x Fin*B  
             
        if K > 1: 
            x1 = my_sparse_mm()(L,x0)              # V x Fin*B
            x = torch.cat((x, x1.unsqueeze(0)),0)  # 2 x V x Fin*B
        for k in range(2, K):
            x2 = 2 * my_sparse_mm()(L,x1) - x0  
            x = torch.cat((x, x2.unsqueeze(0)),0)  # M x Fin*B
            x0, x1 = x1, x2  
        
        x = x.view([K, V, Fin, B])           # K x V x Fin x B     
        x = x.permute(3,1,2,0).contiguous()  # B x V x Fin x K       
        x = x.view([B*V, Fin*K])             # B*V x Fin*K
        
        # Compose linearly Fin features to get Fout features
        x = cl(x)                            # B*V x Fout  
        x = x.view([B, V, Fout])             # B x V x Fout
        
        return x
        
        
    # Max pooling of size p. Must be a power of 2.
    def graph_max_pool(self, x, p): 
        if p > 1: 
            x = x.permute(0,2,1).contiguous()  # x = B x F x V
            x = nn.MaxPool1d(p)(x)             # B x F x V/p          
            x = x.permute(0,2,1).contiguous()  # x = B x V/p x F
            return x  
        else:
            return x    
        
        
    def forward(self, x, d, L, lmax):
        
        # graph CL1
        x = x.unsqueeze(2) # B x V x Fin=1  
        x = self.graph_conv_cheby(x, self.cl1, L[0], lmax[0], self.CL1_F, self.CL1_K)
        x = F.relu(x)
        x = self.graph_max_pool(x, 4)
        
        # graph CL2
        x = self.graph_conv_cheby(x, self.cl2, L[2], lmax[2], self.CL2_F, self.CL2_K)
        x = F.relu(x)
        x = self.graph_max_pool(x, 4)
        
        # FC1
        x = x.view(-1, self.FC1Fin)
        x = self.fc1(x)
        x = F.relu(x)
        x  = nn.Dropout(d)(x)
        
        # FC2
        x = self.fc2(x)
            
        return x
        
        
    def loss(self, y, y_target, l2_regularization):
    
        loss = nn.CrossEntropyLoss()(y,y_target)

        l2_loss = 0.0
        for param in self.parameters():
            data = param* param
            l2_loss += data.sum()
           
        loss += 0.5* l2_regularization* l2_loss
            
        return loss
    
    
    def update(self, lr):
                
        update = torch.optim.SGD( self.parameters(), lr=lr, momentum=0.9 )
        
        return update
        
        
    def update_learning_rate(self, optimizer, lr):
   
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        return optimizer

    
    def evaluation(self, y_predicted, test_l):
    
        _, class_predicted = torch.max(y_predicted.data, 1)
        return 100.0* (class_predicted == test_l).sum()/ y_predicted.size(0)

In [8]:
# Delete existing network if exists
try:
    del net
    print('Delete existing network\n')
except NameError:
    print('No existing network to delete\n')



# network parameters
D = train_data.shape[1]
CL1_F = 32
CL1_K = 25
CL2_F = 64
CL2_K = 25
FC1_F = 512
FC2_F = 10
net_parameters = [D, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F]


# instantiate the object net of the class 
net = Graph_ConvNet_LeNet5(net_parameters)
if torch.cuda.is_available():
    net.cuda()
print(net)


# Weights
L_net = list(net.parameters())


# learning parameters
learning_rate = 0.05
dropout_value = 0.5
l2_regularization = 5e-4 
batch_size = 100
num_epochs = 20
train_size = train_data.shape[0]
nb_iter = int(num_epochs * train_size) // batch_size
print('num_epochs=',num_epochs,', train_size=',train_size,', nb_iter=',nb_iter)


# Optimizer
global_lr = learning_rate
global_step = 0
decay = 0.95
decay_steps = train_size
lr = learning_rate
optimizer = net.update(lr) 


# loop over epochs
indices = collections.deque()
for epoch in range(num_epochs):  # loop over the dataset multiple times

    # reshuffle 
    indices.extend(np.random.permutation(train_size)) # rand permutation
    
    # reset time
    t_start = time.time()
    
    # extract batches
    running_loss = 0.0
    running_accuray = 0
    running_total = 0
    while len(indices) >= batch_size:
        
        # extract batches
        batch_idx = [indices.popleft() for i in range(batch_size)]
        train_x, train_y = train_data[batch_idx,:], train_labels[batch_idx]
        train_x = Variable( torch.FloatTensor(train_x).type(dtypeFloat) , requires_grad=False) 
        train_y = train_y.astype(np.int64)
        train_y = torch.LongTensor(train_y).type(dtypeLong)
        train_y = Variable( train_y , requires_grad=False) 
            
        # Forward 
        y = net.forward(train_x, dropout_value, L, lmax)
        loss = net.loss(y,train_y,l2_regularization) 
        loss_train = loss.data[0]
        
        # Accuracy
        acc_train = net.evaluation(y,train_y.data)
        
        # backward
        loss.backward()
        
        # Update 
        global_step += batch_size # to update learning rate
        optimizer.step()
        optimizer.zero_grad()
        
        # loss, accuracy
        running_loss += loss_train
        running_accuray += acc_train
        running_total += 1
        
        # print        
        if not running_total%100: # print every x mini-batches
            print('epoch= %d, i= %4d, loss(batch)= %.4f, accuray(batch)= %.2f' % (epoch+1, running_total, loss_train, acc_train))
          
       
    # print 
    t_stop = time.time() - t_start
    print('epoch= %d, loss(train)= %.3f, accuracy(train)= %.3f, time= %.3f, lr= %.5f' % 
          (epoch+1, running_loss/running_total, running_accuray/running_total, t_stop, lr))
 

    # update learning rate 
    lr = global_lr * pow( decay , float(global_step// decay_steps) )
    optimizer = net.update_learning_rate(optimizer, lr)
    
    
    # Test set
    running_accuray_test = 0
    running_total_test = 0
    indices_test = collections.deque()
    indices_test.extend(range(test_data.shape[0]))
    t_start_test = time.time()
    while len(indices_test) >= batch_size:
        batch_idx_test = [indices_test.popleft() for i in range(batch_size)]
        test_x, test_y = test_data[batch_idx_test,:], test_labels[batch_idx_test]
        test_x = Variable( torch.FloatTensor(test_x).type(dtypeFloat) , requires_grad=False) 
        y = net.forward(test_x, 0.0, L, lmax) 
        test_y = test_y.astype(np.int64)
        test_y = torch.LongTensor(test_y).type(dtypeLong)
        test_y = Variable( test_y , requires_grad=False) 
        acc_test = net.evaluation(y,test_y.data)
        running_accuray_test += acc_test
        running_total_test += 1
    t_stop_test = time.time() - t_start_test
    print('  accuracy(test) = %.3f %%, time= %.3f' % (running_accuray_test / running_total_test, t_stop_test))


Delete existing network

Graph ConvNet: LeNet5
nb of parameters= 1958282 

Graph_ConvNet_LeNet5 (
  (cl1): Linear (25 -> 32)
  (cl2): Linear (800 -> 64)
  (fc1): Linear (3712 -> 512)
  (fc2): Linear (512 -> 10)
)
num_epochs= 20 , train_size= 55000 , nb_iter= 11000
epoch= 1, i=  100, loss(batch)= 0.4129, accuray(batch)= 91.00
epoch= 1, i=  200, loss(batch)= 0.3374, accuray(batch)= 92.00
epoch= 1, i=  300, loss(batch)= 0.1941, accuray(batch)= 97.00
epoch= 1, i=  400, loss(batch)= 0.2339, accuray(batch)= 96.00
epoch= 1, i=  500, loss(batch)= 0.1828, accuray(batch)= 97.00
epoch= 1, loss(train)= 0.396, accuracy(train)= 90.804, time= 100.878, lr= 0.05000
  accuracy(test) = 97.730 %, time= 9.023
epoch= 2, i=  100, loss(batch)= 0.1696, accuray(batch)= 98.00
epoch= 2, i=  200, loss(batch)= 0.2761, accuray(batch)= 94.00
epoch= 2, i=  300, loss(batch)= 0.1676, accuray(batch)= 98.00
epoch= 2, i=  400, loss(batch)= 0.1188, accuray(batch)= 100.00
epoch= 2, i=  500, loss(batch)= 0.2046, accuray(batch)= 95.00
epoch= 2, loss(train)= 0.187, accuracy(train)= 97.549, time= 100.835, lr= 0.04750
  accuracy(test) = 98.270 %, time= 8.987
epoch= 3, i=  100, loss(batch)= 0.1839, accuray(batch)= 96.00
epoch= 3, i=  200, loss(batch)= 0.1674, accuray(batch)= 98.00
epoch= 3, i=  300, loss(batch)= 0.1397, accuray(batch)= 98.00
epoch= 3, i=  400, loss(batch)= 0.1966, accuray(batch)= 98.00
epoch= 3, i=  500, loss(batch)= 0.1546, accuray(batch)= 98.00
epoch= 3, loss(train)= 0.155, accuracy(train)= 98.253, time= 101.344, lr= 0.04512
  accuracy(test) = 98.650 %, time= 9.043
epoch= 4, i=  100, loss(batch)= 0.2077, accuray(batch)= 97.00
epoch= 4, i=  200, loss(batch)= 0.1644, accuray(batch)= 98.00
epoch= 4, i=  300, loss(batch)= 0.1275, accuray(batch)= 99.00
epoch= 4, i=  400, loss(batch)= 0.1430, accuray(batch)= 98.00
epoch= 4, i=  500, loss(batch)= 0.1855, accuray(batch)= 99.00
epoch= 4, loss(train)= 0.138, accuracy(train)= 98.542, time= 101.353, lr= 0.04287
  accuracy(test) = 98.650 %, time= 9.012
epoch= 5, i=  100, loss(batch)= 0.1073, accuray(batch)= 99.00
epoch= 5, i=  200, loss(batch)= 0.1340, accuray(batch)= 98.00
epoch= 5, i=  300, loss(batch)= 0.1362, accuray(batch)= 99.00
epoch= 5, i=  400, loss(batch)= 0.1478, accuray(batch)= 98.00
epoch= 5, i=  500, loss(batch)= 0.1501, accuray(batch)= 98.00
epoch= 5, loss(train)= 0.122, accuracy(train)= 98.795, time= 101.034, lr= 0.04073
  accuracy(test) = 98.740 %, time= 9.012
epoch= 6, i=  100, loss(batch)= 0.1457, accuray(batch)= 98.00
epoch= 6, i=  200, loss(batch)= 0.1073, accuray(batch)= 98.00
epoch= 6, i=  300, loss(batch)= 0.1079, accuray(batch)= 99.00
epoch= 6, i=  400, loss(batch)= 0.0891, accuray(batch)= 99.00
epoch= 6, i=  500, loss(batch)= 0.0995, accuray(batch)= 99.00
epoch= 6, loss(train)= 0.111, accuracy(train)= 98.978, time= 101.218, lr= 0.03869
  accuracy(test) = 98.690 %, time= 9.030
epoch= 7, i=  100, loss(batch)= 0.1062, accuray(batch)= 99.00
epoch= 7, i=  200, loss(batch)= 0.0986, accuray(batch)= 99.00
epoch= 7, i=  300, loss(batch)= 0.1089, accuray(batch)= 99.00
epoch= 7, i=  400, loss(batch)= 0.0914, accuray(batch)= 99.00
epoch= 7, i=  500, loss(batch)= 0.1146, accuray(batch)= 98.00
epoch= 7, loss(train)= 0.102, accuracy(train)= 99.120, time= 101.632, lr= 0.03675
  accuracy(test) = 98.990 %, time= 9.076
epoch= 8, i=  100, loss(batch)= 0.0832, accuray(batch)= 100.00
epoch= 8, i=  200, loss(batch)= 0.0797, accuray(batch)= 100.00
epoch= 8, i=  300, loss(batch)= 0.0762, accuray(batch)= 100.00
epoch= 8, i=  400, loss(batch)= 0.2062, accuray(batch)= 97.00
epoch= 8, i=  500, loss(batch)= 0.0970, accuray(batch)= 99.00
epoch= 8, loss(train)= 0.095, accuracy(train)= 99.164, time= 101.151, lr= 0.03492
  accuracy(test) = 98.990 %, time= 9.021
epoch= 9, i=  100, loss(batch)= 0.1012, accuray(batch)= 99.00
epoch= 9, i=  200, loss(batch)= 0.0793, accuray(batch)= 100.00
epoch= 9, i=  300, loss(batch)= 0.0771, accuray(batch)= 100.00
epoch= 9, i=  400, loss(batch)= 0.0786, accuray(batch)= 100.00
epoch= 9, i=  500, loss(batch)= 0.0899, accuray(batch)= 99.00
epoch= 9, loss(train)= 0.089, accuracy(train)= 99.264, time= 101.987, lr= 0.03317
  accuracy(test) = 99.180 %, time= 9.014
epoch= 10, i=  100, loss(batch)= 0.0718, accuray(batch)= 99.00
epoch= 10, i=  200, loss(batch)= 0.0646, accuray(batch)= 100.00
epoch= 10, i=  300, loss(batch)= 0.1157, accuray(batch)= 99.00
epoch= 10, i=  400, loss(batch)= 0.0718, accuray(batch)= 100.00
epoch= 10, i=  500, loss(batch)= 0.1252, accuray(batch)= 98.00
epoch= 10, loss(train)= 0.082, accuracy(train)= 99.367, time= 101.591, lr= 0.03151
  accuracy(test) = 99.110 %, time= 9.033
epoch= 11, i=  100, loss(batch)= 0.0681, accuray(batch)= 100.00
epoch= 11, i=  200, loss(batch)= 0.0846, accuray(batch)= 99.00
epoch= 11, i=  300, loss(batch)= 0.0703, accuray(batch)= 99.00
epoch= 11, i=  400, loss(batch)= 0.0682, accuray(batch)= 100.00
epoch= 11, i=  500, loss(batch)= 0.0870, accuray(batch)= 100.00
epoch= 11, loss(train)= 0.079, accuracy(train)= 99.378, time= 101.530, lr= 0.02994
  accuracy(test) = 99.060 %, time= 9.017
epoch= 12, i=  100, loss(batch)= 0.0997, accuray(batch)= 98.00
epoch= 12, i=  200, loss(batch)= 0.1114, accuray(batch)= 98.00
epoch= 12, i=  300, loss(batch)= 0.0691, accuray(batch)= 99.00
epoch= 12, i=  400, loss(batch)= 0.0671, accuray(batch)= 100.00
epoch= 12, i=  500, loss(batch)= 0.0946, accuray(batch)= 98.00
epoch= 12, loss(train)= 0.075, accuracy(train)= 99.458, time= 101.616, lr= 0.02844
  accuracy(test) = 99.120 %, time= 9.091
epoch= 13, i=  100, loss(batch)= 0.0869, accuray(batch)= 98.00
epoch= 13, i=  200, loss(batch)= 0.0619, accuray(batch)= 100.00
epoch= 13, i=  300, loss(batch)= 0.0930, accuray(batch)= 99.00
epoch= 13, i=  400, loss(batch)= 0.0799, accuray(batch)= 99.00
epoch= 13, i=  500, loss(batch)= 0.0586, accuray(batch)= 100.00
epoch= 13, loss(train)= 0.072, accuracy(train)= 99.458, time= 101.749, lr= 0.02702
  accuracy(test) = 99.100 %, time= 9.008
epoch= 14, i=  100, loss(batch)= 0.0618, accuray(batch)= 100.00
epoch= 14, i=  200, loss(batch)= 0.0577, accuray(batch)= 100.00
epoch= 14, i=  300, loss(batch)= 0.0583, accuray(batch)= 100.00
epoch= 14, i=  400, loss(batch)= 0.0607, accuray(batch)= 100.00
epoch= 14, i=  500, loss(batch)= 0.0760, accuray(batch)= 99.00
epoch= 14, loss(train)= 0.069, accuracy(train)= 99.516, time= 101.536, lr= 0.02567
  accuracy(test) = 99.180 %, time= 9.018
epoch= 15, i=  100, loss(batch)= 0.0555, accuray(batch)= 100.00
epoch= 15, i=  200, loss(batch)= 0.0765, accuray(batch)= 99.00
epoch= 15, i=  300, loss(batch)= 0.0752, accuray(batch)= 99.00
epoch= 15, i=  400, loss(batch)= 0.0618, accuray(batch)= 100.00
epoch= 15, i=  500, loss(batch)= 0.0734, accuray(batch)= 100.00
epoch= 15, loss(train)= 0.066, accuracy(train)= 99.591, time= 101.262, lr= 0.02438
  accuracy(test) = 99.280 %, time= 9.052
epoch= 16, i=  100, loss(batch)= 0.0780, accuray(batch)= 99.00
epoch= 16, i=  200, loss(batch)= 0.0516, accuray(batch)= 100.00
epoch= 16, i=  300, loss(batch)= 0.0634, accuray(batch)= 100.00
epoch= 16, i=  400, loss(batch)= 0.0531, accuray(batch)= 100.00
epoch= 16, i=  500, loss(batch)= 0.0635, accuray(batch)= 99.00
epoch= 16, loss(train)= 0.064, accuracy(train)= 99.600, time= 100.793, lr= 0.02316
  accuracy(test) = 99.220 %, time= 9.053
epoch= 17, i=  100, loss(batch)= 0.0884, accuray(batch)= 99.00
epoch= 17, i=  200, loss(batch)= 0.0694, accuray(batch)= 99.00
epoch= 17, i=  300, loss(batch)= 0.0541, accuray(batch)= 100.00
epoch= 17, i=  400, loss(batch)= 0.0574, accuray(batch)= 100.00
epoch= 17, i=  500, loss(batch)= 0.0516, accuray(batch)= 100.00
epoch= 17, loss(train)= 0.062, accuracy(train)= 99.596, time= 100.811, lr= 0.02201
  accuracy(test) = 99.200 %, time= 9.063
epoch= 18, i=  100, loss(batch)= 0.0531, accuray(batch)= 100.00
epoch= 18, i=  200, loss(batch)= 0.0755, accuray(batch)= 98.00
epoch= 18, i=  300, loss(batch)= 0.0521, accuray(batch)= 100.00
epoch= 18, i=  400, loss(batch)= 0.0612, accuray(batch)= 100.00
epoch= 18, i=  500, loss(batch)= 0.0561, accuray(batch)= 100.00
epoch= 18, loss(train)= 0.061, accuracy(train)= 99.615, time= 100.540, lr= 0.02091
  accuracy(test) = 99.180 %, time= 9.043
epoch= 19, i=  100, loss(batch)= 0.0559, accuray(batch)= 100.00
epoch= 19, i=  200, loss(batch)= 0.0570, accuray(batch)= 99.00
epoch= 19, i=  300, loss(batch)= 0.0497, accuray(batch)= 100.00
epoch= 19, i=  400, loss(batch)= 0.0596, accuray(batch)= 99.00
epoch= 19, i=  500, loss(batch)= 0.0499, accuray(batch)= 100.00
epoch= 19, loss(train)= 0.059, accuracy(train)= 99.629, time= 100.761, lr= 0.01986
  accuracy(test) = 99.210 %, time= 9.012
epoch= 20, i=  100, loss(batch)= 0.0495, accuray(batch)= 100.00
epoch= 20, i=  200, loss(batch)= 0.0549, accuray(batch)= 100.00
epoch= 20, i=  300, loss(batch)= 0.0477, accuray(batch)= 100.00
epoch= 20, i=  400, loss(batch)= 0.0674, accuray(batch)= 100.00
epoch= 20, i=  500, loss(batch)= 0.0522, accuray(batch)= 100.00
epoch= 20, loss(train)= 0.058, accuracy(train)= 99.664, time= 100.785, lr= 0.01887
  accuracy(test) = 99.190 %, time= 9.055

In [ ]: