03_Linear_Regression_Model


In [8]:
import torch
from torch.autograd import Variable
from torch import nn

In [9]:
import matplotlib.pyplot as plt
%matplotlib inline
torch.manual_seed(1)


Out[9]:
<torch._C.Generator at 0x7fe89848d240>

Prepare Data


In [41]:
# X and Y training data

x_train = torch.Tensor([[1], [2], [3]])
y_train = torch.Tensor([[1], [2], [3]])

# x_train = torch.Tensor([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
#                     [9.779], [6.182], [7.59], [2.167], [7.042], 
#                     [10.791], [5.313], [7.997], [3.1]])

# y_train = torch.Tensor([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
#                     [3.366], [2.596], [2.53], [1.221], [2.827], 
#                     [3.465], [1.65], [2.904], [1.3]])

x, y = Variable(x_train), Variable(y_train)

plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()


Naive Linear Regression Model

Define Linear Regression Model


In [13]:
W = Variable(torch.rand(1,1))
x, W, x.mm(W)


Out[13]:
(Variable containing:
  1
  2
  3
 [torch.FloatTensor of size 3x1], Variable containing:
  0.7203
 [torch.FloatTensor of size 1x1], Variable containing:
  0.7203
  1.4406
  2.1610
 [torch.FloatTensor of size 3x1])

Training Linear Regression Model


In [17]:
plt.ion()   # something about plotting

cost_func = torch.nn.MSELoss()                 # Our mean squared Cost function 
lr = 0.01

for step in range(300):

    prediction = x.mm(W)                       # Our Model XW
    cost = cost_func(prediction, y)            # must be (1. prediction, 2. training target y) 
    gradient = (prediction-y).view(-1).dot(x.view(-1)) / len(x)  # Compute Gradient of cost w.r.t W (dCost/dW) 
    W -= lr * gradient                         # Update weight parameter with learning lr
    

    if step % 5 == 0:
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.title('cost=%.4f, w=%.4f, grad=%.4f' % (cost.data[0], W.data[0][0], gradient.data[0]), fontdict={'size': 20} )
        plt.show()
        plt.pause(0.1)
        
print('Linear Model Optimization is Done!')

plt.ioff()


Linear Model Optimization is Done!

Prediction (Test)


In [34]:
x_test = Variable(torch.Tensor([[5]]))
y_test = x.mm(W)
print(y)


Variable containing:
 1
 2
 3
[torch.FloatTensor of size 3x1]

Linear Regression Model w/ nn Module

Define Linear Regression Model w/ nn Module


In [30]:
model = nn.Linear(1, 1, bias=True)    # Our Model XW+b
cost_func = nn.MSELoss()        # Our mean squared Cost function 

print(model)  # model
model.weight, model.bias


Linear (1 -> 1)
Out[30]:
(Parameter containing:
  0.8711
 [torch.FloatTensor of size 1x1], Parameter containing:
 1.00000e-02 *
   7.7633
 [torch.FloatTensor of size 1])

Training Your Model w/ optim Module


In [31]:
plt.ion()   # something about plotting

optimizer = torch.optim.SGD(model.parameters(), lr= 0.01)

for step in range(300):
    
    prediction = model(x)               # input x and predict based on x
    cost = cost_func(prediction, y)     # must be (1. prediction, 2. training target y)

    optimizer.zero_grad()               # clear gradients for next traing
    cost.backward()                     # compute gradient value of parameters
    optimizer.step()                    # update gradients


    #print ('dL/dw: ', model.weight.grad) 
    #print ('dL/db: ', model.bias.grad)
    
    if step % 5 == 0:
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.title('cost=%.4f, w=%.4f, b=%.4f' % (cost.data[0], model.weight.data[0][0],model.bias.data[0]), fontdict={'size': 20} )
        plt.show()
        plt.pause(0.1)
        
print('Linear Model Optimization is Done!')

plt.ioff()


Linear Model Optimization is Done!

