# A training example in Pytorch

## Introduction

In this notebook we will train a neural network to do a simple task. This will be a classification task : as explained in the first week of lectures, classification basically means to find a decision boundary over a space of real numbers. For representation purposes we will work with a 2D example : the decision boundary will be a cercle. More precisely, it will be the unit circle in the plan.

### Sampling

We will generate points \$(x_1,x_2)\$ to classify, and their class \$y\$. The actual decision fonction is \$y=1_{x_1^2+x_2^2<1}\$.

To have a balanced dataset with about as many points in each class, we will sample uniformly over polar coordinates, within the circle of center 0 and radius 2.

``````

In [1]:

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
print(torch.__version__)
%matplotlib inline

``````
``````

0.4.1

``````
``````

In [2]:

def sample_points(n):
# returns (X,Y), where X of shape (n,2) is the numpy array of points and Y is the (n) array of classes

angle = np.random.uniform(low=0,high=2*np.pi,size=n).reshape(-1,1) # uniform angle
x = np.concatenate([x1,x2],axis=1)
return x,y

``````
``````

In [3]:

# Generate the data
trainx,trainy = sample_points(10000)
valx,valy = sample_points(500)
testx,testy = sample_points(500)

print(trainx.shape,trainy.shape)

``````
``````

(10000, 2) (10000,)

``````

Our model will be a multi-layer perceptron with one hidden layer, and an output of size 2 since we have two classes. Since it is a binary classification task we could also use just one output and a zero threshold, but we will use two to illustrate the use of the pytorch Cross-Entropy loss (with one output, you would use BinaryCrossEntropy).

As you know from the lectures, such a model cannot represent a circular boundary but could represent a polygonal boundary, whose number of sides is the number of neurons on the hidden layer. For example, with 6 hidden neurons the model could compute a hexagonal boundary that approximate the unity circle, such as :

Of course the trained model won't compute an actual hexagone, due to the activation that isn't a threshold, and the liberty of the final layer's weights (it does not have to be an AND). We can actually expect better accuracy than what a hexagon could do.

``````

In [4]:

def generate_single_hidden_MLP(n_hidden_neurons):
return nn.Sequential(nn.Linear(2,n_hidden_neurons),nn.ReLU(),nn.Linear(n_hidden_neurons,2))
model1 = generate_single_hidden_MLP(6)

``````

To train our model, we will need to feed it with tensors. Let's transform our generated numpy arrays :

``````

In [5]:

trainx = torch.from_numpy(trainx).float()
valx = torch.from_numpy(valx).float()
testx = torch.from_numpy(testx).float()

trainy = torch.from_numpy(trainy).long()
valy = torch.from_numpy(valy).long()
testy = torch.from_numpy(testy).long()
print(trainx.type(),trainy.type())

``````
``````

torch.FloatTensor torch.LongTensor

``````

Now we will define our training routine. There is the question of whether to perform our traing on CPU or GPU. The best thing to do is to use a flag variable that you will set, when you actually do the training.

``````

In [6]:

def training_routine(net,dataset,n_iters,gpu):
# organize the data
train_data,train_labels,val_data,val_labels = dataset

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(),lr=0.01)

# use the flag
if gpu:
train_data,train_labels = train_data.cuda(),train_labels.cuda()
val_data,val_labels = val_data.cuda(),val_labels.cuda()
net = net.cuda() # the network parameters also need to be on the gpu !
print("Using GPU")
else:
print("Using CPU")
for i in range(n_iters):
# forward pass
train_output = net(train_data)
train_loss = criterion(train_output,train_labels)
# backward pass and optimization
train_loss.backward()
optimizer.step()

# Once every 100 iterations, print statistics
if i%100==0:
print("At iteration",i)
# compute the accuracy of the prediction
train_prediction = train_output.cpu().detach().argmax(dim=1)
train_accuracy = (train_prediction.numpy()==train_labels.cpu().numpy()).mean()
# Now for the validation set
val_output = net(val_data)
val_loss = criterion(val_output,val_labels)
# compute the accuracy of the prediction
val_prediction = val_output.cpu().detach().argmax(dim=1)
val_accuracy = (val_prediction.numpy()==val_labels.cpu().numpy()).mean()
print("Training loss :",train_loss.cpu().detach().numpy())
print("Training accuracy :",train_accuracy)
print("Validation loss :",val_loss.cpu().detach().numpy())
print("Validation accuracy :",val_accuracy)

net = net.cpu()

``````
``````

In [7]:

dataset = trainx,trainy,valx,valy

``````
``````

In [8]:

gpu = True
gpu = gpu and torch.cuda.is_available() # to know if you actually can use the GPU
begin = time.time()
training_routine(model1,dataset,10000,gpu)
end=time.time()

``````
``````

