A training example in Pytorch

Introduction

Task

In this notebook we will train a neural network to do a simple task. This will be a classification task : as explained in the first week of lectures, classification basically means to find a decision boundary over a space of real numbers. For representation purposes we will work with a 2D example : the decision boundary will be a cercle. More precisely, it will be the unit circle in the plan.

Sampling

We will generate points $(x_1,x_2)$ to classify, and their class $y$. The actual decision fonction is $y=1_{x_1^2+x_2^2<1}$.

To have a balanced dataset with about as many points in each class, we will sample uniformly over polar coordinates, within the circle of center 0 and radius 2.


In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
print(torch.__version__)
%matplotlib inline


0.4.1

In [2]:
def sample_points(n):
    # returns (X,Y), where X of shape (n,2) is the numpy array of points and Y is the (n) array of classes
    
    radius = np.random.uniform(low=0,high=2,size=n).reshape(-1,1) # uniform radius between 0 and 2
    angle = np.random.uniform(low=0,high=2*np.pi,size=n).reshape(-1,1) # uniform angle
    x1 = radius*np.cos(angle)
    x2=radius*np.sin(angle)
    y = (radius<1).astype(int).reshape(-1)
    x = np.concatenate([x1,x2],axis=1)
    return x,y

In [3]:
# Generate the data
trainx,trainy = sample_points(10000)
valx,valy = sample_points(500)
testx,testy = sample_points(500)

print(trainx.shape,trainy.shape)


(10000, 2) (10000,)

Our model will be a multi-layer perceptron with one hidden layer, and an output of size 2 since we have two classes. Since it is a binary classification task we could also use just one output and a zero threshold, but we will use two to illustrate the use of the pytorch Cross-Entropy loss (with one output, you would use BinaryCrossEntropy).

As you know from the lectures, such a model cannot represent a circular boundary but could represent a polygonal boundary, whose number of sides is the number of neurons on the hidden layer. For example, with 6 hidden neurons the model could compute a hexagonal boundary that approximate the unity circle, such as :

Of course the trained model won't compute an actual hexagone, due to the activation that isn't a threshold, and the liberty of the final layer's weights (it does not have to be an AND). We can actually expect better accuracy than what a hexagon could do.


In [4]:
def generate_single_hidden_MLP(n_hidden_neurons):
    return nn.Sequential(nn.Linear(2,n_hidden_neurons),nn.ReLU(),nn.Linear(n_hidden_neurons,2))
model1 = generate_single_hidden_MLP(6)

To train our model, we will need to feed it with tensors. Let's transform our generated numpy arrays :


In [5]:
trainx = torch.from_numpy(trainx).float()
valx = torch.from_numpy(valx).float()
testx = torch.from_numpy(testx).float()

trainy = torch.from_numpy(trainy).long()
valy = torch.from_numpy(valy).long()
testy = torch.from_numpy(testy).long()
print(trainx.type(),trainy.type())


torch.FloatTensor torch.LongTensor

Now we will define our training routine. There is the question of whether to perform our traing on CPU or GPU. The best thing to do is to use a flag variable that you will set, when you actually do the training.


In [6]:
def training_routine(net,dataset,n_iters,gpu):
    # organize the data
    train_data,train_labels,val_data,val_labels = dataset
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),lr=0.01)
    
    # use the flag
    if gpu:
        train_data,train_labels = train_data.cuda(),train_labels.cuda()
        val_data,val_labels = val_data.cuda(),val_labels.cuda()
        net = net.cuda() # the network parameters also need to be on the gpu !
        print("Using GPU")
    else:
        print("Using CPU")
    for i in range(n_iters):
        # forward pass
        train_output = net(train_data)
        train_loss = criterion(train_output,train_labels)
        # backward pass and optimization
        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Once every 100 iterations, print statistics
        if i%100==0:
            print("At iteration",i)
            # compute the accuracy of the prediction
            train_prediction = train_output.cpu().detach().argmax(dim=1)
            train_accuracy = (train_prediction.numpy()==train_labels.cpu().numpy()).mean() 
            # Now for the validation set
            val_output = net(val_data)
            val_loss = criterion(val_output,val_labels)
            # compute the accuracy of the prediction
            val_prediction = val_output.cpu().detach().argmax(dim=1)
            val_accuracy = (val_prediction.numpy()==val_labels.cpu().numpy()).mean() 
            print("Training loss :",train_loss.cpu().detach().numpy())
            print("Training accuracy :",train_accuracy)
            print("Validation loss :",val_loss.cpu().detach().numpy())
            print("Validation accuracy :",val_accuracy)
    
    net = net.cpu()

In [7]:
dataset = trainx,trainy,valx,valy

In [8]:
gpu = True
gpu = gpu and torch.cuda.is_available() # to know if you actually can use the GPU
begin = time.time()
training_routine(model1,dataset,10000,gpu)
end=time.time()