In [33]:
x_test = Variable(torch.Tensor([[7]]))
y_test = model(x_test)

print('input: %.4f, output: %.4f' % (x_test.data[0][0], y_test.data[0][0]) )


input: 7.0000, output: 6.8863

Is "nn.MSELoss()" Convex Cost Function?


In [44]:
W_val, cost_val = [], []

for i in range(-30, 51):
    W = i*0.1
    model.weight.data.fill_(W)
    cost = cost_func(model(x),y)
    
    #print('{:.2f}, {:.2f}'.format(W, cost.data[0]))
    W_val.append(W)
    cost_val.append(cost.data[0])

# ------------------------------------------ #
plt.plot(W_val, cost_val, 'ro')
plt.ylabel('Cost(W)')
plt.xlabel('W')
plt.show()


Multivariate Linear Regression


In [45]:
import numpy as np

Loading Multivariate Data


In [47]:
xy = np.loadtxt('data-01-test-score.csv', delimiter=',', dtype=np.float32)
x_data = xy[:, 0:-1]
y_data = xy[:, [-1]]

# Make sure the shape and data are OK
print('shape: ', x_data.shape, '\nlength:', len(x_data), '\n', x_data )
print('shape: ', y_data.shape, '\nlength:', len(y_data), '\n', y_data )

x, y = Variable(torch.from_numpy(x_data)), Variable(torch.from_numpy(y_data))


shape:  (25, 3) 
length: 25 
 [[  73.   80.   75.]
 [  93.   88.   93.]
 [  89.   91.   90.]
 [  96.   98.  100.]
 [  73.   66.   70.]
 [  53.   46.   55.]
 [  69.   74.   77.]
 [  47.   56.   60.]
 [  87.   79.   90.]
 [  79.   70.   88.]
 [  69.   70.   73.]
 [  70.   65.   74.]
 [  93.   95.   91.]
 [  79.   80.   73.]
 [  70.   73.   78.]
 [  93.   89.   96.]
 [  78.   75.   68.]
 [  81.   90.   93.]
 [  88.   92.   86.]
 [  78.   83.   77.]
 [  82.   86.   90.]
 [  86.   82.   89.]
 [  78.   83.   85.]
 [  76.   83.   71.]
 [  96.   93.   95.]]
shape:  (25, 1) 
length: 25 
 [[ 152.]
 [ 185.]
 [ 180.]
 [ 196.]
 [ 142.]
 [ 101.]
 [ 149.]
 [ 115.]
 [ 175.]
 [ 164.]
 [ 141.]
 [ 141.]
 [ 184.]
 [ 152.]
 [ 148.]
 [ 192.]
 [ 147.]
 [ 183.]
 [ 177.]
 [ 159.]
 [ 177.]
 [ 175.]
 [ 175.]
 [ 149.]
 [ 192.]]

Define Multivariabe Linear Regression Model


In [54]:
# Our hypothesis XW+b
mv_model = nn.Linear(3, 1, bias=True)

print( mv_model )
print( 'weight: ', mv_model.weight ) 
print( 'bias: ', mv_model.bias )


Linear (3 -> 1)
weight:  Parameter containing:
 0.0283 -0.3413 -0.0653
[torch.FloatTensor of size 1x3]

bias:  Parameter containing:
 0.4366
[torch.FloatTensor of size 1]

Train Your Model


In [57]:
# cost criterion
cost_func = nn.MSELoss()

# Minimize
optimizer = torch.optim.SGD(mv_model.parameters(), lr=1e-5)

# Train the model
for step in range(2001):
    optimizer.zero_grad()
    
    # Our model
    prediction = mv_model(x)
    cost = cost_func(prediction, y)
    cost.backward()    
    optimizer.step()

    if step % 50 == 0:
        print(step, "Cost: ", cost.data.numpy(), "\nPrediction:\n", prediction.data.t().numpy())