Using GPU
At iteration 0
Training loss : 0.6987106204032898
Training accuracy : 0.5003
Validation loss : 0.7039246559143066
Validation accuracy : 0.486
At iteration 100
Training loss : 0.677510678768158
Training accuracy : 0.4203
Validation loss : 0.6795798540115356
Validation accuracy : 0.402
At iteration 200
Training loss : 0.6638249158859253
Training accuracy : 0.5779
Validation loss : 0.6638974547386169
Validation accuracy : 0.562
At iteration 300
Training loss : 0.6521827578544617
Training accuracy : 0.6372
Validation loss : 0.6509451270103455
Validation accuracy : 0.608
At iteration 400
Training loss : 0.6414715051651001
Training accuracy : 0.6623
Validation loss : 0.6396153569221497
Validation accuracy : 0.65
At iteration 500
Training loss : 0.6312623620033264
Training accuracy : 0.6757
Validation loss : 0.6290354132652283
Validation accuracy : 0.672
At iteration 600
Training loss : 0.6213157176971436
Training accuracy : 0.6864
Validation loss : 0.6188277006149292
Validation accuracy : 0.698
At iteration 700
Training loss : 0.6115304827690125
Training accuracy : 0.6946
Validation loss : 0.6090017557144165
Validation accuracy : 0.708
At iteration 800
Training loss : 0.6015798449516296
Training accuracy : 0.7031
Validation loss : 0.5991077423095703
Validation accuracy : 0.714
At iteration 900
Training loss : 0.5912153124809265
Training accuracy : 0.71
Validation loss : 0.588752031326294
Validation accuracy : 0.716
At iteration 1000
Training loss : 0.5803276896476746
Training accuracy : 0.7187
Validation loss : 0.5778717994689941
Validation accuracy : 0.718
At iteration 1100
Training loss : 0.5689370036125183
Training accuracy : 0.7294
Validation loss : 0.5664529800415039
Validation accuracy : 0.724
At iteration 1200
Training loss : 0.5569323301315308
Training accuracy : 0.7387
Validation loss : 0.5542795658111572
Validation accuracy : 0.734
At iteration 1300
Training loss : 0.5442454814910889
Training accuracy : 0.7472
Validation loss : 0.5413668155670166
Validation accuracy : 0.74
At iteration 1400
Training loss : 0.530874490737915
Training accuracy : 0.757
Validation loss : 0.5276691317558289
Validation accuracy : 0.76
At iteration 1500
Training loss : 0.5168139338493347
Training accuracy : 0.7674
Validation loss : 0.5130881667137146
Validation accuracy : 0.766
At iteration 1600
Training loss : 0.5021050572395325
Training accuracy : 0.7788
Validation loss : 0.497661828994751
Validation accuracy : 0.772
At iteration 1700
Training loss : 0.48768314719200134
Training accuracy : 0.791
Validation loss : 0.4825124442577362
Validation accuracy : 0.788
At iteration 1800
Training loss : 0.4740116596221924
Training accuracy : 0.8022
Validation loss : 0.4681008458137512
Validation accuracy : 0.806
At iteration 1900
Training loss : 0.4608691334724426
Training accuracy : 0.8119
Validation loss : 0.4541640877723694
Validation accuracy : 0.81
At iteration 2000
Training loss : 0.4481540620326996
Training accuracy : 0.8197
Validation loss : 0.44063761830329895
Validation accuracy : 0.814
At iteration 2100
Training loss : 0.4358036518096924
Training accuracy : 0.8264
Validation loss : 0.4274945855140686
Validation accuracy : 0.822
At iteration 2200
Training loss : 0.42370855808258057
Training accuracy : 0.832
Validation loss : 0.41458141803741455
Validation accuracy : 0.834
At iteration 2300
Training loss : 0.411865770816803
Training accuracy : 0.8383
Validation loss : 0.401969850063324
Validation accuracy : 0.84
At iteration 2400
Training loss : 0.40023329854011536
Training accuracy : 0.8449
Validation loss : 0.3895937204360962
Validation accuracy : 0.84
At iteration 2500
Training loss : 0.3888087868690491
Training accuracy : 0.85
Validation loss : 0.3774207830429077
Validation accuracy : 0.854
At iteration 2600
Training loss : 0.3775351941585541
Training accuracy : 0.8573
Validation loss : 0.3655126392841339
Validation accuracy : 0.868
At iteration 2700
Training loss : 0.36636418104171753
Training accuracy : 0.8643
Validation loss : 0.35366711020469666
Validation accuracy : 0.872
At iteration 2800
Training loss : 0.3553362488746643
Training accuracy : 0.8725
Validation loss : 0.34201210737228394
Validation accuracy : 0.88
At iteration 2900
Training loss : 0.34448328614234924
Training accuracy : 0.8824
Validation loss : 0.3305943012237549
Validation accuracy : 0.88
At iteration 3000
Training loss : 0.3338693082332611
Training accuracy : 0.8905
Validation loss : 0.319495290517807
Validation accuracy : 0.894
At iteration 3100
Training loss : 0.3234810531139374
Training accuracy : 0.8972
Validation loss : 0.3086738884449005
Validation accuracy : 0.898
At iteration 3200
Training loss : 0.31330394744873047
Training accuracy : 0.903
Validation loss : 0.2982039749622345
Validation accuracy : 0.908
At iteration 3300
Training loss : 0.30343005061149597
Training accuracy : 0.9083
Validation loss : 0.28819984197616577
Validation accuracy : 0.916
At iteration 3400
Training loss : 0.29375043511390686
Training accuracy : 0.9161
Validation loss : 0.278478741645813
Validation accuracy : 0.924
At iteration 3500
Training loss : 0.2843717634677887
Training accuracy : 0.9227
Validation loss : 0.2690211832523346
Validation accuracy : 0.926
At iteration 3600
Training loss : 0.27524057030677795
Training accuracy : 0.9274
Validation loss : 0.25984010100364685
Validation accuracy : 0.928
At iteration 3700
Training loss : 0.26641568541526794
Training accuracy : 0.9318
Validation loss : 0.25092747807502747
Validation accuracy : 0.942
At iteration 3800
Training loss : 0.25782617926597595
Training accuracy : 0.936
Validation loss : 0.242435485124588
Validation accuracy : 0.944
At iteration 3900
Training loss : 0.24958781898021698
Training accuracy : 0.9397
Validation loss : 0.234419584274292
Validation accuracy : 0.946
At iteration 4000
Training loss : 0.24167028069496155
Training accuracy : 0.9442
Validation loss : 0.22682949900627136
Validation accuracy : 0.95
At iteration 4100
Training loss : 0.2340223342180252
Training accuracy : 0.9464
Validation loss : 0.2195090651512146
Validation accuracy : 0.956
At iteration 4200
Training loss : 0.22662721574306488
Training accuracy : 0.9478
Validation loss : 0.2125617414712906
Validation accuracy : 0.956
At iteration 4300
Training loss : 0.2196005880832672
Training accuracy : 0.9508
Validation loss : 0.20601828396320343
Validation accuracy : 0.956
At iteration 4400
Training loss : 0.2129788100719452
Training accuracy : 0.9531
Validation loss : 0.19988633692264557
Validation accuracy : 0.962
At iteration 4500
Training loss : 0.20669223368167877
Training accuracy : 0.9554
Validation loss : 0.19415733218193054
Validation accuracy : 0.966
At iteration 4600
Training loss : 0.20078632235527039
Training accuracy : 0.9574
Validation loss : 0.1888844072818756
Validation accuracy : 0.966
At iteration 4700
Training loss : 0.1952987015247345
Training accuracy : 0.9594
Validation loss : 0.1840348094701767
Validation accuracy : 0.968
At iteration 4800
Training loss : 0.19018003344535828
Training accuracy : 0.9613
Validation loss : 0.17957104742527008
Validation accuracy : 0.968
At iteration 4900
Training loss : 0.18544895946979523
Training accuracy : 0.963
Validation loss : 0.1754365712404251
Validation accuracy : 0.966
At iteration 5000
Training loss : 0.18104971945285797
Training accuracy : 0.9648
Validation loss : 0.17157526314258575
Validation accuracy : 0.966
At iteration 5100
Training loss : 0.17694155871868134
Training accuracy : 0.9659
Validation loss : 0.16793076694011688
Validation accuracy : 0.968
At iteration 5200
Training loss : 0.1731051802635193
Training accuracy : 0.9672
Validation loss : 0.1645321249961853
Validation accuracy : 0.97
At iteration 5300
Training loss : 0.16951316595077515
Training accuracy : 0.9676
Validation loss : 0.1613655537366867
Validation accuracy : 0.97
At iteration 5400
Training loss : 0.16613084077835083
Training accuracy : 0.9679
Validation loss : 0.1583479344844818
Validation accuracy : 0.97
At iteration 5500
Training loss : 0.16294611990451813
Training accuracy : 0.9692
Validation loss : 0.1555103212594986
Validation accuracy : 0.972
At iteration 5600
Training loss : 0.1599428504705429
Training accuracy : 0.9697
Validation loss : 0.1528404951095581
Validation accuracy : 0.974
At iteration 5700
Training loss : 0.1571037620306015
Training accuracy : 0.9703
Validation loss : 0.1503140926361084
Validation accuracy : 0.974
At iteration 5800
Training loss : 0.15441352128982544
Training accuracy : 0.9711
Validation loss : 0.1479216367006302
Validation accuracy : 0.97
At iteration 5900
Training loss : 0.15184079110622406
Training accuracy : 0.9717
Validation loss : 0.145643949508667
Validation accuracy : 0.968
At iteration 6000
Training loss : 0.14939665794372559
Training accuracy : 0.9717
Validation loss : 0.14346875250339508
Validation accuracy : 0.968
At iteration 6100
Training loss : 0.14706821739673615
Training accuracy : 0.9721
Validation loss : 0.14139312505722046
Validation accuracy : 0.968
At iteration 6200
Training loss : 0.14484672248363495
Training accuracy : 0.9724
Validation loss : 0.13942237198352814
Validation accuracy : 0.968
At iteration 6300
Training loss : 0.14271709322929382
Training accuracy : 0.9724
Validation loss : 0.13753719627857208
Validation accuracy : 0.968
At iteration 6400
Training loss : 0.14068393409252167
Training accuracy : 0.9729
Validation loss : 0.1357404738664627
Validation accuracy : 0.964
At iteration 6500
Training loss : 0.13873538374900818
Training accuracy : 0.9734
Validation loss : 0.1339964121580124
Validation accuracy : 0.964
At iteration 6600
Training loss : 0.13686779141426086
Training accuracy : 0.9737
Validation loss : 0.1323196142911911
Validation accuracy : 0.964
At iteration 6700
Training loss : 0.1350780576467514
Training accuracy : 0.9745
Validation loss : 0.13071401417255402
Validation accuracy : 0.964
At iteration 6800
Training loss : 0.1333550363779068
Training accuracy : 0.9745
Validation loss : 0.12917150557041168
Validation accuracy : 0.964
At iteration 6900
Training loss : 0.13169458508491516
Training accuracy : 0.9748
Validation loss : 0.12768438458442688
Validation accuracy : 0.966
At iteration 7000
Training loss : 0.13009630143642426
Training accuracy : 0.9757
Validation loss : 0.12625116109848022
Validation accuracy : 0.968
At iteration 7100
Training loss : 0.12855537235736847
Training accuracy : 0.976
Validation loss : 0.12486620247364044
Validation accuracy : 0.968
At iteration 7200
Training loss : 0.12706787884235382
Training accuracy : 0.976
Validation loss : 0.12352976948022842
Validation accuracy : 0.968
At iteration 7300
Training loss : 0.1256323903799057
Training accuracy : 0.9765
Validation loss : 0.12222747504711151
Validation accuracy : 0.968
At iteration 7400
Training loss : 0.12424390763044357
Training accuracy : 0.9767
Validation loss : 0.12096820771694183
Validation accuracy : 0.968
At iteration 7500
Training loss : 0.12289751321077347
Training accuracy : 0.9773
Validation loss : 0.1197456642985344
Validation accuracy : 0.968
At iteration 7600
Training loss : 0.1215951070189476
Training accuracy : 0.9775
Validation loss : 0.11856357008218765
Validation accuracy : 0.968
At iteration 7700
Training loss : 0.12033259123563766
Training accuracy : 0.9774
Validation loss : 0.11741720139980316
Validation accuracy : 0.968
At iteration 7800
Training loss : 0.11910898238420486
Training accuracy : 0.9774
Validation loss : 0.11630597710609436
Validation accuracy : 0.968
At iteration 7900
Training loss : 0.1179240271449089
Training accuracy : 0.9774
Validation loss : 0.11522701382637024
Validation accuracy : 0.968
At iteration 8000
Training loss : 0.11677491664886475
Training accuracy : 0.9775
Validation loss : 0.11418116837739944
Validation accuracy : 0.968
At iteration 8100
Training loss : 0.1156592071056366
Training accuracy : 0.9775
Validation loss : 0.11317238211631775
Validation accuracy : 0.968
At iteration 8200
Training loss : 0.11457314342260361
Training accuracy : 0.9775
Validation loss : 0.11218325793743134
Validation accuracy : 0.968
At iteration 8300
Training loss : 0.11351557821035385
Training accuracy : 0.9777
Validation loss : 0.11121871322393417
Validation accuracy : 0.968
At iteration 8400
Training loss : 0.11248715966939926
Training accuracy : 0.9782
Validation loss : 0.11028134822845459
Validation accuracy : 0.968
At iteration 8500
Training loss : 0.11148744821548462
Training accuracy : 0.9786
Validation loss : 0.10937269032001495
Validation accuracy : 0.968
At iteration 8600
Training loss : 0.1105133444070816
Training accuracy : 0.9786
Validation loss : 0.10849720239639282
Validation accuracy : 0.97
At iteration 8700
Training loss : 0.10956472158432007
Training accuracy : 0.9785
Validation loss : 0.10764889419078827
Validation accuracy : 0.97
At iteration 8800
Training loss : 0.10864181071519852
Training accuracy : 0.9787
Validation loss : 0.1068202406167984
Validation accuracy : 0.97
At iteration 8900
Training loss : 0.10774435102939606
Training accuracy : 0.9787
Validation loss : 0.10601337999105453
Validation accuracy : 0.97
At iteration 9000
Training loss : 0.10686799138784409
Training accuracy : 0.9788
Validation loss : 0.10522656887769699
Validation accuracy : 0.97
At iteration 9100
Training loss : 0.10601214319467545
Training accuracy : 0.979
Validation loss : 0.104451484978199
Validation accuracy : 0.972
At iteration 9200
Training loss : 0.10517784208059311
Training accuracy : 0.9791
Validation loss : 0.10369487851858139
Validation accuracy : 0.972
At iteration 9300
Training loss : 0.10436240583658218
Training accuracy : 0.9792
Validation loss : 0.10296136140823364
Validation accuracy : 0.972
At iteration 9400
Training loss : 0.103565514087677
Training accuracy : 0.9792
Validation loss : 0.10224708914756775
Validation accuracy : 0.974
At iteration 9500
Training loss : 0.10278775542974472
Training accuracy : 0.9793
Validation loss : 0.10154952853918076
Validation accuracy : 0.976
At iteration 9600
Training loss : 0.10202813893556595
Training accuracy : 0.9793
Validation loss : 0.10086623579263687
Validation accuracy : 0.976
At iteration 9700
Training loss : 0.10128509253263474
Training accuracy : 0.9792
Validation loss : 0.10019765049219131
Validation accuracy : 0.976
At iteration 9800
Training loss : 0.10055863112211227
Training accuracy : 0.9795
Validation loss : 0.09954356402158737
Validation accuracy : 0.978
At iteration 9900
Training loss : 0.09984857589006424
Training accuracy : 0.9796
Validation loss : 0.09890390187501907
Validation accuracy : 0.978