Using GPU
At iteration 0
Training loss : 0.6987106204032898
Training accuracy : 0.5003
Validation loss : 0.7039246559143066
Validation accuracy : 0.486
At iteration 100
Training loss : 0.677510678768158
Training accuracy : 0.4203
Validation loss : 0.6795798540115356
Validation accuracy : 0.402
At iteration 200
Training loss : 0.6638249158859253
Training accuracy : 0.5779
Validation loss : 0.6638974547386169
Validation accuracy : 0.562
At iteration 300
Training loss : 0.6521827578544617
Training accuracy : 0.6372
Validation loss : 0.6509451270103455
Validation accuracy : 0.608
At iteration 400
Training loss : 0.6414715051651001
Training accuracy : 0.6623
Validation loss : 0.6396153569221497
Validation accuracy : 0.65
At iteration 500
Training loss : 0.6312623620033264
Training accuracy : 0.6757
Validation loss : 0.6290354132652283
Validation accuracy : 0.672
At iteration 600
Training loss : 0.6213157176971436
Training accuracy : 0.6864
Validation loss : 0.6188277006149292
Validation accuracy : 0.698
At iteration 700
Training loss : 0.6115304827690125
Training accuracy : 0.6946
Validation loss : 0.6090017557144165
Validation accuracy : 0.708
At iteration 800
Training loss : 0.6015798449516296
Training accuracy : 0.7031
Validation loss : 0.5991077423095703
Validation accuracy : 0.714
At iteration 900
Training loss : 0.5912153124809265
Training accuracy : 0.71
Validation loss : 0.588752031326294
Validation accuracy : 0.716
At iteration 1000
Training loss : 0.5803276896476746
Training accuracy : 0.7187
Validation loss : 0.5778717994689941
Validation accuracy : 0.718
At iteration 1100
Training loss : 0.5689370036125183
Training accuracy : 0.7294
Validation loss : 0.5664529800415039
Validation accuracy : 0.724
At iteration 1200
Training loss : 0.5569323301315308
Training accuracy : 0.7387
Validation loss : 0.5542795658111572
Validation accuracy : 0.734
At iteration 1300
Training loss : 0.5442454814910889
Training accuracy : 0.7472
Validation loss : 0.5413668155670166
Validation accuracy : 0.74
At iteration 1400
Training loss : 0.530874490737915
Training accuracy : 0.757
Validation loss : 0.5276691317558289
Validation accuracy : 0.76
At iteration 1500
Training loss : 0.5168139338493347
Training accuracy : 0.7674
Validation loss : 0.5130881667137146
Validation accuracy : 0.766
At iteration 1600
Training loss : 0.5021050572395325
Training accuracy : 0.7788
Validation loss : 0.497661828994751
Validation accuracy : 0.772
At iteration 1700
Training loss : 0.48768314719200134
Training accuracy : 0.791
Validation loss : 0.4825124442577362
Validation accuracy : 0.788
At iteration 1800
Training loss : 0.4740116596221924
Training accuracy : 0.8022
Validation loss : 0.4681008458137512
Validation accuracy : 0.806
At iteration 1900
Training loss : 0.4608691334724426
Training accuracy : 0.8119
Validation loss : 0.4541640877723694
Validation accuracy : 0.81
At iteration 2000
Training loss : 0.4481540620326996
Training accuracy : 0.8197
Validation loss : 0.44063761830329895
Validation accuracy : 0.814
At iteration 2100
Training loss : 0.4358036518096924
Training accuracy : 0.8264
Validation loss : 0.4274945855140686
Validation accuracy : 0.822
At iteration 2200
Training loss : 0.42370855808258057
Training accuracy : 0.832
Validation loss : 0.41458141803741455
Validation accuracy : 0.834
At iteration 2300
Training loss : 0.411865770816803
Training accuracy : 0.8383
Validation loss : 0.401969850063324
Validation accuracy : 0.84
At iteration 2400
Training loss : 0.40023329854011536
Training accuracy : 0.8449
Validation loss : 0.3895937204360962
Validation accuracy : 0.84
At iteration 2500
Training loss : 0.3888087868690491
Training accuracy : 0.85
Validation loss : 0.3774207830429077
Validation accuracy : 0.854
At iteration 2600
Training loss : 0.3775351941585541
Training accuracy : 0.8573
Validation loss : 0.3655126392841339
Validation accuracy : 0.868
At iteration 2700
Training loss : 0.36636418104171753
Training accuracy : 0.8643
Validation loss : 0.35366711020469666
Validation accuracy : 0.872
At iteration 2800
Training loss : 0.3553362488746643
Training accuracy : 0.8725
Validation loss : 0.34201210737228394
Validation accuracy : 0.88
At iteration 2900
Training loss : 0.34448328614234924
Training accuracy : 0.8824
Validation loss : 0.3305943012237549
Validation accuracy : 0.88
At iteration 3000
Training loss : 0.3338693082332611
Training accuracy : 0.8905
Validation loss : 0.319495290517807
Validation accuracy : 0.894
At iteration 3100
Training loss : 0.3234810531139374
Training accuracy : 0.8972
Validation loss : 0.3086738884449005
Validation accuracy : 0.898
At iteration 3200
Training loss : 0.31330394744873047
Training accuracy : 0.903
Validation loss : 0.2982039749622345
Validation accuracy : 0.908
At iteration 3300
Training loss : 0.30343005061149597
Training accuracy : 0.9083
Validation loss : 0.28819984197616577
Validation accuracy : 0.916
At iteration 3400
Training loss : 0.29375043511390686
Training accuracy : 0.9161
Validation loss : 0.278478741645813
Validation accuracy : 0.924
At iteration 3500
Training loss : 0.2843717634677887
Training accuracy : 0.9227
Validation loss : 0.2690211832523346
Validation accuracy : 0.926
At iteration 3600
Training loss : 0.27524057030677795
Training accuracy : 0.9274
Validation loss : 0.25984010100364685
Validation accuracy : 0.928
At iteration 3700
Training loss : 0.26641568541526794
Training accuracy : 0.9318
Validation loss : 0.25092747807502747
Validation accuracy : 0.942
At iteration 3800
Training loss : 0.25782617926597595
Training accuracy : 0.936
Validation loss : 0.242435485124588
Validation accuracy : 0.944
At iteration 3900
Training loss : 0.24958781898021698
Training accuracy : 0.9397
Validation loss : 0.234419584274292
Validation accuracy : 0.946
At iteration 4000
Training loss : 0.24167028069496155
Training accuracy : 0.9442
Validation loss : 0.22682949900627136
Validation accuracy : 0.95
At iteration 4100
Training loss : 0.2340223342180252
Training accuracy : 0.9464
Validation loss : 0.2195090651512146
Validation accuracy : 0.956
At iteration 4200
Training loss : 0.22662721574306488
Training accuracy : 0.9478
Validation loss : 0.2125617414712906
Validation accuracy : 0.956
At iteration 4300
Training loss : 0.2196005880832672
Training accuracy : 0.9508
Validation loss : 0.20601828396320343
Validation accuracy : 0.956
At iteration 4400
Training loss : 0.2129788100719452
Training accuracy : 0.9531
Validation loss : 0.19988633692264557
Validation accuracy : 0.962
At iteration 4500
Training loss : 0.20669223368167877
Training accuracy : 0.9554
Validation loss : 0.19415733218193054
Validation accuracy : 0.966
At iteration 4600
Training loss : 0.20078632235527039
Training accuracy : 0.9574
Validation loss : 0.1888844072818756
Validation accuracy : 0.966
At iteration 4700
Training loss : 0.1952987015247345
Training accuracy : 0.9594
Validation loss : 0.1840348094701767
Validation accuracy : 0.968
At iteration 4800
Training loss : 0.19018003344535828
Training accuracy : 0.9613
Validation loss : 0.17957104742527008
Validation accuracy : 0.968
At iteration 4900
Training loss : 0.18544895946979523
Training accuracy : 0.963
Validation loss : 0.1754365712404251
Validation accuracy : 0.966
At iteration 5000
Training loss : 0.18104971945285797
Training accuracy : 0.9648
Validation loss : 0.17157526314258575
Validation accuracy : 0.966
At iteration 5100
Training loss : 0.17694155871868134
Training accuracy : 0.9659
Validation loss : 0.16793076694011688
Validation accuracy : 0.968
At iteration 5200
Training loss : 0.1731051802635193
Training accuracy : 0.9672
Validation loss : 0.1645321249961853
Validation accuracy : 0.97
At iteration 5300
Training loss : 0.16951316595077515
Training accuracy : 0.9676
Validation loss : 0.1613655537366867
Validation accuracy : 0.97
At iteration 5400
Training loss : 0.16613084077835083
Training accuracy : 0.9679
Validation loss : 0.1583479344844818
Validation accuracy : 0.97
At iteration 5500
Training loss : 0.16294611990451813
Training accuracy : 0.9692
Validation loss : 0.1555103212594986
Validation accuracy : 0.972
At iteration 5600
Training loss : 0.1599428504705429
Training accuracy : 0.9697
Validation loss : 0.1528404951095581
Validation accuracy : 0.974
At iteration 5700
Training loss : 0.1571037620306015
Training accuracy : 0.9703
Validation loss : 0.1503140926361084
Validation accuracy : 0.974
At iteration 5800
Training loss : 0.15441352128982544
Training accuracy : 0.9711
Validation loss : 0.1479216367006302
Validation accuracy : 0.97
At iteration 5900
Training loss : 0.15184079110622406
Training accuracy : 0.9717
Validation loss : 0.145643949508667
Validation accuracy : 0.968
At iteration 6000
Training loss : 0.14939665794372559
Training accuracy : 0.9717
Validation loss : 0.14346875250339508
Validation accuracy : 0.968
At iteration 6100
Training loss : 0.14706821739673615
Training accuracy : 0.9721
Validation loss : 0.14139312505722046
Validation accuracy : 0.968
At iteration 6200
Training loss : 0.14484672248363495
Training accuracy : 0.9724
Validation loss : 0.13942237198352814
Validation accuracy : 0.968
At iteration 6300
Training loss : 0.14271709322929382
Training accuracy : 0.9724
Validation loss : 0.13753719627857208
Validation accuracy : 0.968
At iteration 6400
Training loss : 0.14068393409252167
Training accuracy : 0.9729
Validation loss : 0.1357404738664627
Validation accuracy : 0.964
At iteration 6500
Training loss : 0.13873538374900818
Training accuracy : 0.9734
Validation loss : 0.1339964121580124
Validation accuracy : 0.964
At iteration 6600
Training loss : 0.13686779141426086
Training accuracy : 0.9737
Validation loss : 0.1323196142911911
Validation accuracy : 0.964
At iteration 6700
Training loss : 0.1350780576467514
Training accuracy : 0.9745
Validation loss : 0.13071401417255402
Validation accuracy : 0.964
At iteration 6800
Training loss : 0.1333550363779068
Training accuracy : 0.9745
Validation loss : 0.12917150557041168
Validation accuracy : 0.964
At iteration 6900
Training loss : 0.13169458508491516
Training accuracy : 0.9748
Validation loss : 0.12768438458442688
Validation accuracy : 0.966
At iteration 7000
Training loss : 0.13009630143642426
Training accuracy : 0.9757
Validation loss : 0.12625116109848022
Validation accuracy : 0.968
At iteration 7100
Training loss : 0.12855537235736847
Training accuracy : 0.976
Validation loss : 0.12486620247364044
Validation accuracy : 0.968
At iteration 7200
Training loss : 0.12706787884235382
Training accuracy : 0.976
Validation loss : 0.12352976948022842
Validation accuracy : 0.968
At iteration 7300
Training loss : 0.1256323903799057
Training accuracy : 0.9765
Validation loss : 0.12222747504711151
Validation accuracy : 0.968
At iteration 7400
Training loss : 0.12424390763044357
Training accuracy : 0.9767
Validation loss : 0.12096820771694183
Validation accuracy : 0.968
At iteration 7500
Training loss : 0.12289751321077347
Training accuracy : 0.9773
Validation loss : 0.1197456642985344
Validation accuracy : 0.968
At iteration 7600
Training loss : 0.1215951070189476
Training accuracy : 0.9775
Validation loss : 0.11856357008218765
Validation accuracy : 0.968
At iteration 7700
Training loss : 0.12033259123563766
Training accuracy : 0.9774
Validation loss : 0.11741720139980316
Validation accuracy : 0.968
At iteration 7800
Training loss : 0.11910898238420486
Training accuracy : 0.9774
Validation loss : 0.11630597710609436
Validation accuracy : 0.968
At iteration 7900
Training loss : 0.1179240271449089
Training accuracy : 0.9774
Validation loss : 0.11522701382637024
Validation accuracy : 0.968
At iteration 8000
Training loss : 0.11677491664886475
Training accuracy : 0.9775
Validation loss : 0.11418116837739944
Validation accuracy : 0.968
At iteration 8100
Training loss : 0.1156592071056366
Training accuracy : 0.9775
Validation loss : 0.11317238211631775
Validation accuracy : 0.968
At iteration 8200
Training loss : 0.11457314342260361
Training accuracy : 0.9775
Validation loss : 0.11218325793743134
Validation accuracy : 0.968
At iteration 8300
Training loss : 0.11351557821035385
Training accuracy : 0.9777
Validation loss : 0.11121871322393417
Validation accuracy : 0.968
At iteration 8400
Training loss : 0.11248715966939926
Training accuracy : 0.9782
Validation loss : 0.11028134822845459
Validation accuracy : 0.968
At iteration 8500
Training loss : 0.11148744821548462
Training accuracy : 0.9786
Validation loss : 0.10937269032001495
Validation accuracy : 0.968
At iteration 8600
Training loss : 0.1105133444070816
Training accuracy : 0.9786
Validation loss : 0.10849720239639282
Validation accuracy : 0.97
At iteration 8700
Training loss : 0.10956472158432007
Training accuracy : 0.9785
Validation loss : 0.10764889419078827
Validation accuracy : 0.97
At iteration 8800
Training loss : 0.10864181071519852
Training accuracy : 0.9787
Validation loss : 0.1068202406167984
Validation accuracy : 0.97
At iteration 8900
Training loss : 0.10774435102939606
Training accuracy : 0.9787
Validation loss : 0.10601337999105453
Validation accuracy : 0.97
At iteration 9000
Training loss : 0.10686799138784409
Training accuracy : 0.9788
Validation loss : 0.10522656887769699
Validation accuracy : 0.97
At iteration 9100
Training loss : 0.10601214319467545
Training accuracy : 0.979
Validation loss : 0.104451484978199
Validation accuracy : 0.972
At iteration 9200
Training loss : 0.10517784208059311
Training accuracy : 0.9791
Validation loss : 0.10369487851858139
Validation accuracy : 0.972
At iteration 9300
Training loss : 0.10436240583658218
Training accuracy : 0.9792
Validation loss : 0.10296136140823364
Validation accuracy : 0.972
At iteration 9400
Training loss : 0.103565514087677
Training accuracy : 0.9792
Validation loss : 0.10224708914756775
Validation accuracy : 0.974
At iteration 9500
Training loss : 0.10278775542974472
Training accuracy : 0.9793
Validation loss : 0.10154952853918076
Validation accuracy : 0.976
At iteration 9600
Training loss : 0.10202813893556595
Training accuracy : 0.9793
Validation loss : 0.10086623579263687
Validation accuracy : 0.976
At iteration 9700
Training loss : 0.10128509253263474
Training accuracy : 0.9792
Validation loss : 0.10019765049219131
Validation accuracy : 0.976
At iteration 9800
Training loss : 0.10055863112211227
Training accuracy : 0.9795
Validation loss : 0.09954356402158737
Validation accuracy : 0.978
At iteration 9900
Training loss : 0.09984857589006424
Training accuracy : 0.9796
Validation loss : 0.09890390187501907
Validation accuracy : 0.978