0 Cost:  [ 6.7285018] 
Prediction:
 [[ 152.59234619  185.16094971  181.41368103  198.54983521  140.93930054
   105.80461121  149.7220459   112.43860626  174.74307251  164.29562378
   143.77297974  142.9732666   186.40313721  153.69979858  150.79641724
   188.68569946  145.73933411  179.80810547  177.28311157  158.65615845
   175.40196228  174.60574341  166.80606079  151.49542236  191.11352539]]
50 Cost:  [ 6.71423769] 
Prediction:
 [[ 152.59536743  185.15403748  181.41195679  198.55247498  140.92765808
   105.80175781  149.73368835  112.46109009  174.73970032  164.30154419
   143.77641296  142.97370911  186.39653015  153.68708801  150.80705261
   188.68397522  145.71842957  179.82743835  177.2776947   158.65327454
   175.41252136  174.60438538  166.81573486  151.48733521  191.10588074]]
100 Cost:  [ 6.70036125] 
Prediction:
 [[ 152.59838867  185.1472168   181.41026306  198.55508423  140.9161377
   105.79888916  149.74520874  112.48329163  174.73631287  164.30731201
   143.77978516  142.97412109  186.39001465  153.67456055  150.8175354
   188.68225098  145.6978302   179.84649658  177.27238464  158.65048218
   175.42294312  174.60298157  166.82527161  151.47943115  191.09832764]]
150 Cost:  [ 6.68683958] 
Prediction:
 [[ 152.60145569  185.14047241  181.40861511  198.55769348  140.90480042
   105.79602051  149.75656128  112.50518799  174.73292542  164.31291199
   143.78312683  142.97450256  186.38363647  153.6622467   150.8278656
   188.68055725  145.67756653  179.86534119  177.26721191  158.64776611
   175.43321228  174.60159302  166.83470154  151.47169495  191.09091187]]
200 Cost:  [ 6.67368031] 
Prediction:
 [[ 152.60447693  185.13378906  181.40701294  198.5602417   140.89353943
   105.79315186  149.76776123  112.52680206  174.72952271  164.31835938
   143.78642273  142.974823    186.37734985  153.65014648  150.83805847
   188.67883301  145.65757751  179.88394165  177.262146    158.64511108
   175.44335938  174.60018921  166.84402466  151.46414185  191.08355713]]
250 Cost:  [ 6.66088295] 
Prediction:
 [[ 152.60752869  185.12715149  181.40544128  198.5627594   140.88243103
   105.79026794  149.77882385  112.54812622  174.72612     164.32362366
   143.78965759  142.9750824   186.37120056  153.63825989  150.84809875
   188.67710876  145.63789368  179.90228271  177.25717163  158.64256287
   175.45335388  174.59877014  166.85321045  151.45677185  191.0763092 ]]
300 Cost:  [ 6.64841604] 
Prediction:
 [[ 152.61053467  185.12057495  181.40390015  198.56524658  140.87142944
   105.78736877  149.78973389  112.56915283  174.72270203  164.32873535
   143.79283142  142.97531128  186.36514282  153.62652588  150.85797119
   188.67536926  145.61849976  179.92037964  177.25228882  158.64004517
   175.4631958   174.59733582  166.86227417  151.44953918  191.06915283]]
350 Cost:  [ 6.63629436] 
Prediction:
 [[ 152.61357117  185.11410522  181.40238953  198.56773376  140.86056519
   105.7844696   149.80049133  112.58991241  174.71928406  164.33372498
   143.79597473  142.97549438  186.35920715  153.61503601  150.86773682
   188.67362976  145.59941101  179.93826294  177.24755859  158.63763428
   175.47291565  174.59590149  166.87123108  151.44248962  191.06208801]]
400 Cost:  [ 6.62449026] 
Prediction:
 [[ 152.61659241  185.10766602  181.40090942  198.57015991  140.84980774
   105.78156281  149.81111145  112.61038971  174.71585083  164.33854675
   143.79905701  142.97564697  186.35336304  153.60369873  150.87734985
   188.671875    145.58058167  179.9559021   177.2428894   158.63528442
   175.48249817  174.5944519   166.88006592  151.43557739  191.05508423]]