``````
``````

In [9]:

print("Training time :",end-begin)

``````
``````

Training time : 24.634289741516113

``````
``````

In [10]:

# Let's try with 3 hidden neurons.
model2 = generate_single_hidden_MLP(3)
training_routine(model2,dataset,10000,gpu)

``````
``````

Using GPU
At iteration 0
Training loss : 0.8042043447494507
Training accuracy : 0.5003
Validation loss : 0.8238264322280884
Validation accuracy : 0.486
At iteration 100
Training loss : 0.7185286283493042
Training accuracy : 0.5003
Validation loss : 0.7303446531295776
Validation accuracy : 0.486
At iteration 200
Training loss : 0.6989850401878357
Training accuracy : 0.3706
Validation loss : 0.7071495652198792
Validation accuracy : 0.346
At iteration 300
Training loss : 0.6923447251319885
Training accuracy : 0.387
Validation loss : 0.6988030672073364
Validation accuracy : 0.358
At iteration 400
Training loss : 0.6878220438957214
Training accuracy : 0.4534
Validation loss : 0.6936196684837341
Validation accuracy : 0.418
At iteration 500
Training loss : 0.682556688785553
Training accuracy : 0.5406
Validation loss : 0.6881617307662964
Validation accuracy : 0.512
At iteration 600
Training loss : 0.6763601899147034
Training accuracy : 0.5856
Validation loss : 0.6821454763412476
Validation accuracy : 0.566
At iteration 700
Training loss : 0.6690481305122375
Training accuracy : 0.6297
Validation loss : 0.6751475930213928
Validation accuracy : 0.604
At iteration 800
Training loss : 0.6603956818580627
Training accuracy : 0.6717
Validation loss : 0.6668155789375305
Validation accuracy : 0.652
At iteration 900
Training loss : 0.6502493023872375
Training accuracy : 0.7081
Validation loss : 0.6572006344795227
Validation accuracy : 0.676
At iteration 1000
Training loss : 0.6385756731033325
Training accuracy : 0.7383
Validation loss : 0.6457823514938354
Validation accuracy : 0.71
At iteration 1100
Training loss : 0.6254290342330933
Training accuracy : 0.7579
Validation loss : 0.6329132914543152
Validation accuracy : 0.734
At iteration 1200
Training loss : 0.6114764213562012
Training accuracy : 0.7647
Validation loss : 0.6191123127937317
Validation accuracy : 0.742
At iteration 1300
Training loss : 0.5972586274147034
Training accuracy : 0.7678
Validation loss : 0.6051972508430481
Validation accuracy : 0.748
At iteration 1400
Training loss : 0.5834115147590637
Training accuracy : 0.7713
Validation loss : 0.591754674911499
Validation accuracy : 0.752
At iteration 1500
Training loss : 0.5704270601272583
Training accuracy : 0.7727
Validation loss : 0.579190194606781
Validation accuracy : 0.76
At iteration 1600
Training loss : 0.5584383606910706
Training accuracy : 0.774
Validation loss : 0.5677558183670044
Validation accuracy : 0.764
At iteration 1700
Training loss : 0.5474631190299988
Training accuracy : 0.7773
Validation loss : 0.5574163198471069
Validation accuracy : 0.77
At iteration 1800
Training loss : 0.537483274936676
Training accuracy : 0.7793
Validation loss : 0.548203706741333
Validation accuracy : 0.768
At iteration 1900
Training loss : 0.5284117460250854
Training accuracy : 0.7812
Validation loss : 0.539940357208252
Validation accuracy : 0.766
At iteration 2000
Training loss : 0.5201975107192993
Training accuracy : 0.7829
Validation loss : 0.5325564742088318
Validation accuracy : 0.768
At iteration 2100
Training loss : 0.5127653479576111
Training accuracy : 0.7829
Validation loss : 0.5259811282157898
Validation accuracy : 0.768
At iteration 2200
Training loss : 0.5060506463050842
Training accuracy : 0.7832
Validation loss : 0.5201577544212341
Validation accuracy : 0.772
At iteration 2300
Training loss : 0.50001460313797
Training accuracy : 0.7834
Validation loss : 0.514948308467865
Validation accuracy : 0.772
At iteration 2400
Training loss : 0.4945991635322571
Training accuracy : 0.7834
Validation loss : 0.5102963447570801
Validation accuracy : 0.772
At iteration 2500
Training loss : 0.4897620677947998
Training accuracy : 0.7831
Validation loss : 0.5061904788017273
Validation accuracy : 0.772
At iteration 2600
Training loss : 0.4854539930820465
Training accuracy : 0.7829
Validation loss : 0.5025847554206848
Validation accuracy : 0.774
At iteration 2700
Training loss : 0.4816124141216278
Training accuracy : 0.7833
Validation loss : 0.49941909313201904
Validation accuracy : 0.774
At iteration 2800
Training loss : 0.4781963527202606
Training accuracy : 0.7827
Validation loss : 0.4966699182987213
Validation accuracy : 0.774
At iteration 2900
Training loss : 0.47516390681266785
Training accuracy : 0.7831
Validation loss : 0.4942916929721832
Validation accuracy : 0.772
At iteration 3000
Training loss : 0.4724833369255066
Training accuracy : 0.7828
Validation loss : 0.49222010374069214
Validation accuracy : 0.772
At iteration 3100
Training loss : 0.4701261818408966
Training accuracy : 0.7824
Validation loss : 0.4904025197029114
Validation accuracy : 0.772
At iteration 3200
Training loss : 0.4680247902870178
Training accuracy : 0.7817
Validation loss : 0.48880332708358765
Validation accuracy : 0.772
At iteration 3300
Training loss : 0.4661575257778168
Training accuracy : 0.7817
Validation loss : 0.4874080419540405
Validation accuracy : 0.772
At iteration 3400
Training loss : 0.4645048975944519
Training accuracy : 0.7809
Validation loss : 0.4861815273761749
Validation accuracy : 0.77
At iteration 3500
Training loss : 0.4630364179611206
Training accuracy : 0.7806
Validation loss : 0.48512333631515503
Validation accuracy : 0.768
At iteration 3600
Training loss : 0.46173372864723206
Training accuracy : 0.78
Validation loss : 0.48421892523765564
Validation accuracy : 0.766
At iteration 3700
Training loss : 0.460572212934494
Training accuracy : 0.7788
Validation loss : 0.48346036672592163
Validation accuracy : 0.764
At iteration 3800
Training loss : 0.4595365822315216
Training accuracy : 0.7787
Validation loss : 0.4828377068042755
Validation accuracy : 0.762
At iteration 3900
Training loss : 0.45861244201660156
Training accuracy : 0.7785
Validation loss : 0.4823111593723297
Validation accuracy : 0.762
At iteration 4000
Training loss : 0.4577837288379669
Training accuracy : 0.7787
Validation loss : 0.48185357451438904
Validation accuracy : 0.762
At iteration 4100
Training loss : 0.4570505917072296
Training accuracy : 0.7782
Validation loss : 0.4814351201057434
Validation accuracy : 0.764
At iteration 4200
Training loss : 0.4563976526260376
Training accuracy : 0.7771
Validation loss : 0.4810810983181
Validation accuracy : 0.766
At iteration 4300
Training loss : 0.455814927816391
Training accuracy : 0.7766
Validation loss : 0.4807562232017517
Validation accuracy : 0.768
At iteration 4400
Training loss : 0.45528584718704224
Training accuracy : 0.7769
Validation loss : 0.4804733395576477
Validation accuracy : 0.766
At iteration 4500
Training loss : 0.4548111855983734
Training accuracy : 0.777
Validation loss : 0.48023200035095215
Validation accuracy : 0.766
At iteration 4600
Training loss : 0.45438677072525024
Training accuracy : 0.7764
Validation loss : 0.48000213503837585
Validation accuracy : 0.766
At iteration 4700
Training loss : 0.45399174094200134
Training accuracy : 0.7762
Validation loss : 0.47981971502304077
Validation accuracy : 0.764
At iteration 4800
Training loss : 0.4536362886428833
Training accuracy : 0.7759
Validation loss : 0.4796644151210785
Validation accuracy : 0.762
At iteration 4900
Training loss : 0.4533211886882782
Training accuracy : 0.7759
Validation loss : 0.47953447699546814
Validation accuracy : 0.766
At iteration 5000
Training loss : 0.453032523393631
Training accuracy : 0.7755
Validation loss : 0.4794190227985382
Validation accuracy : 0.768
At iteration 5100
Training loss : 0.4527663588523865
Training accuracy : 0.776
Validation loss : 0.47931841015815735
Validation accuracy : 0.766
At iteration 5200
Training loss : 0.45251747965812683
Training accuracy : 0.7756
Validation loss : 0.47925683856010437
Validation accuracy : 0.764
At iteration 5300
Training loss : 0.4522855877876282
Training accuracy : 0.7753
Validation loss : 0.47922566533088684
Validation accuracy : 0.764
At iteration 5400
Training loss : 0.452070027589798
Training accuracy : 0.7756
Validation loss : 0.47918781638145447
Validation accuracy : 0.762
At iteration 5500
Training loss : 0.45187774300575256
Training accuracy : 0.7757
Validation loss : 0.479141503572464
Validation accuracy : 0.762
At iteration 5600
Training loss : 0.45170825719833374
Training accuracy : 0.7756
Validation loss : 0.47909456491470337
Validation accuracy : 0.76
At iteration 5700
Training loss : 0.45156118273735046
Training accuracy : 0.7756
Validation loss : 0.47905534505844116
Validation accuracy : 0.76
At iteration 5800
Training loss : 0.4514290392398834
Training accuracy : 0.7755
Validation loss : 0.4790177643299103
Validation accuracy : 0.76
At iteration 5900
Training loss : 0.45129701495170593
Training accuracy : 0.7752
Validation loss : 0.47899454832077026
Validation accuracy : 0.76
At iteration 6000
Training loss : 0.4511772096157074
Training accuracy : 0.7755
Validation loss : 0.47897839546203613
Validation accuracy : 0.76
At iteration 6100
Training loss : 0.45107242465019226
Training accuracy : 0.775
Validation loss : 0.47897568345069885
Validation accuracy : 0.758
At iteration 6200
Training loss : 0.4509792923927307
Training accuracy : 0.7752
Validation loss : 0.4789806306362152
Validation accuracy : 0.756
At iteration 6300
Training loss : 0.4508935511112213
Training accuracy : 0.7757
Validation loss : 0.47898975014686584
Validation accuracy : 0.758
At iteration 6400
Training loss : 0.4508166015148163
Training accuracy : 0.7756
Validation loss : 0.4789983630180359
Validation accuracy : 0.758
At iteration 6500
Training loss : 0.4507419466972351
Training accuracy : 0.7751
Validation loss : 0.47901594638824463
Validation accuracy : 0.758
At iteration 6600
Training loss : 0.450667142868042
Training accuracy : 0.775
Validation loss : 0.47904884815216064
Validation accuracy : 0.758
At iteration 6700
Training loss : 0.4505937099456787
Training accuracy : 0.7753
Validation loss : 0.4791039824485779
Validation accuracy : 0.76
At iteration 6800
Training loss : 0.45052626729011536
Training accuracy : 0.7751
Validation loss : 0.4791721701622009
Validation accuracy : 0.76
At iteration 6900
Training loss : 0.4504624009132385
Training accuracy : 0.7754
Validation loss : 0.47922977805137634
Validation accuracy : 0.76
At iteration 7000
Training loss : 0.45040565729141235
Training accuracy : 0.7753
Validation loss : 0.4792513847351074
Validation accuracy : 0.76
At iteration 7100
Training loss : 0.4503517150878906
Training accuracy : 0.7749
Validation loss : 0.47927895188331604
Validation accuracy : 0.762
At iteration 7200
Training loss : 0.4502948224544525
Training accuracy : 0.7754
Validation loss : 0.47932857275009155
Validation accuracy : 0.762
At iteration 7300
Training loss : 0.4502376317977905
Training accuracy : 0.7754
Validation loss : 0.47938016057014465
Validation accuracy : 0.76
At iteration 7400
Training loss : 0.45018404722213745
Training accuracy : 0.7758
Validation loss : 0.47943389415740967
Validation accuracy : 0.76
At iteration 7500
Training loss : 0.450132817029953
Training accuracy : 0.7757
Validation loss : 0.4794860780239105
Validation accuracy : 0.76
At iteration 7600
Training loss : 0.4500790536403656
Training accuracy : 0.7758
Validation loss : 0.4795333743095398
Validation accuracy : 0.76
At iteration 7700
Training loss : 0.45002856850624084
Training accuracy : 0.7761
Validation loss : 0.4795854985713959
Validation accuracy : 0.76
At iteration 7800
Training loss : 0.4499852955341339
Training accuracy : 0.7761
Validation loss : 0.4796288013458252
Validation accuracy : 0.758
At iteration 7900
Training loss : 0.4499473571777344
Training accuracy : 0.7763
Validation loss : 0.4796689450740814
Validation accuracy : 0.758
At iteration 8000
Training loss : 0.4499092698097229
Training accuracy : 0.7763
Validation loss : 0.47970810532569885
Validation accuracy : 0.758
At iteration 8100
Training loss : 0.44987720251083374
Training accuracy : 0.7761
Validation loss : 0.4797445237636566
Validation accuracy : 0.758
At iteration 8200
Training loss : 0.44984671473503113
Training accuracy : 0.7761
Validation loss : 0.47978121042251587
Validation accuracy : 0.758
At iteration 8300
Training loss : 0.44981464743614197
Training accuracy : 0.7761
Validation loss : 0.4798373579978943
Validation accuracy : 0.758
At iteration 8400
Training loss : 0.4497825801372528
Training accuracy : 0.7761
Validation loss : 0.4798918664455414
Validation accuracy : 0.758
At iteration 8500
Training loss : 0.4497530162334442
Training accuracy : 0.7763
Validation loss : 0.4799533188343048
Validation accuracy : 0.76
At iteration 8600
Training loss : 0.44972261786460876
Training accuracy : 0.7761
Validation loss : 0.4800124168395996
Validation accuracy : 0.762
At iteration 8700
Training loss : 0.44969257712364197
Training accuracy : 0.7762
Validation loss : 0.4800698161125183
Validation accuracy : 0.762
At iteration 8800
Training loss : 0.44966450333595276
Training accuracy : 0.7762
Validation loss : 0.48012039065361023
Validation accuracy : 0.76
At iteration 8900
Training loss : 0.44963324069976807
Training accuracy : 0.7762
Validation loss : 0.48017287254333496
Validation accuracy : 0.758
At iteration 9000
Training loss : 0.4496088922023773
Training accuracy : 0.7762
Validation loss : 0.48021143674850464
Validation accuracy : 0.756
At iteration 9100
Training loss : 0.44958609342575073
Training accuracy : 0.7763
Validation loss : 0.4802546799182892
Validation accuracy : 0.758
At iteration 9200
Training loss : 0.4495634138584137
Training accuracy : 0.7765
Validation loss : 0.4802907109260559
Validation accuracy : 0.76
At iteration 9300
Training loss : 0.4495399296283722
Training accuracy : 0.7764
Validation loss : 0.48030900955200195
Validation accuracy : 0.76
At iteration 9400
Training loss : 0.44951802492141724
Training accuracy : 0.7766
Validation loss : 0.4803306460380554
Validation accuracy : 0.76
At iteration 9500
Training loss : 0.4495003819465637
Training accuracy : 0.7769
Validation loss : 0.48035526275634766
Validation accuracy : 0.758
At iteration 9600
Training loss : 0.4494825303554535
Training accuracy : 0.7771
Validation loss : 0.48038244247436523
Validation accuracy : 0.758
At iteration 9700
Training loss : 0.44945982098579407
Training accuracy : 0.7772
Validation loss : 0.48043206334114075
Validation accuracy : 0.758
At iteration 9800
Training loss : 0.4494328498840332
Training accuracy : 0.7772
Validation loss : 0.48048070073127747
Validation accuracy : 0.758
At iteration 9900
Training loss : 0.4494076669216156
Training accuracy : 0.7773
Validation loss : 0.48052552342414856
Validation accuracy : 0.758