In [9]:
print("Training time :",end-begin)


Training time : 24.634289741516113

In [10]:
# Let's try with 3 hidden neurons.
model2 = generate_single_hidden_MLP(3) 
training_routine(model2,dataset,10000,gpu)


Using GPU
At iteration 0
Training loss : 0.8042043447494507
Training accuracy : 0.5003
Validation loss : 0.8238264322280884
Validation accuracy : 0.486
At iteration 100
Training loss : 0.7185286283493042
Training accuracy : 0.5003
Validation loss : 0.7303446531295776
Validation accuracy : 0.486
At iteration 200
Training loss : 0.6989850401878357
Training accuracy : 0.3706
Validation loss : 0.7071495652198792
Validation accuracy : 0.346
At iteration 300
Training loss : 0.6923447251319885
Training accuracy : 0.387
Validation loss : 0.6988030672073364
Validation accuracy : 0.358
At iteration 400
Training loss : 0.6878220438957214
Training accuracy : 0.4534
Validation loss : 0.6936196684837341
Validation accuracy : 0.418
At iteration 500
Training loss : 0.682556688785553
Training accuracy : 0.5406
Validation loss : 0.6881617307662964
Validation accuracy : 0.512
At iteration 600
Training loss : 0.6763601899147034
Training accuracy : 0.5856
Validation loss : 0.6821454763412476
Validation accuracy : 0.566
At iteration 700
Training loss : 0.6690481305122375
Training accuracy : 0.6297
Validation loss : 0.6751475930213928
Validation accuracy : 0.604
At iteration 800
Training loss : 0.6603956818580627
Training accuracy : 0.6717
Validation loss : 0.6668155789375305
Validation accuracy : 0.652
At iteration 900
Training loss : 0.6502493023872375
Training accuracy : 0.7081
Validation loss : 0.6572006344795227
Validation accuracy : 0.676
At iteration 1000
Training loss : 0.6385756731033325
Training accuracy : 0.7383
Validation loss : 0.6457823514938354
Validation accuracy : 0.71
At iteration 1100
Training loss : 0.6254290342330933
Training accuracy : 0.7579
Validation loss : 0.6329132914543152
Validation accuracy : 0.734
At iteration 1200
Training loss : 0.6114764213562012
Training accuracy : 0.7647
Validation loss : 0.6191123127937317
Validation accuracy : 0.742
At iteration 1300
Training loss : 0.5972586274147034
Training accuracy : 0.7678
Validation loss : 0.6051972508430481
Validation accuracy : 0.748
At iteration 1400
Training loss : 0.5834115147590637
Training accuracy : 0.7713
Validation loss : 0.591754674911499
Validation accuracy : 0.752
At iteration 1500
Training loss : 0.5704270601272583
Training accuracy : 0.7727
Validation loss : 0.579190194606781
Validation accuracy : 0.76
At iteration 1600
Training loss : 0.5584383606910706
Training accuracy : 0.774
Validation loss : 0.5677558183670044
Validation accuracy : 0.764
At iteration 1700
Training loss : 0.5474631190299988
Training accuracy : 0.7773
Validation loss : 0.5574163198471069
Validation accuracy : 0.77
At iteration 1800
Training loss : 0.537483274936676
Training accuracy : 0.7793
Validation loss : 0.548203706741333
Validation accuracy : 0.768
At iteration 1900
Training loss : 0.5284117460250854
Training accuracy : 0.7812
Validation loss : 0.539940357208252
Validation accuracy : 0.766
At iteration 2000
Training loss : 0.5201975107192993
Training accuracy : 0.7829
Validation loss : 0.5325564742088318
Validation accuracy : 0.768
At iteration 2100
Training loss : 0.5127653479576111
Training accuracy : 0.7829
Validation loss : 0.5259811282157898
Validation accuracy : 0.768
At iteration 2200
Training loss : 0.5060506463050842
Training accuracy : 0.7832
Validation loss : 0.5201577544212341
Validation accuracy : 0.772
At iteration 2300
Training loss : 0.50001460313797
Training accuracy : 0.7834
Validation loss : 0.514948308467865
Validation accuracy : 0.772
At iteration 2400
Training loss : 0.4945991635322571
Training accuracy : 0.7834
Validation loss : 0.5102963447570801
Validation accuracy : 0.772
At iteration 2500
Training loss : 0.4897620677947998
Training accuracy : 0.7831
Validation loss : 0.5061904788017273
Validation accuracy : 0.772
At iteration 2600
Training loss : 0.4854539930820465
Training accuracy : 0.7829
Validation loss : 0.5025847554206848
Validation accuracy : 0.774
At iteration 2700
Training loss : 0.4816124141216278
Training accuracy : 0.7833
Validation loss : 0.49941909313201904
Validation accuracy : 0.774
At iteration 2800
Training loss : 0.4781963527202606
Training accuracy : 0.7827
Validation loss : 0.4966699182987213
Validation accuracy : 0.774
At iteration 2900
Training loss : 0.47516390681266785
Training accuracy : 0.7831
Validation loss : 0.4942916929721832
Validation accuracy : 0.772
At iteration 3000
Training loss : 0.4724833369255066
Training accuracy : 0.7828
Validation loss : 0.49222010374069214
Validation accuracy : 0.772
At iteration 3100
Training loss : 0.4701261818408966
Training accuracy : 0.7824
Validation loss : 0.4904025197029114
Validation accuracy : 0.772
At iteration 3200
Training loss : 0.4680247902870178
Training accuracy : 0.7817
Validation loss : 0.48880332708358765
Validation accuracy : 0.772
At iteration 3300
Training loss : 0.4661575257778168
Training accuracy : 0.7817
Validation loss : 0.4874080419540405
Validation accuracy : 0.772
At iteration 3400
Training loss : 0.4645048975944519
Training accuracy : 0.7809
Validation loss : 0.4861815273761749
Validation accuracy : 0.77
At iteration 3500
Training loss : 0.4630364179611206
Training accuracy : 0.7806
Validation loss : 0.48512333631515503
Validation accuracy : 0.768
At iteration 3600
Training loss : 0.46173372864723206
Training accuracy : 0.78
Validation loss : 0.48421892523765564
Validation accuracy : 0.766
At iteration 3700
Training loss : 0.460572212934494
Training accuracy : 0.7788
Validation loss : 0.48346036672592163
Validation accuracy : 0.764
At iteration 3800
Training loss : 0.4595365822315216
Training accuracy : 0.7787
Validation loss : 0.4828377068042755
Validation accuracy : 0.762
At iteration 3900
Training loss : 0.45861244201660156
Training accuracy : 0.7785
Validation loss : 0.4823111593723297
Validation accuracy : 0.762
At iteration 4000
Training loss : 0.4577837288379669
Training accuracy : 0.7787
Validation loss : 0.48185357451438904
Validation accuracy : 0.762
At iteration 4100
Training loss : 0.4570505917072296
Training accuracy : 0.7782
Validation loss : 0.4814351201057434
Validation accuracy : 0.764
At iteration 4200
Training loss : 0.4563976526260376
Training accuracy : 0.7771
Validation loss : 0.4810810983181
Validation accuracy : 0.766
At iteration 4300
Training loss : 0.455814927816391
Training accuracy : 0.7766
Validation loss : 0.4807562232017517
Validation accuracy : 0.768
At iteration 4400
Training loss : 0.45528584718704224
Training accuracy : 0.7769
Validation loss : 0.4804733395576477
Validation accuracy : 0.766
At iteration 4500
Training loss : 0.4548111855983734
Training accuracy : 0.777
Validation loss : 0.48023200035095215
Validation accuracy : 0.766
At iteration 4600
Training loss : 0.45438677072525024
Training accuracy : 0.7764
Validation loss : 0.48000213503837585
Validation accuracy : 0.766
At iteration 4700
Training loss : 0.45399174094200134
Training accuracy : 0.7762
Validation loss : 0.47981971502304077
Validation accuracy : 0.764
At iteration 4800
Training loss : 0.4536362886428833
Training accuracy : 0.7759
Validation loss : 0.4796644151210785
Validation accuracy : 0.762
At iteration 4900
Training loss : 0.4533211886882782
Training accuracy : 0.7759
Validation loss : 0.47953447699546814
Validation accuracy : 0.766
At iteration 5000
Training loss : 0.453032523393631
Training accuracy : 0.7755
Validation loss : 0.4794190227985382
Validation accuracy : 0.768
At iteration 5100
Training loss : 0.4527663588523865
Training accuracy : 0.776
Validation loss : 0.47931841015815735
Validation accuracy : 0.766
At iteration 5200
Training loss : 0.45251747965812683
Training accuracy : 0.7756
Validation loss : 0.47925683856010437
Validation accuracy : 0.764
At iteration 5300
Training loss : 0.4522855877876282
Training accuracy : 0.7753
Validation loss : 0.47922566533088684
Validation accuracy : 0.764
At iteration 5400
Training loss : 0.452070027589798
Training accuracy : 0.7756
Validation loss : 0.47918781638145447
Validation accuracy : 0.762
At iteration 5500
Training loss : 0.45187774300575256
Training accuracy : 0.7757
Validation loss : 0.479141503572464
Validation accuracy : 0.762
At iteration 5600
Training loss : 0.45170825719833374
Training accuracy : 0.7756
Validation loss : 0.47909456491470337
Validation accuracy : 0.76
At iteration 5700
Training loss : 0.45156118273735046
Training accuracy : 0.7756
Validation loss : 0.47905534505844116
Validation accuracy : 0.76
At iteration 5800
Training loss : 0.4514290392398834
Training accuracy : 0.7755
Validation loss : 0.4790177643299103
Validation accuracy : 0.76
At iteration 5900
Training loss : 0.45129701495170593
Training accuracy : 0.7752
Validation loss : 0.47899454832077026
Validation accuracy : 0.76
At iteration 6000
Training loss : 0.4511772096157074
Training accuracy : 0.7755
Validation loss : 0.47897839546203613
Validation accuracy : 0.76
At iteration 6100
Training loss : 0.45107242465019226
Training accuracy : 0.775
Validation loss : 0.47897568345069885
Validation accuracy : 0.758
At iteration 6200
Training loss : 0.4509792923927307
Training accuracy : 0.7752
Validation loss : 0.4789806306362152
Validation accuracy : 0.756
At iteration 6300
Training loss : 0.4508935511112213
Training accuracy : 0.7757
Validation loss : 0.47898975014686584
Validation accuracy : 0.758
At iteration 6400
Training loss : 0.4508166015148163
Training accuracy : 0.7756
Validation loss : 0.4789983630180359
Validation accuracy : 0.758
At iteration 6500
Training loss : 0.4507419466972351
Training accuracy : 0.7751
Validation loss : 0.47901594638824463
Validation accuracy : 0.758
At iteration 6600
Training loss : 0.450667142868042
Training accuracy : 0.775
Validation loss : 0.47904884815216064
Validation accuracy : 0.758
At iteration 6700
Training loss : 0.4505937099456787
Training accuracy : 0.7753
Validation loss : 0.4791039824485779
Validation accuracy : 0.76
At iteration 6800
Training loss : 0.45052626729011536
Training accuracy : 0.7751
Validation loss : 0.4791721701622009
Validation accuracy : 0.76
At iteration 6900
Training loss : 0.4504624009132385
Training accuracy : 0.7754
Validation loss : 0.47922977805137634
Validation accuracy : 0.76
At iteration 7000
Training loss : 0.45040565729141235
Training accuracy : 0.7753
Validation loss : 0.4792513847351074
Validation accuracy : 0.76
At iteration 7100
Training loss : 0.4503517150878906
Training accuracy : 0.7749
Validation loss : 0.47927895188331604
Validation accuracy : 0.762
At iteration 7200
Training loss : 0.4502948224544525
Training accuracy : 0.7754
Validation loss : 0.47932857275009155
Validation accuracy : 0.762
At iteration 7300
Training loss : 0.4502376317977905
Training accuracy : 0.7754
Validation loss : 0.47938016057014465
Validation accuracy : 0.76
At iteration 7400
Training loss : 0.45018404722213745
Training accuracy : 0.7758
Validation loss : 0.47943389415740967
Validation accuracy : 0.76
At iteration 7500
Training loss : 0.450132817029953
Training accuracy : 0.7757
Validation loss : 0.4794860780239105
Validation accuracy : 0.76
At iteration 7600
Training loss : 0.4500790536403656
Training accuracy : 0.7758
Validation loss : 0.4795333743095398
Validation accuracy : 0.76
At iteration 7700
Training loss : 0.45002856850624084
Training accuracy : 0.7761
Validation loss : 0.4795854985713959
Validation accuracy : 0.76
At iteration 7800
Training loss : 0.4499852955341339
Training accuracy : 0.7761
Validation loss : 0.4796288013458252
Validation accuracy : 0.758
At iteration 7900
Training loss : 0.4499473571777344
Training accuracy : 0.7763
Validation loss : 0.4796689450740814
Validation accuracy : 0.758
At iteration 8000
Training loss : 0.4499092698097229
Training accuracy : 0.7763
Validation loss : 0.47970810532569885
Validation accuracy : 0.758
At iteration 8100
Training loss : 0.44987720251083374
Training accuracy : 0.7761
Validation loss : 0.4797445237636566
Validation accuracy : 0.758
At iteration 8200
Training loss : 0.44984671473503113
Training accuracy : 0.7761
Validation loss : 0.47978121042251587
Validation accuracy : 0.758
At iteration 8300
Training loss : 0.44981464743614197
Training accuracy : 0.7761
Validation loss : 0.4798373579978943
Validation accuracy : 0.758
At iteration 8400
Training loss : 0.4497825801372528
Training accuracy : 0.7761
Validation loss : 0.4798918664455414
Validation accuracy : 0.758
At iteration 8500
Training loss : 0.4497530162334442
Training accuracy : 0.7763
Validation loss : 0.4799533188343048
Validation accuracy : 0.76
At iteration 8600
Training loss : 0.44972261786460876
Training accuracy : 0.7761
Validation loss : 0.4800124168395996
Validation accuracy : 0.762
At iteration 8700
Training loss : 0.44969257712364197
Training accuracy : 0.7762
Validation loss : 0.4800698161125183
Validation accuracy : 0.762
At iteration 8800
Training loss : 0.44966450333595276
Training accuracy : 0.7762
Validation loss : 0.48012039065361023
Validation accuracy : 0.76
At iteration 8900
Training loss : 0.44963324069976807
Training accuracy : 0.7762
Validation loss : 0.48017287254333496
Validation accuracy : 0.758
At iteration 9000
Training loss : 0.4496088922023773
Training accuracy : 0.7762
Validation loss : 0.48021143674850464
Validation accuracy : 0.756
At iteration 9100
Training loss : 0.44958609342575073
Training accuracy : 0.7763
Validation loss : 0.4802546799182892
Validation accuracy : 0.758
At iteration 9200
Training loss : 0.4495634138584137
Training accuracy : 0.7765
Validation loss : 0.4802907109260559
Validation accuracy : 0.76
At iteration 9300
Training loss : 0.4495399296283722
Training accuracy : 0.7764
Validation loss : 0.48030900955200195
Validation accuracy : 0.76
At iteration 9400
Training loss : 0.44951802492141724
Training accuracy : 0.7766
Validation loss : 0.4803306460380554
Validation accuracy : 0.76
At iteration 9500
Training loss : 0.4495003819465637
Training accuracy : 0.7769
Validation loss : 0.48035526275634766
Validation accuracy : 0.758
At iteration 9600
Training loss : 0.4494825303554535
Training accuracy : 0.7771
Validation loss : 0.48038244247436523
Validation accuracy : 0.758
At iteration 9700
Training loss : 0.44945982098579407
Training accuracy : 0.7772
Validation loss : 0.48043206334114075
Validation accuracy : 0.758
At iteration 9800
Training loss : 0.4494328498840332
Training accuracy : 0.7772
Validation loss : 0.48048070073127747
Validation accuracy : 0.758
At iteration 9900
Training loss : 0.4494076669216156
Training accuracy : 0.7773
Validation loss : 0.48052552342414856
Validation accuracy : 0.758