450 Cost:  [ 6.6129818] 
Prediction:
 [[ 152.61961365  185.1013031   181.3994751   198.57255554  140.83918762
   105.77864838  149.82159424  112.6306076   174.7124176   164.3432312
   143.80210876  142.97575378  186.34764099  153.59254456  150.88682556
   188.6701355   145.56204224  179.97332764  177.23832703  158.63298035
   175.49195862  174.59300232  166.8888092   151.42884827  191.04818726]]
500 Cost:  [ 6.6017828] 
Prediction:
 [[ 152.62260437  185.09501648  181.39807129  198.57493591  140.82868958
   105.77574921  149.83192444  112.65055084  174.70899963  164.34776306
   143.80509949  142.97581482  186.34199524  153.58158875  150.89616394
   188.668396    145.54377747  179.99052429  177.2338562   158.63076782
   175.501297    174.59153748  166.89741516  151.42222595  191.04138184]]
550 Cost:  [ 6.59089041] 
Prediction:
 [[ 152.62562561  185.08879089  181.39671326  198.57727051  140.81829834
   105.77285004  149.84213257  112.67023468  174.70558167  164.35220337
   143.80804443  142.9758606   186.3364563   153.57080078  150.90536499
   188.66664124  145.52578735  180.00749207  177.22949219  158.62861633
   175.51049805  174.59007263  166.90592957  151.415802    191.03466797]]
600 Cost:  [ 6.58026695] 
Prediction:
 [[ 152.62864685  185.08265686  181.39537048  198.5796051   140.80805969
   105.76993561  149.85221863  112.68965149  174.7021637   164.35649109
   143.81098938  142.97584534  186.33103943  153.56022644  150.91444397
   188.66491699  145.5080719   180.02427673  177.22523499  158.62654114
   175.51957703  174.58860779  166.91435242  151.40950012  191.0280304 ]]
650 Cost:  [ 6.56994867] 
Prediction:
 [[ 152.63163757  185.07653809  181.39407349  198.58189392  140.7979126
   105.76702881  149.8621521   112.70881653  174.69873047  164.36064148
   143.81384277  142.97583008  186.32571411  153.54977417  150.92340088
   188.66316223  145.49060059  180.04081726  177.22103882  158.62451172
   175.52853394  174.58712769  166.92263794  151.40335083  191.02148438]]
700 Cost:  [ 6.55988216] 
Prediction:
 [[ 152.6346283   185.07051086  181.39277649  198.58415222  140.78788757
   105.76412964  149.8719635   112.7277298   174.69532776  164.3646698
   143.81666565  142.97575378  186.32046509  153.53953552  150.93222046
   188.66142273  145.47340393  180.05714417  177.2169342   158.62254333
   175.53736877  174.58566284  166.93083191  151.39732361  191.01501465]]
750 Cost:  [ 6.55007982] 
Prediction:
 [[ 152.63760376  185.06454468  181.39154053  198.58639526  140.77796936
   105.76120758  149.88165283  112.7463913   174.69189453  164.36857605
   143.81945801  142.97564697  186.31532288  153.5294342   150.94091797
   188.65966797  145.45645142  180.07327271  177.21292114  158.62063599
   175.54608154  174.58416748  166.93893433  151.39144897  191.00863647]]
800 Cost:  [ 6.54053783] 
Prediction:
 [[ 152.64056396  185.05862427  181.39031982  198.58859253  140.76815796
   105.75831604  149.89120483  112.76480865  174.68847656  164.37237549
   143.82218933  142.9755249   186.31025696  153.51950073  150.94949341
   188.65791321  145.43972778  180.08918762  177.20901489  158.61878967
   175.5546875   174.58270264  166.94692993  151.38569641  191.00233459]]