``````
``````

In [11]:

out = model2(testx).argmax(dim=1).detach().numpy()
green = testx.numpy()[np.where(out==1)]
red = testx.numpy()[np.where(out==0)]
print(green.shape,red.shape)

``````
``````

(288, 2) (212, 2)

``````
``````

In [12]:

def print_model(model,datapoints):
out = model(datapoints).argmax(dim=1).detach().numpy()
green = datapoints.numpy()[np.where(out==1)]
red = datapoints.numpy()[np.where(out==0)]

circle1 = plt.Circle((0, 0), 1, color='y')
circle2 = plt.Circle((0, 0), 1, color='b',fill=False)

fig, ax = plt.subplots() # note we must use plt.subplots, not plt.subplot
# (or if you have an existing figure)
# fig = plt.gcf()
# ax = fig.gca()
plt.xlim((-2,2))
plt.ylim((-2,2))

pos_values = plt.scatter(x=green[:,0],y=green[:,1], color='g',)
neg_values = plt.scatter(x=red[:,0],y=red[:,1], color='r',)

``````
``````

In [13]:

print_model(model1,testx)

``````
``````

``````
``````

In [14]:

print_model(model2,testx)

``````
``````

``````
``````

In [15]:

model3 = generate_single_hidden_MLP(2)
training_routine(model3,dataset,10000,gpu)