In [11]:
out = model2(testx).argmax(dim=1).detach().numpy()
green = testx.numpy()[np.where(out==1)]
red = testx.numpy()[np.where(out==0)]
print(green.shape,red.shape)


(288, 2) (212, 2)

In [12]:
def print_model(model,datapoints):
    out = model(datapoints).argmax(dim=1).detach().numpy()
    green = datapoints.numpy()[np.where(out==1)]
    red = datapoints.numpy()[np.where(out==0)]

    circle1 = plt.Circle((0, 0), 1, color='y')
    circle2 = plt.Circle((0, 0), 1, color='b',fill=False)

    fig, ax = plt.subplots() # note we must use plt.subplots, not plt.subplot
    # (or if you have an existing figure)
    # fig = plt.gcf()
    # ax = fig.gca()
    plt.xlim((-2,2))
    plt.ylim((-2,2))

    pos_values = plt.scatter(x=green[:,0],y=green[:,1], color='g',)
    neg_values = plt.scatter(x=red[:,0],y=red[:,1], color='r',)

    ax.add_artist(circle1)
    ax.add_artist(circle2)
    ax.add_artist(pos_values)
    ax.add_artist(neg_values)

In [13]:
print_model(model1,testx)



In [14]:
print_model(model2,testx)



In [15]:
model3 = generate_single_hidden_MLP(2) 
training_routine(model3,dataset,10000,gpu)