850 Cost:  [ 6.53124523] 
Prediction:
 [[ 152.64353943  185.05278015  181.3891449   198.59080505  140.75849915
   105.75541687  149.90065002  112.7829895   174.68507385  164.37605286
   143.8249054   142.97535706  186.30529785  153.50975037  150.95793152
   188.65618896  145.42329407  180.10490417  177.20516968  158.61698914
   175.56318665  174.58120728  166.95481873  151.38008118  190.99610901]]
900 Cost:  [ 6.52220774] 
Prediction:
 [[ 152.64649963  185.04699707  181.38796997  198.59295654  140.74890137
   105.75251007  149.90994263  112.80091858  174.68167114  164.37960815
   143.82756042  142.97517395  186.30041504  153.50012207  150.96626282
   188.65444946  145.40707397  180.12042236  177.20140076  158.61524963
   175.5715332   174.57971191  166.96258545  151.37458801  190.98995972]]
950 Cost:  [ 6.51339245] 
Prediction:
 [[ 152.64944458  185.04127502  181.38684082  198.59507751  140.73945618
   105.74961853  149.91912842  112.81861115  174.67828369  164.3830719
   143.83018494  142.97496033  186.29563904  153.49067688  150.97447205
   188.6526947   145.39109802  180.13574219  177.19772339  158.61357117
   175.57978821  174.57823181  166.97029114  151.36923218  190.98390198]]
1000 Cost:  [ 6.50480461] 
Prediction:
 [[ 152.65237427  185.03562927  181.38574219  198.59719849  140.73010254
   105.74672699  149.92819214  112.83607483  174.67489624  164.38641357
   143.83277893  142.97471619  186.29092407  153.48138428  150.98257446
   188.65097046  145.37536621  180.15086365  177.19412231  158.61193848
   175.58795166  174.57675171  166.97789001  151.36398315  190.97790527]]
1050 Cost:  [ 6.49645424] 
Prediction:
 [[ 152.6552887   185.0300293   181.38464355  198.59927368  140.72084045
   105.74385071  149.93713379  112.85330963  174.67150879  164.38964844
   143.83532715  142.97442627  186.2862854   153.472229    150.99055481
   188.64923096  145.35983276  180.16578674  177.19059753  158.61035156
   175.59596252  174.57525635  166.98538208  151.3588562   190.97198486]]
1100 Cost:  [ 6.48831034] 
Prediction:
 [[ 152.65821838  185.02449036  181.3835907   198.60133362  140.7117157
   105.74098206  149.94596863  112.87032318  174.66815186  164.39280701
   143.83784485  142.97415161  186.28175354  153.46324158  150.99842834
   188.64750671  145.34455872  180.18052673  177.18714905  158.60882568
   175.60389709  174.57376099  166.99279785  151.35383606  190.96617126]]
1150 Cost:  [ 6.48036337] 
Prediction:
 [[ 152.66111755  185.01899719  181.38256836  198.60336304  140.7026825
   105.73810577  149.9546814   112.88710785  174.66479492  164.39584351
   143.84030151  142.97383118  186.27728271  153.45437622  151.00619507
   188.64578247  145.32949829  180.19508362  177.1837616   158.60734558
   175.61172485  174.57226562  167.00012207  151.34893799  190.96040344]]
1200 Cost:  [ 6.47266436] 
Prediction:
 [[ 152.66397095  185.0135498   181.38153076  198.60536194  140.69372559
   105.73524475  149.96325684  112.90366364  174.66140747  164.39877319
   143.84274292  142.97348022  186.27287292  153.44564819  151.01382446
   188.64404297  145.31463623  180.20942688  177.18045044  158.60588074
   175.61941528  174.57077026  167.00730896  151.34414673  190.9546814 ]]