``````
``````

Using GPU
At iteration 0
Training loss : 0.757908821105957
Training accuracy : 0.5003
Validation loss : 0.7690756320953369
Validation accuracy : 0.486
At iteration 100
Training loss : 0.7138646841049194
Training accuracy : 0.5003
Validation loss : 0.7210879921913147
Validation accuracy : 0.486
At iteration 200
Training loss : 0.6944175958633423
Training accuracy : 0.5003
Validation loss : 0.6991903781890869
Validation accuracy : 0.486
At iteration 300
Training loss : 0.6836276054382324
Training accuracy : 0.5003
Validation loss : 0.6868808269500732
Validation accuracy : 0.486
At iteration 400
Training loss : 0.6757462024688721
Training accuracy : 0.5003
Validation loss : 0.6780419945716858
Validation accuracy : 0.486
At iteration 500
Training loss : 0.6687736511230469
Training accuracy : 0.5142
Validation loss : 0.6704592108726501
Validation accuracy : 0.49
At iteration 600
Training loss : 0.6619855761528015
Training accuracy : 0.6194
Validation loss : 0.663285493850708
Validation accuracy : 0.604
At iteration 700
Training loss : 0.6551350951194763
Training accuracy : 0.6494
Validation loss : 0.6561704874038696
Validation accuracy : 0.626
At iteration 800
Training loss : 0.6481907367706299
Training accuracy : 0.6631
Validation loss : 0.6490899324417114
Validation accuracy : 0.642
At iteration 900
Training loss : 0.6411333084106445
Training accuracy : 0.6722
Validation loss : 0.6420141458511353
Validation accuracy : 0.652
At iteration 1000
Training loss : 0.6339496374130249
Training accuracy : 0.6807
Validation loss : 0.6349431276321411
Validation accuracy : 0.658
At iteration 1100
Training loss : 0.6266373991966248
Training accuracy : 0.6901
Validation loss : 0.6278645992279053
Validation accuracy : 0.668
At iteration 1200
Training loss : 0.6191753149032593
Training accuracy : 0.698
Validation loss : 0.6208072304725647
Validation accuracy : 0.674
At iteration 1300
Training loss : 0.6114997267723083
Training accuracy : 0.7033
Validation loss : 0.6136322021484375
Validation accuracy : 0.678
At iteration 1400
Training loss : 0.6035844683647156
Training accuracy : 0.7107
Validation loss : 0.6063424944877625
Validation accuracy : 0.688
At iteration 1500
Training loss : 0.5954177975654602
Training accuracy : 0.7153
Validation loss : 0.5988637804985046
Validation accuracy : 0.69
At iteration 1600
Training loss : 0.5870065689086914
Training accuracy : 0.721
Validation loss : 0.5911092758178711
Validation accuracy : 0.698
At iteration 1700
Training loss : 0.5783757567405701
Training accuracy : 0.7266
Validation loss : 0.5830110907554626
Validation accuracy : 0.71
At iteration 1800
Training loss : 0.5696817636489868
Training accuracy : 0.7345
Validation loss : 0.5746251344680786
Validation accuracy : 0.718
At iteration 1900
Training loss : 0.5610436201095581
Training accuracy : 0.7402
Validation loss : 0.5661273002624512
Validation accuracy : 0.73
At iteration 2000
Training loss : 0.5525501370429993
Training accuracy : 0.7451
Validation loss : 0.5577940344810486
Validation accuracy : 0.736
At iteration 2100
Training loss : 0.5443151593208313
Training accuracy : 0.7512
Validation loss : 0.5496269464492798
Validation accuracy : 0.748
At iteration 2200
Training loss : 0.5363830924034119
Training accuracy : 0.757
Validation loss : 0.5417766571044922
Validation accuracy : 0.752
At iteration 2300
Training loss : 0.5288269519805908
Training accuracy : 0.7627
Validation loss : 0.5342857241630554
Validation accuracy : 0.752
At iteration 2400
Training loss : 0.5217110514640808
Training accuracy : 0.7678
Validation loss : 0.5271515250205994
Validation accuracy : 0.76
At iteration 2500
Training loss : 0.5150945782661438
Training accuracy : 0.7723
Validation loss : 0.5204028487205505
Validation accuracy : 0.766
At iteration 2600
Training loss : 0.5089967846870422
Training accuracy : 0.7776
Validation loss : 0.5140964984893799
Validation accuracy : 0.77
At iteration 2700
Training loss : 0.503364622592926
Training accuracy : 0.7826
Validation loss : 0.5082263946533203
Validation accuracy : 0.774
At iteration 2800
Training loss : 0.49825260043144226
Training accuracy : 0.7857
Validation loss : 0.5027372241020203
Validation accuracy : 0.78
At iteration 2900
Training loss : 0.4936380386352539
Training accuracy : 0.7862
Validation loss : 0.4977012276649475
Validation accuracy : 0.786
At iteration 3000
Training loss : 0.4894702732563019
Training accuracy : 0.7861
Validation loss : 0.49305543303489685
Validation accuracy : 0.782
At iteration 3100
Training loss : 0.4856936037540436
Training accuracy : 0.7869
Validation loss : 0.4888104498386383
Validation accuracy : 0.782
At iteration 3200
Training loss : 0.4822947382926941
Training accuracy : 0.7868
Validation loss : 0.4849281311035156
Validation accuracy : 0.786
At iteration 3300
Training loss : 0.47922685742378235
Training accuracy : 0.7873
Validation loss : 0.48135992884635925
Validation accuracy : 0.784
At iteration 3400
Training loss : 0.47646695375442505
Training accuracy : 0.7864
Validation loss : 0.47807401418685913
Validation accuracy : 0.784
At iteration 3500
Training loss : 0.47399234771728516
Training accuracy : 0.7874
Validation loss : 0.4750390648841858
Validation accuracy : 0.788
At iteration 3600
Training loss : 0.4717704653739929
Training accuracy : 0.7868
Validation loss : 0.4723238945007324
Validation accuracy : 0.788
At iteration 3700
Training loss : 0.4697822332382202
Training accuracy : 0.7871
Validation loss : 0.4698861241340637
Validation accuracy : 0.788
At iteration 3800
Training loss : 0.46801266074180603
Training accuracy : 0.7863
Validation loss : 0.46771350502967834
Validation accuracy : 0.788
At iteration 3900
Training loss : 0.4664306640625
Training accuracy : 0.7868
Validation loss : 0.4657498002052307
Validation accuracy : 0.79
At iteration 4000
Training loss : 0.46500927209854126
Training accuracy : 0.7871
Validation loss : 0.4639551639556885
Validation accuracy : 0.79
At iteration 4100
Training loss : 0.4637335538864136
Training accuracy : 0.7875
Validation loss : 0.4623580574989319
Validation accuracy : 0.79
At iteration 4200
Training loss : 0.46258631348609924
Training accuracy : 0.7874
Validation loss : 0.46093276143074036
Validation accuracy : 0.794
At iteration 4300
Training loss : 0.461558997631073
Training accuracy : 0.7875
Validation loss : 0.4596591591835022
Validation accuracy : 0.794
At iteration 4400
Training loss : 0.46064233779907227
Training accuracy : 0.7879
Validation loss : 0.458560049533844
Validation accuracy : 0.796
At iteration 4500
Training loss : 0.4598313570022583
Training accuracy : 0.7881
Validation loss : 0.4575856626033783
Validation accuracy : 0.796
At iteration 4600
Training loss : 0.45910653471946716
Training accuracy : 0.7874
Validation loss : 0.4567101001739502
Validation accuracy : 0.794
At iteration 4700
Training loss : 0.45845508575439453
Training accuracy : 0.7867
Validation loss : 0.4559449553489685
Validation accuracy : 0.794
At iteration 4800
Training loss : 0.4578685164451599
Training accuracy : 0.786
Validation loss : 0.45528444647789
Validation accuracy : 0.796
At iteration 4900
Training loss : 0.45733726024627686
Training accuracy : 0.7856
Validation loss : 0.454721063375473
Validation accuracy : 0.796
At iteration 5000
Training loss : 0.4568570852279663
Training accuracy : 0.7851
Validation loss : 0.45425066351890564
Validation accuracy : 0.796
At iteration 5100
Training loss : 0.45642563700675964
Training accuracy : 0.7847
Validation loss : 0.4538639187812805
Validation accuracy : 0.792
At iteration 5200
Training loss : 0.45603397488594055
Training accuracy : 0.7843
Validation loss : 0.45351094007492065
Validation accuracy : 0.79
At iteration 5300
Training loss : 0.45567500591278076
Training accuracy : 0.7841
Validation loss : 0.453240305185318
Validation accuracy : 0.79
At iteration 5400
Training loss : 0.45534569025039673
Training accuracy : 0.7836
Validation loss : 0.45302337408065796
Validation accuracy : 0.79
At iteration 5500
Training loss : 0.45504093170166016
Training accuracy : 0.7839
Validation loss : 0.45283132791519165
Validation accuracy : 0.79
At iteration 5600
Training loss : 0.45476263761520386
Training accuracy : 0.7836
Validation loss : 0.45265713334083557
Validation accuracy : 0.79
At iteration 5700
Training loss : 0.4545077979564667
Training accuracy : 0.7835
Validation loss : 0.45250192284584045
Validation accuracy : 0.788
At iteration 5800
Training loss : 0.4542773962020874
Training accuracy : 0.7835
Validation loss : 0.45235541462898254
Validation accuracy : 0.79
At iteration 5900
Training loss : 0.4540669023990631
Training accuracy : 0.7828
Validation loss : 0.4522089660167694
Validation accuracy : 0.79
At iteration 6000
Training loss : 0.45387357473373413
Training accuracy : 0.7823
Validation loss : 0.4520772397518158
Validation accuracy : 0.79
At iteration 6100
Training loss : 0.4536908268928528
Training accuracy : 0.7819
Validation loss : 0.451982706785202
Validation accuracy : 0.79
At iteration 6200
Training loss : 0.4535170793533325
Training accuracy : 0.7818
Validation loss : 0.45192718505859375
Validation accuracy : 0.79
At iteration 6300
Training loss : 0.45335546135902405
Training accuracy : 0.7822
Validation loss : 0.45187908411026
Validation accuracy : 0.792
At iteration 6400
Training loss : 0.4532018005847931
Training accuracy : 0.7824
Validation loss : 0.45187172293663025
Validation accuracy : 0.792
At iteration 6500
Training loss : 0.4530511200428009
Training accuracy : 0.7826
Validation loss : 0.451901376247406
Validation accuracy : 0.792
At iteration 6600
Training loss : 0.45290425419807434
Training accuracy : 0.7825
Validation loss : 0.45195072889328003
Validation accuracy : 0.792
At iteration 6700
Training loss : 0.452772319316864
Training accuracy : 0.7824
Validation loss : 0.45195135474205017
Validation accuracy : 0.792
At iteration 6800
Training loss : 0.4526498019695282
Training accuracy : 0.7821
Validation loss : 0.4519382119178772
Validation accuracy : 0.792
At iteration 6900
Training loss : 0.4525333046913147
Training accuracy : 0.7821
Validation loss : 0.4519408941268921
Validation accuracy : 0.792
At iteration 7000
Training loss : 0.4524204730987549
Training accuracy : 0.7822
Validation loss : 0.4519646465778351
Validation accuracy : 0.79
At iteration 7100
Training loss : 0.45229971408843994
Training accuracy : 0.7825
Validation loss : 0.45204782485961914
Validation accuracy : 0.79
At iteration 7200
Training loss : 0.45219022035598755
Training accuracy : 0.7822
Validation loss : 0.4521273374557495
Validation accuracy : 0.792
At iteration 7300
Training loss : 0.45208197832107544
Training accuracy : 0.7821
Validation loss : 0.4522210657596588
Validation accuracy : 0.792
At iteration 7400
Training loss : 0.4519691467285156
Training accuracy : 0.7819
Validation loss : 0.4523455798625946
Validation accuracy : 0.792
At iteration 7500
Training loss : 0.45185211300849915
Training accuracy : 0.7824
Validation loss : 0.4524751603603363
Validation accuracy : 0.79
At iteration 7600
Training loss : 0.45173853635787964
Training accuracy : 0.7827
Validation loss : 0.45261380076408386
Validation accuracy : 0.79
At iteration 7700
Training loss : 0.4516219198703766
Training accuracy : 0.7826
Validation loss : 0.4527660012245178
Validation accuracy : 0.788
At iteration 7800
Training loss : 0.45150718092918396
Training accuracy : 0.7821
Validation loss : 0.45293357968330383
Validation accuracy : 0.788
At iteration 7900
Training loss : 0.45139312744140625
Training accuracy : 0.782
Validation loss : 0.453081876039505
Validation accuracy : 0.786
At iteration 8000
Training loss : 0.45127978920936584
Training accuracy : 0.7822
Validation loss : 0.45318156480789185
Validation accuracy : 0.788
At iteration 8100
Training loss : 0.4511800706386566
Training accuracy : 0.7825
Validation loss : 0.45325249433517456
Validation accuracy : 0.788
At iteration 8200
Training loss : 0.45108383893966675
Training accuracy : 0.7826
Validation loss : 0.45331111550331116
Validation accuracy : 0.788
At iteration 8300
Training loss : 0.45099425315856934
Training accuracy : 0.7827
Validation loss : 0.4533732533454895
Validation accuracy : 0.788
At iteration 8400
Training loss : 0.45091482996940613
Training accuracy : 0.7828
Validation loss : 0.4534011781215668
Validation accuracy : 0.788
At iteration 8500
Training loss : 0.45083779096603394
Training accuracy : 0.7826
Validation loss : 0.45342811942100525
Validation accuracy : 0.788
At iteration 8600
Training loss : 0.45076048374176025
Training accuracy : 0.7826
Validation loss : 0.4534529745578766
Validation accuracy : 0.786
At iteration 8700
Training loss : 0.4506836533546448
Training accuracy : 0.7825
Validation loss : 0.45347943902015686
Validation accuracy : 0.786
At iteration 8800
Training loss : 0.4506029784679413
Training accuracy : 0.7826
Validation loss : 0.4535260498523712
Validation accuracy : 0.784
At iteration 8900
Training loss : 0.4505268633365631
Training accuracy : 0.7829
Validation loss : 0.45358648896217346
Validation accuracy : 0.78
At iteration 9000
Training loss : 0.45046234130859375
Training accuracy : 0.7828
Validation loss : 0.4536360502243042
Validation accuracy : 0.78
At iteration 9100
Training loss : 0.450396865606308
Training accuracy : 0.7831
Validation loss : 0.4536837339401245
Validation accuracy : 0.78
At iteration 9200
Training loss : 0.45033028721809387
Training accuracy : 0.7828
Validation loss : 0.4537392556667328
Validation accuracy : 0.78
At iteration 9300
Training loss : 0.4502713978290558
Training accuracy : 0.7831
Validation loss : 0.4537900686264038
Validation accuracy : 0.78
At iteration 9400
Training loss : 0.4502134323120117
Training accuracy : 0.7831
Validation loss : 0.4538417160511017
Validation accuracy : 0.78
At iteration 9500
Training loss : 0.4501570165157318
Training accuracy : 0.7832
Validation loss : 0.453892320394516
Validation accuracy : 0.78
At iteration 9600
Training loss : 0.4501034617424011
Training accuracy : 0.7832
Validation loss : 0.45394226908683777
Validation accuracy : 0.78
At iteration 9700
Training loss : 0.4500454068183899
Training accuracy : 0.783
Validation loss : 0.4540122449398041
Validation accuracy : 0.78
At iteration 9800
Training loss : 0.44999000430107117
Training accuracy : 0.7828
Validation loss : 0.4540848433971405
Validation accuracy : 0.782
At iteration 9900
Training loss : 0.44993647933006287
Training accuracy : 0.7824
Validation loss : 0.45415589213371277
Validation accuracy : 0.782

``````
``````

In [16]:

print_model(model3,testx)

``````
``````

``````