Using GPU
At iteration 0
Training loss : 0.757908821105957
Training accuracy : 0.5003
Validation loss : 0.7690756320953369
Validation accuracy : 0.486
At iteration 100
Training loss : 0.7138646841049194
Training accuracy : 0.5003
Validation loss : 0.7210879921913147
Validation accuracy : 0.486
At iteration 200
Training loss : 0.6944175958633423
Training accuracy : 0.5003
Validation loss : 0.6991903781890869
Validation accuracy : 0.486
At iteration 300
Training loss : 0.6836276054382324
Training accuracy : 0.5003
Validation loss : 0.6868808269500732
Validation accuracy : 0.486
At iteration 400
Training loss : 0.6757462024688721
Training accuracy : 0.5003
Validation loss : 0.6780419945716858
Validation accuracy : 0.486
At iteration 500
Training loss : 0.6687736511230469
Training accuracy : 0.5142
Validation loss : 0.6704592108726501
Validation accuracy : 0.49
At iteration 600
Training loss : 0.6619855761528015
Training accuracy : 0.6194
Validation loss : 0.663285493850708
Validation accuracy : 0.604
At iteration 700
Training loss : 0.6551350951194763
Training accuracy : 0.6494
Validation loss : 0.6561704874038696
Validation accuracy : 0.626
At iteration 800
Training loss : 0.6481907367706299
Training accuracy : 0.6631
Validation loss : 0.6490899324417114
Validation accuracy : 0.642
At iteration 900
Training loss : 0.6411333084106445
Training accuracy : 0.6722
Validation loss : 0.6420141458511353
Validation accuracy : 0.652
At iteration 1000
Training loss : 0.6339496374130249
Training accuracy : 0.6807
Validation loss : 0.6349431276321411
Validation accuracy : 0.658
At iteration 1100
Training loss : 0.6266373991966248
Training accuracy : 0.6901
Validation loss : 0.6278645992279053
Validation accuracy : 0.668
At iteration 1200
Training loss : 0.6191753149032593
Training accuracy : 0.698
Validation loss : 0.6208072304725647
Validation accuracy : 0.674
At iteration 1300
Training loss : 0.6114997267723083
Training accuracy : 0.7033
Validation loss : 0.6136322021484375
Validation accuracy : 0.678
At iteration 1400
Training loss : 0.6035844683647156
Training accuracy : 0.7107
Validation loss : 0.6063424944877625
Validation accuracy : 0.688
At iteration 1500
Training loss : 0.5954177975654602
Training accuracy : 0.7153
Validation loss : 0.5988637804985046
Validation accuracy : 0.69
At iteration 1600
Training loss : 0.5870065689086914
Training accuracy : 0.721
Validation loss : 0.5911092758178711
Validation accuracy : 0.698
At iteration 1700
Training loss : 0.5783757567405701
Training accuracy : 0.7266
Validation loss : 0.5830110907554626
Validation accuracy : 0.71
At iteration 1800
Training loss : 0.5696817636489868
Training accuracy : 0.7345
Validation loss : 0.5746251344680786
Validation accuracy : 0.718
At iteration 1900
Training loss : 0.5610436201095581
Training accuracy : 0.7402
Validation loss : 0.5661273002624512
Validation accuracy : 0.73
At iteration 2000
Training loss : 0.5525501370429993
Training accuracy : 0.7451
Validation loss : 0.5577940344810486
Validation accuracy : 0.736
At iteration 2100
Training loss : 0.5443151593208313
Training accuracy : 0.7512
Validation loss : 0.5496269464492798
Validation accuracy : 0.748
At iteration 2200
Training loss : 0.5363830924034119
Training accuracy : 0.757
Validation loss : 0.5417766571044922
Validation accuracy : 0.752
At iteration 2300
Training loss : 0.5288269519805908
Training accuracy : 0.7627
Validation loss : 0.5342857241630554
Validation accuracy : 0.752
At iteration 2400
Training loss : 0.5217110514640808
Training accuracy : 0.7678
Validation loss : 0.5271515250205994
Validation accuracy : 0.76
At iteration 2500
Training loss : 0.5150945782661438
Training accuracy : 0.7723
Validation loss : 0.5204028487205505
Validation accuracy : 0.766
At iteration 2600
Training loss : 0.5089967846870422
Training accuracy : 0.7776
Validation loss : 0.5140964984893799
Validation accuracy : 0.77
At iteration 2700
Training loss : 0.503364622592926
Training accuracy : 0.7826
Validation loss : 0.5082263946533203
Validation accuracy : 0.774
At iteration 2800
Training loss : 0.49825260043144226
Training accuracy : 0.7857
Validation loss : 0.5027372241020203
Validation accuracy : 0.78
At iteration 2900
Training loss : 0.4936380386352539
Training accuracy : 0.7862
Validation loss : 0.4977012276649475
Validation accuracy : 0.786
At iteration 3000
Training loss : 0.4894702732563019
Training accuracy : 0.7861
Validation loss : 0.49305543303489685
Validation accuracy : 0.782
At iteration 3100
Training loss : 0.4856936037540436
Training accuracy : 0.7869
Validation loss : 0.4888104498386383
Validation accuracy : 0.782
At iteration 3200
Training loss : 0.4822947382926941
Training accuracy : 0.7868
Validation loss : 0.4849281311035156
Validation accuracy : 0.786
At iteration 3300
Training loss : 0.47922685742378235
Training accuracy : 0.7873
Validation loss : 0.48135992884635925
Validation accuracy : 0.784
At iteration 3400
Training loss : 0.47646695375442505
Training accuracy : 0.7864
Validation loss : 0.47807401418685913
Validation accuracy : 0.784
At iteration 3500
Training loss : 0.47399234771728516
Training accuracy : 0.7874
Validation loss : 0.4750390648841858
Validation accuracy : 0.788
At iteration 3600
Training loss : 0.4717704653739929
Training accuracy : 0.7868
Validation loss : 0.4723238945007324
Validation accuracy : 0.788
At iteration 3700
Training loss : 0.4697822332382202
Training accuracy : 0.7871
Validation loss : 0.4698861241340637
Validation accuracy : 0.788
At iteration 3800
Training loss : 0.46801266074180603
Training accuracy : 0.7863
Validation loss : 0.46771350502967834
Validation accuracy : 0.788
At iteration 3900
Training loss : 0.4664306640625
Training accuracy : 0.7868
Validation loss : 0.4657498002052307
Validation accuracy : 0.79
At iteration 4000
Training loss : 0.46500927209854126
Training accuracy : 0.7871
Validation loss : 0.4639551639556885
Validation accuracy : 0.79
At iteration 4100
Training loss : 0.4637335538864136
Training accuracy : 0.7875
Validation loss : 0.4623580574989319
Validation accuracy : 0.79
At iteration 4200
Training loss : 0.46258631348609924
Training accuracy : 0.7874
Validation loss : 0.46093276143074036
Validation accuracy : 0.794
At iteration 4300
Training loss : 0.461558997631073
Training accuracy : 0.7875
Validation loss : 0.4596591591835022
Validation accuracy : 0.794
At iteration 4400
Training loss : 0.46064233779907227
Training accuracy : 0.7879
Validation loss : 0.458560049533844
Validation accuracy : 0.796
At iteration 4500
Training loss : 0.4598313570022583
Training accuracy : 0.7881
Validation loss : 0.4575856626033783
Validation accuracy : 0.796
At iteration 4600
Training loss : 0.45910653471946716
Training accuracy : 0.7874
Validation loss : 0.4567101001739502
Validation accuracy : 0.794
At iteration 4700
Training loss : 0.45845508575439453
Training accuracy : 0.7867
Validation loss : 0.4559449553489685
Validation accuracy : 0.794
At iteration 4800
Training loss : 0.4578685164451599
Training accuracy : 0.786
Validation loss : 0.45528444647789
Validation accuracy : 0.796
At iteration 4900
Training loss : 0.45733726024627686
Training accuracy : 0.7856
Validation loss : 0.454721063375473
Validation accuracy : 0.796
At iteration 5000
Training loss : 0.4568570852279663
Training accuracy : 0.7851
Validation loss : 0.45425066351890564
Validation accuracy : 0.796
At iteration 5100
Training loss : 0.45642563700675964
Training accuracy : 0.7847
Validation loss : 0.4538639187812805
Validation accuracy : 0.792
At iteration 5200
Training loss : 0.45603397488594055
Training accuracy : 0.7843
Validation loss : 0.45351094007492065
Validation accuracy : 0.79
At iteration 5300
Training loss : 0.45567500591278076
Training accuracy : 0.7841
Validation loss : 0.453240305185318
Validation accuracy : 0.79
At iteration 5400
Training loss : 0.45534569025039673
Training accuracy : 0.7836
Validation loss : 0.45302337408065796
Validation accuracy : 0.79
At iteration 5500
Training loss : 0.45504093170166016
Training accuracy : 0.7839
Validation loss : 0.45283132791519165
Validation accuracy : 0.79
At iteration 5600
Training loss : 0.45476263761520386
Training accuracy : 0.7836
Validation loss : 0.45265713334083557
Validation accuracy : 0.79
At iteration 5700
Training loss : 0.4545077979564667
Training accuracy : 0.7835
Validation loss : 0.45250192284584045
Validation accuracy : 0.788
At iteration 5800
Training loss : 0.4542773962020874
Training accuracy : 0.7835
Validation loss : 0.45235541462898254
Validation accuracy : 0.79
At iteration 5900
Training loss : 0.4540669023990631
Training accuracy : 0.7828
Validation loss : 0.4522089660167694
Validation accuracy : 0.79
At iteration 6000
Training loss : 0.45387357473373413
Training accuracy : 0.7823
Validation loss : 0.4520772397518158
Validation accuracy : 0.79
At iteration 6100
Training loss : 0.4536908268928528
Training accuracy : 0.7819
Validation loss : 0.451982706785202
Validation accuracy : 0.79
At iteration 6200
Training loss : 0.4535170793533325
Training accuracy : 0.7818
Validation loss : 0.45192718505859375
Validation accuracy : 0.79
At iteration 6300
Training loss : 0.45335546135902405
Training accuracy : 0.7822
Validation loss : 0.45187908411026
Validation accuracy : 0.792
At iteration 6400
Training loss : 0.4532018005847931
Training accuracy : 0.7824
Validation loss : 0.45187172293663025
Validation accuracy : 0.792
At iteration 6500
Training loss : 0.4530511200428009
Training accuracy : 0.7826
Validation loss : 0.451901376247406
Validation accuracy : 0.792
At iteration 6600
Training loss : 0.45290425419807434
Training accuracy : 0.7825
Validation loss : 0.45195072889328003
Validation accuracy : 0.792
At iteration 6700
Training loss : 0.452772319316864
Training accuracy : 0.7824
Validation loss : 0.45195135474205017
Validation accuracy : 0.792
At iteration 6800
Training loss : 0.4526498019695282
Training accuracy : 0.7821
Validation loss : 0.4519382119178772
Validation accuracy : 0.792
At iteration 6900
Training loss : 0.4525333046913147
Training accuracy : 0.7821
Validation loss : 0.4519408941268921
Validation accuracy : 0.792
At iteration 7000
Training loss : 0.4524204730987549
Training accuracy : 0.7822
Validation loss : 0.4519646465778351
Validation accuracy : 0.79
At iteration 7100
Training loss : 0.45229971408843994
Training accuracy : 0.7825
Validation loss : 0.45204782485961914
Validation accuracy : 0.79
At iteration 7200
Training loss : 0.45219022035598755
Training accuracy : 0.7822
Validation loss : 0.4521273374557495
Validation accuracy : 0.792
At iteration 7300
Training loss : 0.45208197832107544
Training accuracy : 0.7821
Validation loss : 0.4522210657596588
Validation accuracy : 0.792
At iteration 7400
Training loss : 0.4519691467285156
Training accuracy : 0.7819
Validation loss : 0.4523455798625946
Validation accuracy : 0.792
At iteration 7500
Training loss : 0.45185211300849915
Training accuracy : 0.7824
Validation loss : 0.4524751603603363
Validation accuracy : 0.79
At iteration 7600
Training loss : 0.45173853635787964
Training accuracy : 0.7827
Validation loss : 0.45261380076408386
Validation accuracy : 0.79
At iteration 7700
Training loss : 0.4516219198703766
Training accuracy : 0.7826
Validation loss : 0.4527660012245178
Validation accuracy : 0.788
At iteration 7800
Training loss : 0.45150718092918396
Training accuracy : 0.7821
Validation loss : 0.45293357968330383
Validation accuracy : 0.788
At iteration 7900
Training loss : 0.45139312744140625
Training accuracy : 0.782
Validation loss : 0.453081876039505
Validation accuracy : 0.786
At iteration 8000
Training loss : 0.45127978920936584
Training accuracy : 0.7822
Validation loss : 0.45318156480789185
Validation accuracy : 0.788
At iteration 8100
Training loss : 0.4511800706386566
Training accuracy : 0.7825
Validation loss : 0.45325249433517456
Validation accuracy : 0.788
At iteration 8200
Training loss : 0.45108383893966675
Training accuracy : 0.7826
Validation loss : 0.45331111550331116
Validation accuracy : 0.788
At iteration 8300
Training loss : 0.45099425315856934
Training accuracy : 0.7827
Validation loss : 0.4533732533454895
Validation accuracy : 0.788
At iteration 8400
Training loss : 0.45091482996940613
Training accuracy : 0.7828
Validation loss : 0.4534011781215668
Validation accuracy : 0.788
At iteration 8500
Training loss : 0.45083779096603394
Training accuracy : 0.7826
Validation loss : 0.45342811942100525
Validation accuracy : 0.788
At iteration 8600
Training loss : 0.45076048374176025
Training accuracy : 0.7826
Validation loss : 0.4534529745578766
Validation accuracy : 0.786
At iteration 8700
Training loss : 0.4506836533546448
Training accuracy : 0.7825
Validation loss : 0.45347943902015686
Validation accuracy : 0.786
At iteration 8800
Training loss : 0.4506029784679413
Training accuracy : 0.7826
Validation loss : 0.4535260498523712
Validation accuracy : 0.784
At iteration 8900
Training loss : 0.4505268633365631
Training accuracy : 0.7829
Validation loss : 0.45358648896217346
Validation accuracy : 0.78
At iteration 9000
Training loss : 0.45046234130859375
Training accuracy : 0.7828
Validation loss : 0.4536360502243042
Validation accuracy : 0.78
At iteration 9100
Training loss : 0.450396865606308
Training accuracy : 0.7831
Validation loss : 0.4536837339401245
Validation accuracy : 0.78
At iteration 9200
Training loss : 0.45033028721809387
Training accuracy : 0.7828
Validation loss : 0.4537392556667328
Validation accuracy : 0.78
At iteration 9300
Training loss : 0.4502713978290558
Training accuracy : 0.7831
Validation loss : 0.4537900686264038
Validation accuracy : 0.78
At iteration 9400
Training loss : 0.4502134323120117
Training accuracy : 0.7831
Validation loss : 0.4538417160511017
Validation accuracy : 0.78
At iteration 9500
Training loss : 0.4501570165157318
Training accuracy : 0.7832
Validation loss : 0.453892320394516
Validation accuracy : 0.78
At iteration 9600
Training loss : 0.4501034617424011
Training accuracy : 0.7832
Validation loss : 0.45394226908683777
Validation accuracy : 0.78
At iteration 9700
Training loss : 0.4500454068183899
Training accuracy : 0.783
Validation loss : 0.4540122449398041
Validation accuracy : 0.78
At iteration 9800
Training loss : 0.44999000430107117
Training accuracy : 0.7828
Validation loss : 0.4540848433971405
Validation accuracy : 0.782
At iteration 9900
Training loss : 0.44993647933006287
Training accuracy : 0.7824
Validation loss : 0.45415589213371277
Validation accuracy : 0.782

In [16]:
print_model(model3,testx)