1250 Cost:  [ 6.46514225] 
Prediction:
 [[ 152.66687012  185.00819397  181.38056946  198.60736084  140.68492126
   105.73239899  149.97177124  112.92001343  174.65809631  164.40164185
   143.84515381  142.97312927  186.26858521  153.43708801  151.02137756
   188.64234924  145.30003357  180.22361755  177.17721558  158.60450745
   175.62702942  174.56929016  167.01445007  151.33947754  190.94908142]]
1300 Cost:  [ 6.45781612] 
Prediction:
 [[ 152.66973877  185.00291443  181.37963867  198.60934448  140.6762085
   105.72956848  149.98016357  112.93615723  174.65478516  164.40441895
   143.84753418  142.97277832  186.26437378  153.4286499   151.02883911
   188.64065552  145.28562927  180.23764038  177.17407227  158.60317993
   175.634552    174.56781006  167.02151489  151.33491516  190.94355774]]
1350 Cost:  [ 6.45067453] 
Prediction:
 [[ 152.67259216  184.99766541  181.37867737  198.61128235  140.66758728
   105.72672272  149.98841858  112.95207214  174.65144348  164.40708923
   143.84986877  142.97236633  186.26020813  153.42036438  151.03617859
   188.63894653  145.2714386   180.25146484  177.17097473  158.60185242
   175.64198303  174.56632996  167.02845764  151.33044434  190.93809509]]
1400 Cost:  [ 6.44372082] 
Prediction:
 [[ 152.67541504  184.9924469   181.37774658  198.61317444  140.65904236
   105.72389221  149.99656677  112.96777344  174.64813232  164.40966797
   143.85215759  142.97193909  186.25608826  153.41217041  151.04338074
   188.63723755  145.25741577  180.26509094  177.16790771  158.60058594
   175.64926147  174.56483459  167.03530884  151.32606506  190.93266296]]
1450 Cost:  [ 6.43695116] 
Prediction:
 [[ 152.67825317  184.98733521  181.37687683  198.61508179  140.65063477
   105.72109222  150.00463867  112.98329163  174.64485168  164.41218567
   143.85443115  142.97149658  186.25210571  153.40414429  151.05053711
   188.63555908  145.24363708  180.27859497  177.16494751  158.59936523
   175.65649414  174.56336975  167.042099    151.32180786  190.92733765]]
1500 Cost:  [ 6.43034506] 
Prediction:
 [[ 152.68103027  184.98220825  181.37597656  198.6169281   140.64228821
   105.71826935  150.01255798  112.99858856  174.64154053  164.4145813
   143.85665894  142.97102356  186.24812317  153.39620972  151.05752563
   188.6338501   145.23002625  180.2918396   177.16201782  158.59817505
   175.66357422  174.56185913  167.04878235  151.31761169  190.92204285]]
1550 Cost:  [ 6.42391729] 
Prediction:
 [[ 152.68383789  184.97718811  181.37513733  198.61878967  140.63406372
   105.71548462  150.02041626  113.01368713  174.63829041  164.41694641
   143.85887146  142.97055054  186.2442627   153.38844299  151.06446838
   188.63217163  145.21664429  180.30500793  177.15917969  158.5970459
   175.67060852  174.56040955  167.05540466  151.3135376   190.91685486]]
1600 Cost:  [ 6.41766596] 
Prediction:
 [[ 152.68661499  184.972229    181.3742981   198.62063599  140.62594604
   105.71269226  150.02816772  113.02859497  174.63504028  164.41918945
   143.86103821  142.97006226  186.24046326  153.38079834  151.07131958
   188.63050842  145.20344543  180.31796265  177.15640259  158.59594727
   175.67755127  174.55892944  167.06192017  151.30955505  190.91171265]]
1650 Cost:  [ 6.41156816] 
Prediction:
 [[ 152.68939209  184.96730042  181.37347412  198.62240601  140.6178894
   105.70993042  150.03579712  113.04331207  174.6317749   164.42138672
   143.86317444  142.96955872  186.23669434  153.37324524  151.07804871
   188.62882996  145.19042969  180.33074951  177.15365601  158.59484863
   175.68435669  174.5574646   167.06835938  151.30566406  190.90664673]]
