# 03_Linear_Regression_Model

``````

In [8]:

import torch
from torch import nn

``````
``````

In [9]:

import matplotlib.pyplot as plt
%matplotlib inline
torch.manual_seed(1)

``````
``````

Out[9]:

<torch._C.Generator at 0x7fe89848d240>

``````

### Prepare Data

``````

In [41]:

# X and Y training data

x_train = torch.Tensor([[1], [2], [3]])
y_train = torch.Tensor([[1], [2], [3]])

# x_train = torch.Tensor([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
#                     [9.779], [6.182], [7.59], [2.167], [7.042],
#                     [10.791], [5.313], [7.997], [3.1]])

# y_train = torch.Tensor([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
#                     [3.366], [2.596], [2.53], [1.221], [2.827],
#                     [3.465], [1.65], [2.904], [1.3]])

x, y = Variable(x_train), Variable(y_train)

plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

``````
``````

``````

## Naive Linear Regression Model

### Define Linear Regression Model

``````

In [13]:

W = Variable(torch.rand(1,1))
x, W, x.mm(W)

``````
``````

Out[13]:

(Variable containing:
1
2
3
[torch.FloatTensor of size 3x1], Variable containing:
0.7203
[torch.FloatTensor of size 1x1], Variable containing:
0.7203
1.4406
2.1610
[torch.FloatTensor of size 3x1])

``````

### Training Linear Regression Model

``````

In [17]:

cost_func = torch.nn.MSELoss()                 # Our mean squared Cost function
lr = 0.01

for step in range(300):

prediction = x.mm(W)                       # Our Model XW
cost = cost_func(prediction, y)            # must be (1. prediction, 2. training target y)
gradient = (prediction-y).view(-1).dot(x.view(-1)) / len(x)  # Compute Gradient of cost w.r.t W (dCost/dW)
W -= lr * gradient                         # Update weight parameter with learning lr

if step % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
plt.pause(0.1)

print('Linear Model Optimization is Done!')

plt.ioff()

``````
``````

``````