1700 Cost:  [ 6.40562582] 
Prediction:
 [[ 152.69213867  184.96243286  181.37269592  198.62420654  140.60993958
   105.70716858  150.04333496  113.05782318  174.6285553   164.42349243
   143.86528015  142.96903992  186.23301697  153.36582947  151.08470154
   188.62718201  145.17759705  180.34339905  177.15098572  158.59382629
   175.69111633  174.55599976  167.07472229  151.30184937  190.90162659]]
1750 Cost:  [ 6.39981937] 
Prediction:
 [[ 152.69487     184.95762634  181.37191772  198.6259613   140.60211182
   105.70441437  150.05078125  113.07215118  174.62536621  164.42553711
   143.86737061  142.96850586  186.22940063  153.35853577  151.09124756
   188.6255188   145.16497803  180.35588074  177.14837646  158.59283447
   175.69775391  174.55455017  167.08102417  151.29812622  190.89668274]]
1800 Cost:  [ 6.39418602] 
Prediction:
 [[ 152.69758606  184.95285034  181.37115479  198.62768555  140.59432983
   105.70167542  150.05812073  113.08628845  174.62214661  164.42750549
   143.86941528  142.96794128  186.2258606   153.35134888  151.09771729
   188.62387085  145.1525116   180.36819458  177.14581299  158.59187317
   175.70431519  174.55308533  167.08721924  151.29449463  190.89181519]]
1850 Cost:  [ 6.38869429] 
Prediction:
 [[ 152.70030212  184.94815063  181.37042236  198.62942505  140.58666992
   105.69896698  150.06538391  113.10025024  174.61897278  164.42941284
   143.87142944  142.96739197  186.22236633  153.34428406  151.10409546
   188.62223816  145.14024353  180.38035583  177.14332581  158.59095764
   175.71080017  174.551651    167.09335327  151.29093933  190.88697815]]
1900 Cost:  [ 6.38333273] 
Prediction:
 [[ 152.70297241  184.9434967   181.36968994  198.63111877  140.5790863
   105.69625092  150.07252502  113.1140213   174.61581421  164.4312439
   143.87341309  142.96681213  186.21891785  153.33731079  151.11036682
   188.62060547  145.12814331  180.39234924  177.14085388  158.59004211
   175.71717834  174.55018616  167.09939575  151.28746033  190.88221741]]
1950 Cost:  [ 6.37812328] 
Prediction:
 [[ 152.7056427   184.93888855  181.36898804  198.63279724  140.57159424
   105.69355774  150.07958984  113.12760925  174.6126709   164.43301392
   143.87536621  142.96621704  186.21554565  153.33045959  151.11657715
   188.61898804  145.11621094  180.40420532  177.13845825  158.58917236
   175.72348022  174.54876709  167.10536194  151.28407288  190.8775177 ]]
2000 Cost:  [ 6.37303352] 
Prediction:
 [[ 152.70828247  184.93431091  181.36828613  198.63442993  140.56417847
   105.69086456  150.08654785  113.1410141   174.60952759  164.43470764
   143.87728882  142.96560669  186.21221924  153.32369995  151.12265015
   188.61737061  145.10444641  180.41589355  177.13607788  158.58831787
   175.72969055  174.54730225  167.11123657  151.28074646  190.87286377]]

In [61]:
model.state_dict()


Out[61]:
OrderedDict([('weight', 
               0.4261  0.4996  1.0865
              [torch.FloatTensor of size 1x3]), ('bias', 
               0.2118
              [torch.FloatTensor of size 1])])

Test


In [ ]:
# Predict my score
print("Your score will be ", model(Variable(torch.Tensor([[100, 70, 101]]))).data.numpy())
print("Other scores will be ", model(Variable(torch.Tensor([[60, 70, 110], [90, 100, 80]]))).data.numpy())