Линейная регрессия и стохастический градиентный спуск

Задание основано на материалах лекций по линейной регрессии и градиентному спуску. Вы будете прогнозировать выручку компании в зависимости от уровня ее инвестиций в рекламу по TV, в газетах и по радио.

Вы научитесь:

  • решать задачу восстановления линейной регрессии
  • реализовывать стохастический градиентный спуск для ее настройки
  • решать задачу линейной регрессии аналитически

Введение

Линейная регрессия - один из наиболее хорошо изученных методов машинного обучения, позволяющий прогнозировать значения количественного признака в виде линейной комбинации прочих признаков с параметрами - весами модели. Оптимальные (в смысле минимальности некоторого функционала ошибки) параметры линейной регрессии можно найти аналитически с помощью нормального уравнения или численно с помощью методов оптимизации.

Линейная регрессия использует простой функционал качества - среднеквадратичную ошибку. Мы будем работать с выборкой, содержащей 3 признака. Для настройки параметров (весов) модели решается следующая задача: $$\Large \frac{1}{\ell}\sum_{i=1}^\ell{{((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}^2} \rightarrow \min_{w_0, w_1, w_2, w_3},$$ где $x_{i1}, x_{i2}, x_{i3}$ - значения признаков $i$-го объекта, $y_i$ - значение целевого признака $i$-го объекта, $\ell$ - число объектов в обучающей выборке.

Градиентный спуск

Параметры $w_0, w_1, w_2, w_3$, по которым минимизируется среднеквадратичная ошибка, можно находить численно с помощью градиентного спуска. Градиентный шаг для весов будет выглядеть следующим образом: $$\Large w_0 \leftarrow w_0 - \frac{2\eta}{\ell} \sum_{i=1}^\ell{{((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}}$$ $$\Large w_j \leftarrow w_j - \frac{2\eta}{\ell} \sum_{i=1}^\ell{{x_{ij}((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}},\ j \in \{1,2,3\}$$ Здесь $\eta$ - параметр, шаг градиентного спуска.

Стохастический градиентный спуск

Проблема градиентного спуска, описанного выше, в том, что на больших выборках считать на каждом шаге градиент по всем имеющимся данным может быть очень вычислительно сложно. В стохастическом варианте градиентного спуска поправки для весов вычисляются только с учетом одного случайно взятого объекта обучающей выборки: $$\Large w_0 \leftarrow w_0 - \frac{2\eta}{\ell} {((w_0 + w_1x_{k1} + w_2x_{k2} + w_3x_{k3}) - y_k)}$$ $$\Large w_j \leftarrow w_j - \frac{2\eta}{\ell} {x_{kj}((w_0 + w_1x_{k1} + w_2x_{k2} + w_3x_{k3}) - y_k)},\ j \in \{1,2,3\},$$ где $k$ - случайный индекс, $k \in \{1, \ldots, \ell\}$.

Нормальное уравнение

Нахождение вектора оптимальных весов $w$ может быть сделано и аналитически. Мы хотим найти такой вектор весов $w$, чтобы вектор $y$, приближающий целевой признак, получался умножением матрицы $X$ (состоящей из всех признаков объектов обучающей выборки, кроме целевого) на вектор весов $w$. То есть, чтобы выполнялось матричное уравнение: $$\Large y = Xw$$ Домножением слева на $X^T$ получаем: $$\Large X^Ty = X^TXw$$ Это хорошо, поскольку теперь матрица $X^TX$ - квадратная, и можно найти решение (вектор $w$) в виде: $$\Large w = {(X^TX)}^{-1}X^Ty$$ Матрица ${(X^TX)}^{-1}X^T$ - псевдообратная для матрицы $X$. В NumPy такую матрицу можно вычислить с помощью функции numpy.linalg.pinv.

Однако, нахождение псевдообратной матрицы - операция вычислительно сложная и нестабильная в случае малого определителя матрицы $X$ (проблема мультиколлинеарности). На практике лучше находить вектор весов $w$ решением матричного уравнения $$\Large X^TXw = X^Ty$$Это может быть сделано с помощью функции numpy.linalg.solve.

Но все же на практике для больших матриц $X$ быстрее работает градиентный спуск, особенно его стохастическая версия.

Инструкции по выполнению

В начале напишем простую функцию для записи ответов в текстовый файл. Ответами будут числа, полученные в ходе решения этого задания, округленные до 3 знаков после запятой. Полученные файлы после выполнения задания надо отправить в форму на странице задания на Coursera.org.


In [1]:
def write_answer_to_file(answer, filename):
    with open(filename, 'w') as f_out:
        f_out.write(str(round(answer, 3)))

1. Загрузите данные из файла advertising.csv в объект pandas DataFrame. Источник данных.


In [3]:
import pandas as pd
adver_data = pd.read_csv('advertising.csv')

Посмотрите на первые 5 записей и на статистику признаков в этом наборе данных.


In [4]:
adver_data.head(5)


Out[4]:
TV Radio Newspaper Sales
1 230.1 37.8 69.2 22.1
2 44.5 39.3 45.1 10.4
3 17.2 45.9 69.3 9.3
4 151.5 41.3 58.5 18.5
5 180.8 10.8 58.4 12.9

In [7]:
adver_data.describe()


Out[7]:
TV Radio Newspaper Sales
count 200.000000 200.000000 200.000000 200.000000
mean 147.042500 23.264000 30.554000 14.022500
std 85.854236 14.846809 21.778621 5.217457
min 0.700000 0.000000 0.300000 1.600000
25% 74.375000 9.975000 12.750000 10.375000
50% 149.750000 22.900000 25.750000 12.900000
75% 218.825000 36.525000 45.100000 17.400000
max 296.400000 49.600000 114.000000 27.000000

Создайте массивы NumPy X из столбцов TV, Radio и Newspaper и y - из столбца Sales. Используйте атрибут values объекта pandas DataFrame.


In [11]:
import numpy as np
data = adver_data
X = np.array(data[['TV', 'Radio', 'Newspaper']])
y = np.array(data['Sales'])
X


Out[11]:
array([[ 230.1,   37.8,   69.2],
       [  44.5,   39.3,   45.1],
       [  17.2,   45.9,   69.3],
       [ 151.5,   41.3,   58.5],
       [ 180.8,   10.8,   58.4],
       [   8.7,   48.9,   75. ],
       [  57.5,   32.8,   23.5],
       [ 120.2,   19.6,   11.6],
       [   8.6,    2.1,    1. ],
       [ 199.8,    2.6,   21.2],
       [  66.1,    5.8,   24.2],
       [ 214.7,   24. ,    4. ],
       [  23.8,   35.1,   65.9],
       [  97.5,    7.6,    7.2],
       [ 204.1,   32.9,   46. ],
       [ 195.4,   47.7,   52.9],
       [  67.8,   36.6,  114. ],
       [ 281.4,   39.6,   55.8],
       [  69.2,   20.5,   18.3],
       [ 147.3,   23.9,   19.1],
       [ 218.4,   27.7,   53.4],
       [ 237.4,    5.1,   23.5],
       [  13.2,   15.9,   49.6],
       [ 228.3,   16.9,   26.2],
       [  62.3,   12.6,   18.3],
       [ 262.9,    3.5,   19.5],
       [ 142.9,   29.3,   12.6],
       [ 240.1,   16.7,   22.9],
       [ 248.8,   27.1,   22.9],
       [  70.6,   16. ,   40.8],
       [ 292.9,   28.3,   43.2],
       [ 112.9,   17.4,   38.6],
       [  97.2,    1.5,   30. ],
       [ 265.6,   20. ,    0.3],
       [  95.7,    1.4,    7.4],
       [ 290.7,    4.1,    8.5],
       [ 266.9,   43.8,    5. ],
       [  74.7,   49.4,   45.7],
       [  43.1,   26.7,   35.1],
       [ 228. ,   37.7,   32. ],
       [ 202.5,   22.3,   31.6],
       [ 177. ,   33.4,   38.7],
       [ 293.6,   27.7,    1.8],
       [ 206.9,    8.4,   26.4],
       [  25.1,   25.7,   43.3],
       [ 175.1,   22.5,   31.5],
       [  89.7,    9.9,   35.7],
       [ 239.9,   41.5,   18.5],
       [ 227.2,   15.8,   49.9],
       [  66.9,   11.7,   36.8],
       [ 199.8,    3.1,   34.6],
       [ 100.4,    9.6,    3.6],
       [ 216.4,   41.7,   39.6],
       [ 182.6,   46.2,   58.7],
       [ 262.7,   28.8,   15.9],
       [ 198.9,   49.4,   60. ],
       [   7.3,   28.1,   41.4],
       [ 136.2,   19.2,   16.6],
       [ 210.8,   49.6,   37.7],
       [ 210.7,   29.5,    9.3],
       [  53.5,    2. ,   21.4],
       [ 261.3,   42.7,   54.7],
       [ 239.3,   15.5,   27.3],
       [ 102.7,   29.6,    8.4],
       [ 131.1,   42.8,   28.9],
       [  69. ,    9.3,    0.9],
       [  31.5,   24.6,    2.2],
       [ 139.3,   14.5,   10.2],
       [ 237.4,   27.5,   11. ],
       [ 216.8,   43.9,   27.2],
       [ 199.1,   30.6,   38.7],
       [ 109.8,   14.3,   31.7],
       [  26.8,   33. ,   19.3],
       [ 129.4,    5.7,   31.3],
       [ 213.4,   24.6,   13.1],
       [  16.9,   43.7,   89.4],
       [  27.5,    1.6,   20.7],
       [ 120.5,   28.5,   14.2],
       [   5.4,   29.9,    9.4],
       [ 116. ,    7.7,   23.1],
       [  76.4,   26.7,   22.3],
       [ 239.8,    4.1,   36.9],
       [  75.3,   20.3,   32.5],
       [  68.4,   44.5,   35.6],
       [ 213.5,   43. ,   33.8],
       [ 193.2,   18.4,   65.7],
       [  76.3,   27.5,   16. ],
       [ 110.7,   40.6,   63.2],
       [  88.3,   25.5,   73.4],
       [ 109.8,   47.8,   51.4],
       [ 134.3,    4.9,    9.3],
       [  28.6,    1.5,   33. ],
       [ 217.7,   33.5,   59. ],
       [ 250.9,   36.5,   72.3],
       [ 107.4,   14. ,   10.9],
       [ 163.3,   31.6,   52.9],
       [ 197.6,    3.5,    5.9],
       [ 184.9,   21. ,   22. ],
       [ 289.7,   42.3,   51.2],
       [ 135.2,   41.7,   45.9],
       [ 222.4,    4.3,   49.8],
       [ 296.4,   36.3,  100.9],
       [ 280.2,   10.1,   21.4],
       [ 187.9,   17.2,   17.9],
       [ 238.2,   34.3,    5.3],
       [ 137.9,   46.4,   59. ],
       [  25. ,   11. ,   29.7],
       [  90.4,    0.3,   23.2],
       [  13.1,    0.4,   25.6],
       [ 255.4,   26.9,    5.5],
       [ 225.8,    8.2,   56.5],
       [ 241.7,   38. ,   23.2],
       [ 175.7,   15.4,    2.4],
       [ 209.6,   20.6,   10.7],
       [  78.2,   46.8,   34.5],
       [  75.1,   35. ,   52.7],
       [ 139.2,   14.3,   25.6],
       [  76.4,    0.8,   14.8],
       [ 125.7,   36.9,   79.2],
       [  19.4,   16. ,   22.3],
       [ 141.3,   26.8,   46.2],
       [  18.8,   21.7,   50.4],
       [ 224. ,    2.4,   15.6],
       [ 123.1,   34.6,   12.4],
       [ 229.5,   32.3,   74.2],
       [  87.2,   11.8,   25.9],
       [   7.8,   38.9,   50.6],
       [  80.2,    0. ,    9.2],
       [ 220.3,   49. ,    3.2],
       [  59.6,   12. ,   43.1],
       [   0.7,   39.6,    8.7],
       [ 265.2,    2.9,   43. ],
       [   8.4,   27.2,    2.1],
       [ 219.8,   33.5,   45.1],
       [  36.9,   38.6,   65.6],
       [  48.3,   47. ,    8.5],
       [  25.6,   39. ,    9.3],
       [ 273.7,   28.9,   59.7],
       [  43. ,   25.9,   20.5],
       [ 184.9,   43.9,    1.7],
       [  73.4,   17. ,   12.9],
       [ 193.7,   35.4,   75.6],
       [ 220.5,   33.2,   37.9],
       [ 104.6,    5.7,   34.4],
       [  96.2,   14.8,   38.9],
       [ 140.3,    1.9,    9. ],
       [ 240.1,    7.3,    8.7],
       [ 243.2,   49. ,   44.3],
       [  38. ,   40.3,   11.9],
       [  44.7,   25.8,   20.6],
       [ 280.7,   13.9,   37. ],
       [ 121. ,    8.4,   48.7],
       [ 197.6,   23.3,   14.2],
       [ 171.3,   39.7,   37.7],
       [ 187.8,   21.1,    9.5],
       [   4.1,   11.6,    5.7],
       [  93.9,   43.5,   50.5],
       [ 149.8,    1.3,   24.3],
       [  11.7,   36.9,   45.2],
       [ 131.7,   18.4,   34.6],
       [ 172.5,   18.1,   30.7],
       [  85.7,   35.8,   49.3],
       [ 188.4,   18.1,   25.6],
       [ 163.5,   36.8,    7.4],
       [ 117.2,   14.7,    5.4],
       [ 234.5,    3.4,   84.8],
       [  17.9,   37.6,   21.6],
       [ 206.8,    5.2,   19.4],
       [ 215.4,   23.6,   57.6],
       [ 284.3,   10.6,    6.4],
       [  50. ,   11.6,   18.4],
       [ 164.5,   20.9,   47.4],
       [  19.6,   20.1,   17. ],
       [ 168.4,    7.1,   12.8],
       [ 222.4,    3.4,   13.1],
       [ 276.9,   48.9,   41.8],
       [ 248.4,   30.2,   20.3],
       [ 170.2,    7.8,   35.2],
       [ 276.7,    2.3,   23.7],
       [ 165.6,   10. ,   17.6],
       [ 156.6,    2.6,    8.3],
       [ 218.5,    5.4,   27.4],
       [  56.2,    5.7,   29.7],
       [ 287.6,   43. ,   71.8],
       [ 253.8,   21.3,   30. ],
       [ 205. ,   45.1,   19.6],
       [ 139.5,    2.1,   26.6],
       [ 191.1,   28.7,   18.2],
       [ 286. ,   13.9,    3.7],
       [  18.7,   12.1,   23.4],
       [  39.5,   41.1,    5.8],
       [  75.5,   10.8,    6. ],
       [  17.2,    4.1,   31.6],
       [ 166.8,   42. ,    3.6],
       [ 149.7,   35.6,    6. ],
       [  38.2,    3.7,   13.8],
       [  94.2,    4.9,    8.1],
       [ 177. ,    9.3,    6.4],
       [ 283.6,   42. ,   66.2],
       [ 232.1,    8.6,    8.7]])

Отмасштабируйте столбцы матрицы X, вычтя из каждого значения среднее по соответствующему столбцу и поделив результат на стандартное отклонение. Для определенности, используйте методы mean и std векторов NumPy (реализация std в Pandas может отличаться). Обратите внимание, что в numpy вызов функции .mean() без параметров возвращает среднее по всем элементам массива, а не по столбцам, как в pandas. Чтобы произвести вычисление по столбцам, необходимо указать параметр axis.


In [14]:
means, stds = np.mean(X, axis=0), np.std(X, axis=0)


Out[14]:
array([ 147.0425,   23.264 ,   30.554 ])

In [15]:
X = (X - means) / stds
X


Out[15]:
array([[  9.69852266e-01,   9.81522472e-01,   1.77894547e+00],
       [ -1.19737623e+00,   1.08280781e+00,   6.69578760e-01],
       [ -1.51615499e+00,   1.52846331e+00,   1.78354865e+00],
       [  5.20496822e-02,   1.21785493e+00,   1.28640506e+00],
       [  3.94182198e-01,  -8.41613655e-01,   1.28180188e+00],
       [ -1.61540845e+00,   1.73103399e+00,   2.04592999e+00],
       [ -1.04557682e+00,   6.43904671e-01,  -3.24708413e-01],
       [ -3.13436589e-01,  -2.47406325e-01,  -8.72486994e-01],
       [ -1.61657614e+00,  -1.42906863e+00,  -1.36042422e+00],
       [  6.16042873e-01,  -1.39530685e+00,  -4.30581584e-01],
       [ -9.45155670e-01,  -1.17923146e+00,  -2.92486143e-01],
       [  7.90028350e-01,   4.96973404e-02,  -1.22232878e+00],
       [ -1.43908760e+00,   7.99208859e-01,   1.62704048e+00],
       [ -5.78501712e-01,  -1.05768905e+00,  -1.07502697e+00],
       [  6.66253447e-01,   6.50657027e-01,   7.11007392e-01],
       [  5.64664612e-01,   1.65000572e+00,   1.02862691e+00],
       [ -9.25304978e-01,   9.00494200e-01,   3.84117072e+00],
       [  1.56887609e+00,   1.10306488e+00,   1.16211917e+00],
       [ -9.08957349e-01,  -1.86635121e-01,  -5.64073843e-01],
       [  3.00679600e-03,   4.29449843e-02,  -5.27248393e-01],
       [  8.33232798e-01,   2.99534513e-01,   1.05164281e+00],
       [  1.05509347e+00,  -1.22649795e+00,  -3.24708413e-01],
       [ -1.56286250e+00,  -4.97243498e-01,   8.76721921e-01],
       [  9.48833887e-01,  -4.29719938e-01,  -2.00422516e-01],
       [ -9.89527805e-01,  -7.20071247e-01,  -5.64073843e-01],
       [  1.35285385e+00,  -1.33453565e+00,  -5.08835667e-01],
       [ -4.83714657e-02,   4.07572210e-01,  -8.26455181e-01],
       [  1.08662104e+00,  -4.43224650e-01,  -3.52327501e-01],
       [  1.18820988e+00,   2.59020377e-01,  -3.52327501e-01],
       [ -8.92609721e-01,  -4.90491142e-01,   4.71641962e-01],
       [  1.70316018e+00,   3.40048650e-01,   5.82118314e-01],
       [ -3.98677796e-01,  -3.95958157e-01,   3.70371972e-01],
       [ -5.82004775e-01,  -1.46958277e+00,  -2.55016247e-02],
       [  1.38438142e+00,  -2.20396901e-01,  -1.39264649e+00],
       [ -5.99520091e-01,  -1.47633512e+00,  -1.06582061e+00],
       [  1.67747105e+00,  -1.29402151e+00,  -1.01518562e+00],
       [  1.39956136e+00,   1.38666383e+00,  -1.17629696e+00],
       [ -8.44734522e-01,   1.76479577e+00,   6.97197848e-01],
       [ -1.21372386e+00,   2.32010953e-01,   2.09260624e-01],
       [  9.45330823e-01,   9.74770116e-01,   6.65620024e-02],
       [  6.47570443e-01,  -6.50927121e-02,   4.81492770e-02],
       [  3.49810063e-01,   6.84418807e-01,   3.74975153e-01],
       [  1.71133400e+00,   2.99534513e-01,  -1.32359877e+00],
       [  6.98948705e-01,  -1.00367020e+00,  -1.91216154e-01],
       [ -1.42390765e+00,   1.64487393e-01,   5.86721496e-01],
       [  3.27623995e-01,  -5.15880000e-02,   4.35460956e-02],
       [ -6.69581357e-01,  -9.02384859e-01,   2.36879713e-01],
       [  1.08428567e+00,   1.23135965e+00,  -5.54867481e-01],
       [  9.35989321e-01,  -5.03995854e-01,   8.90531465e-01],
       [ -9.35814168e-01,  -7.80842451e-01,   2.87514708e-01],
       [  6.16042873e-01,  -1.36154507e+00,   1.86244718e-01],
       [ -5.44638766e-01,  -9.22641928e-01,  -1.24074150e+00],
       [  8.09879042e-01,   1.24486436e+00,   4.16403786e-01],
       [  4.15200577e-01,   1.54872038e+00,   1.29561142e+00],
       [  1.35051848e+00,   3.73810430e-01,  -6.74550196e-01],
       [  6.05533683e-01,   1.76479577e+00,   1.35545278e+00],
       [ -1.63175608e+00,   3.26543937e-01,   4.99261050e-01],
       [ -1.26606546e-01,  -2.74415749e-01,  -6.42327927e-01],
       [  7.44488528e-01,   1.77830048e+00,   3.28943340e-01],
       [  7.43320840e-01,   4.21076922e-01,  -9.78360166e-01],
       [ -1.09228433e+00,  -1.43582099e+00,  -4.21375221e-01],
       [  1.33417085e+00,   1.31238792e+00,   1.11148417e+00],
       [  1.07727954e+00,  -5.24252922e-01,  -1.49787521e-01],
       [ -5.17781948e-01,   4.27829278e-01,  -1.01978880e+00],
       [ -1.86158622e-01,   1.31914027e+00,  -7.61366196e-02],
       [ -9.11292725e-01,  -9.42898996e-01,  -1.36502740e+00],
       [ -1.34917564e+00,   9.02114765e-02,  -1.30518604e+00],
       [ -9.04082253e-02,  -5.91776482e-01,  -9.36931533e-01],
       [  1.05509347e+00,   2.86029801e-01,  -9.00106083e-01],
       [  8.14549794e-01,   1.39341619e+00,  -1.54390703e-01],
       [  6.07869059e-01,   4.95352838e-01,   3.74975153e-01],
       [ -4.34876116e-01,  -6.05281194e-01,   5.27524584e-02],
       [ -1.40405696e+00,   6.57409383e-01,  -5.18042030e-01],
       [ -2.06009314e-01,  -1.18598381e+00,   3.43397329e-02],
       [  7.74848409e-01,   9.02114765e-02,  -8.03439274e-01],
       [ -1.51965805e+00,   1.37991148e+00,   2.70878810e+00],
       [ -1.39588315e+00,  -1.46283041e+00,  -4.53597491e-01],
       [ -3.09933525e-01,   3.53553362e-01,  -7.52804279e-01],
       [ -1.65394214e+00,   4.48086346e-01,  -9.73756984e-01],
       [ -3.62479475e-01,  -1.05093669e+00,  -3.43121138e-01],
       [ -8.24883830e-01,   2.32010953e-01,  -3.79946589e-01],
       [  1.08311798e+00,  -1.29402151e+00,   2.92117889e-01],
       [ -8.37728396e-01,  -2.00139833e-01,   8.95779092e-02],
       [ -9.18298852e-01,   1.43393033e+00,   2.32276531e-01],
       [  7.76016097e-01,   1.33264499e+00,   1.49419267e-01],
       [  5.38975481e-01,  -3.28434597e-01,   1.61783412e+00],
       [ -8.26051518e-01,   2.86029801e-01,  -6.69947015e-01],
       [ -4.24366926e-01,   1.17058844e+00,   1.50275459e+00],
       [ -6.85928986e-01,   1.50982681e-01,   1.97227908e+00],
       [ -4.34876116e-01,   1.65675807e+00,   9.59579186e-01],
       [ -1.48792614e-01,  -1.24000266e+00,  -9.78360166e-01],
       [ -1.38303858e+00,  -1.46958277e+00,   1.12593816e-01],
       [  8.25058983e-01,   6.91171163e-01,   1.30942097e+00],
       [  1.21273132e+00,   8.93741844e-01,   1.92164409e+00],
       [ -4.62900623e-01,  -6.25538262e-01,  -9.04709264e-01],
       [  1.89836839e-01,   5.62876398e-01,   1.02862691e+00],
       [  5.90353742e-01,  -1.33453565e+00,  -1.13486833e+00],
       [  4.42057396e-01,  -1.52873340e-01,  -3.93756133e-01],
       [  1.66579418e+00,   1.28537849e+00,   9.50372823e-01],
       [ -1.38283424e-01,   1.24486436e+00,   7.06404211e-01],
       [  8.79940308e-01,  -1.28051680e+00,   8.85928284e-01],
       [  1.74402926e+00,   8.80237132e-01,   3.23815396e+00],
       [  1.55486384e+00,  -8.88880147e-01,  -4.21375221e-01],
       [  4.77088029e-01,  -4.09462869e-01,  -5.82486569e-01],
       [  1.06443498e+00,   7.45190011e-01,  -1.16248742e+00],
       [ -1.06755854e-01,   1.56222509e+00,   1.30942097e+00],
       [ -1.42507534e+00,  -8.28108943e-01,  -3.93111688e-02],
       [ -6.61407543e-01,  -1.55061104e+00,  -3.38517957e-01],
       [ -1.56403019e+00,  -1.54385868e+00,  -2.28041604e-01],
       [  1.26527727e+00,   2.45515665e-01,  -1.15328106e+00],
       [  9.19641692e-01,  -1.01717491e+00,   1.19434143e+00],
       [  1.10530405e+00,   9.95027184e-01,  -3.38517957e-01],
       [  3.34630122e-01,  -5.31005278e-01,  -1.29597968e+00],
       [  7.30476274e-01,  -1.79882765e-01,  -9.13915627e-01],
       [ -8.03865450e-01,   1.58923451e+00,   1.81641536e-01],
       [ -8.40063771e-01,   7.92456503e-01,   1.01942054e+00],
       [ -9.15759131e-02,  -6.05281194e-01,  -2.28041604e-01],
       [ -8.24883830e-01,  -1.51684926e+00,  -7.25185191e-01],
       [ -2.49213762e-01,   9.20751268e-01,   2.23926360e+00],
       [ -1.49046586e+00,  -4.90491142e-01,  -3.79946589e-01],
       [ -6.70544700e-02,   2.38763309e-01,   7.20213755e-01],
       [ -1.49747198e+00,  -1.05606848e-01,   9.13547372e-01],
       [  8.98623313e-01,  -1.40881156e+00,  -6.88359740e-01],
       [ -2.79573643e-01,   7.65447079e-01,  -8.35661544e-01],
       [  9.62846140e-01,   6.10142891e-01,   2.00910454e+00],
       [ -6.98773552e-01,  -7.74090095e-01,  -2.14232060e-01],
       [ -1.62591764e+00,   1.05579839e+00,   9.22753735e-01],
       [ -7.80511695e-01,  -1.57086811e+00,  -9.82963347e-01],
       [  8.55418865e-01,   1.73778635e+00,  -1.25915423e+00],
       [ -1.02105537e+00,  -7.60585383e-01,   5.77515133e-01],
       [ -1.70882347e+00,   1.10306488e+00,  -1.00597925e+00],
       [  1.37971067e+00,  -1.37504978e+00,   5.72911952e-01],
       [ -1.61891151e+00,   2.65772733e-01,  -1.30978922e+00],
       [  8.49580427e-01,   6.91171163e-01,   6.69578760e-01],
       [ -1.28612050e+00,   1.03554132e+00,   1.61323094e+00],
       [ -1.15300409e+00,   1.60273923e+00,  -1.01518562e+00],
       [ -1.41806922e+00,   1.06255074e+00,  -9.78360166e-01],
       [  1.47896413e+00,   3.80562786e-01,   1.34164324e+00],
       [ -1.21489154e+00,   1.77992105e-01,  -4.62803854e-01],
       [  4.42057396e-01,   1.39341619e+00,  -1.32820195e+00],
       [ -8.59914463e-01,  -4.22967582e-01,  -8.12645637e-01],
       [  5.44813920e-01,   8.19465927e-01,   2.07354907e+00],
       [  8.57754241e-01,   6.70914095e-01,   3.38149702e-01],
       [ -4.95595880e-01,  -1.18598381e+00,   1.77038355e-01],
       [ -5.93681653e-01,  -5.71519414e-01,   3.84181516e-01],
       [ -7.87313476e-02,  -1.44257334e+00,  -9.92169710e-01],
       [  1.08662104e+00,  -1.07794612e+00,  -1.00597925e+00],
       [  1.12281936e+00,   1.73778635e+00,   6.32753309e-01],
       [ -1.27327593e+00,   1.15033137e+00,  -8.58677450e-01],
       [ -1.19504085e+00,   1.71239749e-01,  -4.58200672e-01],
       [  1.56070228e+00,  -6.32290618e-01,   2.96721070e-01],
       [ -3.04095087e-01,  -1.00367020e+00,   8.35293289e-01],
       [  5.90353742e-01,   2.43084817e-03,  -7.52804279e-01],
       [  2.83251860e-01,   1.10981724e+00,   3.28943340e-01],
       [  4.75920341e-01,  -1.46120984e-01,  -9.69153803e-01],
       [ -1.66912209e+00,  -7.87594807e-01,  -1.14407469e+00],
       [ -6.20538471e-01,   1.36640677e+00,   9.18150553e-01],
       [  3.21989902e-02,  -1.48308748e+00,  -2.87882962e-01],
       [ -1.58037782e+00,   9.20751268e-01,   6.74181942e-01],
       [ -1.79152496e-01,  -3.28434597e-01,   1.86244718e-01],
       [  2.97264113e-01,  -3.48691665e-01,   6.72064478e-03],
       [ -7.16288868e-01,   8.46475352e-01,   8.62912377e-01],
       [  4.82926468e-01,  -3.48691665e-01,  -2.28041604e-01],
       [  1.92172214e-01,   9.13998912e-01,  -1.06582061e+00],
       [ -3.48467222e-01,  -5.78271770e-01,  -1.15788424e+00],
       [  1.02123053e+00,  -1.34128800e+00,   2.49704176e+00],
       [ -1.50798117e+00,   9.68017760e-01,  -4.12168859e-01],
       [  6.97781017e-01,  -1.21974559e+00,  -5.13438849e-01],
       [  7.98202165e-01,   2.26879163e-02,   1.24497643e+00],
       [  1.60273904e+00,  -8.55118367e-01,  -1.11185242e+00],
       [ -1.13315340e+00,  -7.87594807e-01,  -5.59470662e-01],
       [  2.03849092e-01,  -1.59625696e-01,   7.75451931e-01],
       [ -1.48813048e+00,  -2.13644545e-01,  -6.23915201e-01],
       [  2.49388915e-01,  -1.09145083e+00,  -8.17248818e-01],
       [  8.79940308e-01,  -1.34128800e+00,  -8.03439274e-01],
       [  1.51633014e+00,   1.73103399e+00,   5.17673775e-01],
       [  1.18353913e+00,   4.68343414e-01,  -4.72010216e-01],
       [  2.70407294e-01,  -1.04418434e+00,   2.13863806e-01],
       [  1.51399477e+00,  -1.41556392e+00,  -3.15502050e-01],
       [  2.16693657e-01,  -8.95632503e-01,  -5.96296113e-01],
       [  1.11601758e-01,  -1.39530685e+00,  -1.02439198e+00],
       [  8.34400486e-01,  -1.20624088e+00,  -1.45184340e-01],
       [ -1.06075676e+00,  -1.18598381e+00,  -3.93111688e-02],
       [  1.64127273e+00,   1.33264499e+00,   1.89862818e+00],
       [  1.24659427e+00,  -1.32616272e-01,  -2.55016247e-02],
       [  6.76762637e-01,   1.47444446e+00,  -5.04232486e-01],
       [ -8.80728498e-02,  -1.42906863e+00,  -1.82009791e-01],
       [  5.14454038e-01,   3.67058074e-01,  -5.68677025e-01],
       [  1.62258973e+00,  -6.32290618e-01,  -1.23613832e+00],
       [ -1.49863967e+00,  -7.53833027e-01,  -3.29311594e-01],
       [ -1.25576062e+00,   1.20435022e+00,  -1.13947151e+00],
       [ -8.35393020e-01,  -8.41613655e-01,  -1.13026515e+00],
       [ -1.51615499e+00,  -1.29402151e+00,   4.81492770e-02],
       [  2.30705910e-01,   1.26512143e+00,  -1.24074150e+00],
       [  3.10313024e-02,   8.32970639e-01,  -1.13026515e+00],
       [ -1.27094056e+00,  -1.32103093e+00,  -7.71217005e-01],
       [ -6.17035408e-01,  -1.24000266e+00,  -1.03359834e+00],
       [  3.49810063e-01,  -9.42898996e-01,  -1.11185242e+00],
       [  1.59456522e+00,   1.26512143e+00,   1.64085003e+00],
       [  9.93206022e-01,  -9.90165488e-01,  -1.00597925e+00]])

Добавьте к матрице X столбец из единиц, используя методы hstack, ones и reshape библиотеки NumPy. Вектор из единиц нужен для того, чтобы не обрабатывать отдельно коэффициент $w_0$ линейной регрессии.


In [19]:
X = np.hstack((X, np.ones(len(X)).reshape(len(X), 1)))
X


Out[19]:
array([[  9.69852266e-01,   9.81522472e-01,   1.77894547e+00,
          1.00000000e+00],
       [ -1.19737623e+00,   1.08280781e+00,   6.69578760e-01,
          1.00000000e+00],
       [ -1.51615499e+00,   1.52846331e+00,   1.78354865e+00,
          1.00000000e+00],
       [  5.20496822e-02,   1.21785493e+00,   1.28640506e+00,
          1.00000000e+00],
       [  3.94182198e-01,  -8.41613655e-01,   1.28180188e+00,
          1.00000000e+00],
       [ -1.61540845e+00,   1.73103399e+00,   2.04592999e+00,
          1.00000000e+00],
       [ -1.04557682e+00,   6.43904671e-01,  -3.24708413e-01,
          1.00000000e+00],
       [ -3.13436589e-01,  -2.47406325e-01,  -8.72486994e-01,
          1.00000000e+00],
       [ -1.61657614e+00,  -1.42906863e+00,  -1.36042422e+00,
          1.00000000e+00],
       [  6.16042873e-01,  -1.39530685e+00,  -4.30581584e-01,
          1.00000000e+00],
       [ -9.45155670e-01,  -1.17923146e+00,  -2.92486143e-01,
          1.00000000e+00],
       [  7.90028350e-01,   4.96973404e-02,  -1.22232878e+00,
          1.00000000e+00],
       [ -1.43908760e+00,   7.99208859e-01,   1.62704048e+00,
          1.00000000e+00],
       [ -5.78501712e-01,  -1.05768905e+00,  -1.07502697e+00,
          1.00000000e+00],
       [  6.66253447e-01,   6.50657027e-01,   7.11007392e-01,
          1.00000000e+00],
       [  5.64664612e-01,   1.65000572e+00,   1.02862691e+00,
          1.00000000e+00],
       [ -9.25304978e-01,   9.00494200e-01,   3.84117072e+00,
          1.00000000e+00],
       [  1.56887609e+00,   1.10306488e+00,   1.16211917e+00,
          1.00000000e+00],
       [ -9.08957349e-01,  -1.86635121e-01,  -5.64073843e-01,
          1.00000000e+00],
       [  3.00679600e-03,   4.29449843e-02,  -5.27248393e-01,
          1.00000000e+00],
       [  8.33232798e-01,   2.99534513e-01,   1.05164281e+00,
          1.00000000e+00],
       [  1.05509347e+00,  -1.22649795e+00,  -3.24708413e-01,
          1.00000000e+00],
       [ -1.56286250e+00,  -4.97243498e-01,   8.76721921e-01,
          1.00000000e+00],
       [  9.48833887e-01,  -4.29719938e-01,  -2.00422516e-01,
          1.00000000e+00],
       [ -9.89527805e-01,  -7.20071247e-01,  -5.64073843e-01,
          1.00000000e+00],
       [  1.35285385e+00,  -1.33453565e+00,  -5.08835667e-01,
          1.00000000e+00],
       [ -4.83714657e-02,   4.07572210e-01,  -8.26455181e-01,
          1.00000000e+00],
       [  1.08662104e+00,  -4.43224650e-01,  -3.52327501e-01,
          1.00000000e+00],
       [  1.18820988e+00,   2.59020377e-01,  -3.52327501e-01,
          1.00000000e+00],
       [ -8.92609721e-01,  -4.90491142e-01,   4.71641962e-01,
          1.00000000e+00],
       [  1.70316018e+00,   3.40048650e-01,   5.82118314e-01,
          1.00000000e+00],
       [ -3.98677796e-01,  -3.95958157e-01,   3.70371972e-01,
          1.00000000e+00],
       [ -5.82004775e-01,  -1.46958277e+00,  -2.55016247e-02,
          1.00000000e+00],
       [  1.38438142e+00,  -2.20396901e-01,  -1.39264649e+00,
          1.00000000e+00],
       [ -5.99520091e-01,  -1.47633512e+00,  -1.06582061e+00,
          1.00000000e+00],
       [  1.67747105e+00,  -1.29402151e+00,  -1.01518562e+00,
          1.00000000e+00],
       [  1.39956136e+00,   1.38666383e+00,  -1.17629696e+00,
          1.00000000e+00],
       [ -8.44734522e-01,   1.76479577e+00,   6.97197848e-01,
          1.00000000e+00],
       [ -1.21372386e+00,   2.32010953e-01,   2.09260624e-01,
          1.00000000e+00],
       [  9.45330823e-01,   9.74770116e-01,   6.65620024e-02,
          1.00000000e+00],
       [  6.47570443e-01,  -6.50927121e-02,   4.81492770e-02,
          1.00000000e+00],
       [  3.49810063e-01,   6.84418807e-01,   3.74975153e-01,
          1.00000000e+00],
       [  1.71133400e+00,   2.99534513e-01,  -1.32359877e+00,
          1.00000000e+00],
       [  6.98948705e-01,  -1.00367020e+00,  -1.91216154e-01,
          1.00000000e+00],
       [ -1.42390765e+00,   1.64487393e-01,   5.86721496e-01,
          1.00000000e+00],
       [  3.27623995e-01,  -5.15880000e-02,   4.35460956e-02,
          1.00000000e+00],
       [ -6.69581357e-01,  -9.02384859e-01,   2.36879713e-01,
          1.00000000e+00],
       [  1.08428567e+00,   1.23135965e+00,  -5.54867481e-01,
          1.00000000e+00],
       [  9.35989321e-01,  -5.03995854e-01,   8.90531465e-01,
          1.00000000e+00],
       [ -9.35814168e-01,  -7.80842451e-01,   2.87514708e-01,
          1.00000000e+00],
       [  6.16042873e-01,  -1.36154507e+00,   1.86244718e-01,
          1.00000000e+00],
       [ -5.44638766e-01,  -9.22641928e-01,  -1.24074150e+00,
          1.00000000e+00],
       [  8.09879042e-01,   1.24486436e+00,   4.16403786e-01,
          1.00000000e+00],
       [  4.15200577e-01,   1.54872038e+00,   1.29561142e+00,
          1.00000000e+00],
       [  1.35051848e+00,   3.73810430e-01,  -6.74550196e-01,
          1.00000000e+00],
       [  6.05533683e-01,   1.76479577e+00,   1.35545278e+00,
          1.00000000e+00],
       [ -1.63175608e+00,   3.26543937e-01,   4.99261050e-01,
          1.00000000e+00],
       [ -1.26606546e-01,  -2.74415749e-01,  -6.42327927e-01,
          1.00000000e+00],
       [  7.44488528e-01,   1.77830048e+00,   3.28943340e-01,
          1.00000000e+00],
       [  7.43320840e-01,   4.21076922e-01,  -9.78360166e-01,
          1.00000000e+00],
       [ -1.09228433e+00,  -1.43582099e+00,  -4.21375221e-01,
          1.00000000e+00],
       [  1.33417085e+00,   1.31238792e+00,   1.11148417e+00,
          1.00000000e+00],
       [  1.07727954e+00,  -5.24252922e-01,  -1.49787521e-01,
          1.00000000e+00],
       [ -5.17781948e-01,   4.27829278e-01,  -1.01978880e+00,
          1.00000000e+00],
       [ -1.86158622e-01,   1.31914027e+00,  -7.61366196e-02,
          1.00000000e+00],
       [ -9.11292725e-01,  -9.42898996e-01,  -1.36502740e+00,
          1.00000000e+00],
       [ -1.34917564e+00,   9.02114765e-02,  -1.30518604e+00,
          1.00000000e+00],
       [ -9.04082253e-02,  -5.91776482e-01,  -9.36931533e-01,
          1.00000000e+00],
       [  1.05509347e+00,   2.86029801e-01,  -9.00106083e-01,
          1.00000000e+00],
       [  8.14549794e-01,   1.39341619e+00,  -1.54390703e-01,
          1.00000000e+00],
       [  6.07869059e-01,   4.95352838e-01,   3.74975153e-01,
          1.00000000e+00],
       [ -4.34876116e-01,  -6.05281194e-01,   5.27524584e-02,
          1.00000000e+00],
       [ -1.40405696e+00,   6.57409383e-01,  -5.18042030e-01,
          1.00000000e+00],
       [ -2.06009314e-01,  -1.18598381e+00,   3.43397329e-02,
          1.00000000e+00],
       [  7.74848409e-01,   9.02114765e-02,  -8.03439274e-01,
          1.00000000e+00],
       [ -1.51965805e+00,   1.37991148e+00,   2.70878810e+00,
          1.00000000e+00],
       [ -1.39588315e+00,  -1.46283041e+00,  -4.53597491e-01,
          1.00000000e+00],
       [ -3.09933525e-01,   3.53553362e-01,  -7.52804279e-01,
          1.00000000e+00],
       [ -1.65394214e+00,   4.48086346e-01,  -9.73756984e-01,
          1.00000000e+00],
       [ -3.62479475e-01,  -1.05093669e+00,  -3.43121138e-01,
          1.00000000e+00],
       [ -8.24883830e-01,   2.32010953e-01,  -3.79946589e-01,
          1.00000000e+00],
       [  1.08311798e+00,  -1.29402151e+00,   2.92117889e-01,
          1.00000000e+00],
       [ -8.37728396e-01,  -2.00139833e-01,   8.95779092e-02,
          1.00000000e+00],
       [ -9.18298852e-01,   1.43393033e+00,   2.32276531e-01,
          1.00000000e+00],
       [  7.76016097e-01,   1.33264499e+00,   1.49419267e-01,
          1.00000000e+00],
       [  5.38975481e-01,  -3.28434597e-01,   1.61783412e+00,
          1.00000000e+00],
       [ -8.26051518e-01,   2.86029801e-01,  -6.69947015e-01,
          1.00000000e+00],
       [ -4.24366926e-01,   1.17058844e+00,   1.50275459e+00,
          1.00000000e+00],
       [ -6.85928986e-01,   1.50982681e-01,   1.97227908e+00,
          1.00000000e+00],
       [ -4.34876116e-01,   1.65675807e+00,   9.59579186e-01,
          1.00000000e+00],
       [ -1.48792614e-01,  -1.24000266e+00,  -9.78360166e-01,
          1.00000000e+00],
       [ -1.38303858e+00,  -1.46958277e+00,   1.12593816e-01,
          1.00000000e+00],
       [  8.25058983e-01,   6.91171163e-01,   1.30942097e+00,
          1.00000000e+00],
       [  1.21273132e+00,   8.93741844e-01,   1.92164409e+00,
          1.00000000e+00],
       [ -4.62900623e-01,  -6.25538262e-01,  -9.04709264e-01,
          1.00000000e+00],
       [  1.89836839e-01,   5.62876398e-01,   1.02862691e+00,
          1.00000000e+00],
       [  5.90353742e-01,  -1.33453565e+00,  -1.13486833e+00,
          1.00000000e+00],
       [  4.42057396e-01,  -1.52873340e-01,  -3.93756133e-01,
          1.00000000e+00],
       [  1.66579418e+00,   1.28537849e+00,   9.50372823e-01,
          1.00000000e+00],
       [ -1.38283424e-01,   1.24486436e+00,   7.06404211e-01,
          1.00000000e+00],
       [  8.79940308e-01,  -1.28051680e+00,   8.85928284e-01,
          1.00000000e+00],
       [  1.74402926e+00,   8.80237132e-01,   3.23815396e+00,
          1.00000000e+00],
       [  1.55486384e+00,  -8.88880147e-01,  -4.21375221e-01,
          1.00000000e+00],
       [  4.77088029e-01,  -4.09462869e-01,  -5.82486569e-01,
          1.00000000e+00],
       [  1.06443498e+00,   7.45190011e-01,  -1.16248742e+00,
          1.00000000e+00],
       [ -1.06755854e-01,   1.56222509e+00,   1.30942097e+00,
          1.00000000e+00],
       [ -1.42507534e+00,  -8.28108943e-01,  -3.93111688e-02,
          1.00000000e+00],
       [ -6.61407543e-01,  -1.55061104e+00,  -3.38517957e-01,
          1.00000000e+00],
       [ -1.56403019e+00,  -1.54385868e+00,  -2.28041604e-01,
          1.00000000e+00],
       [  1.26527727e+00,   2.45515665e-01,  -1.15328106e+00,
          1.00000000e+00],
       [  9.19641692e-01,  -1.01717491e+00,   1.19434143e+00,
          1.00000000e+00],
       [  1.10530405e+00,   9.95027184e-01,  -3.38517957e-01,
          1.00000000e+00],
       [  3.34630122e-01,  -5.31005278e-01,  -1.29597968e+00,
          1.00000000e+00],
       [  7.30476274e-01,  -1.79882765e-01,  -9.13915627e-01,
          1.00000000e+00],
       [ -8.03865450e-01,   1.58923451e+00,   1.81641536e-01,
          1.00000000e+00],
       [ -8.40063771e-01,   7.92456503e-01,   1.01942054e+00,
          1.00000000e+00],
       [ -9.15759131e-02,  -6.05281194e-01,  -2.28041604e-01,
          1.00000000e+00],
       [ -8.24883830e-01,  -1.51684926e+00,  -7.25185191e-01,
          1.00000000e+00],
       [ -2.49213762e-01,   9.20751268e-01,   2.23926360e+00,
          1.00000000e+00],
       [ -1.49046586e+00,  -4.90491142e-01,  -3.79946589e-01,
          1.00000000e+00],
       [ -6.70544700e-02,   2.38763309e-01,   7.20213755e-01,
          1.00000000e+00],
       [ -1.49747198e+00,  -1.05606848e-01,   9.13547372e-01,
          1.00000000e+00],
       [  8.98623313e-01,  -1.40881156e+00,  -6.88359740e-01,
          1.00000000e+00],
       [ -2.79573643e-01,   7.65447079e-01,  -8.35661544e-01,
          1.00000000e+00],
       [  9.62846140e-01,   6.10142891e-01,   2.00910454e+00,
          1.00000000e+00],
       [ -6.98773552e-01,  -7.74090095e-01,  -2.14232060e-01,
          1.00000000e+00],
       [ -1.62591764e+00,   1.05579839e+00,   9.22753735e-01,
          1.00000000e+00],
       [ -7.80511695e-01,  -1.57086811e+00,  -9.82963347e-01,
          1.00000000e+00],
       [  8.55418865e-01,   1.73778635e+00,  -1.25915423e+00,
          1.00000000e+00],
       [ -1.02105537e+00,  -7.60585383e-01,   5.77515133e-01,
          1.00000000e+00],
       [ -1.70882347e+00,   1.10306488e+00,  -1.00597925e+00,
          1.00000000e+00],
       [  1.37971067e+00,  -1.37504978e+00,   5.72911952e-01,
          1.00000000e+00],
       [ -1.61891151e+00,   2.65772733e-01,  -1.30978922e+00,
          1.00000000e+00],
       [  8.49580427e-01,   6.91171163e-01,   6.69578760e-01,
          1.00000000e+00],
       [ -1.28612050e+00,   1.03554132e+00,   1.61323094e+00,
          1.00000000e+00],
       [ -1.15300409e+00,   1.60273923e+00,  -1.01518562e+00,
          1.00000000e+00],
       [ -1.41806922e+00,   1.06255074e+00,  -9.78360166e-01,
          1.00000000e+00],
       [  1.47896413e+00,   3.80562786e-01,   1.34164324e+00,
          1.00000000e+00],
       [ -1.21489154e+00,   1.77992105e-01,  -4.62803854e-01,
          1.00000000e+00],
       [  4.42057396e-01,   1.39341619e+00,  -1.32820195e+00,
          1.00000000e+00],
       [ -8.59914463e-01,  -4.22967582e-01,  -8.12645637e-01,
          1.00000000e+00],
       [  5.44813920e-01,   8.19465927e-01,   2.07354907e+00,
          1.00000000e+00],
       [  8.57754241e-01,   6.70914095e-01,   3.38149702e-01,
          1.00000000e+00],
       [ -4.95595880e-01,  -1.18598381e+00,   1.77038355e-01,
          1.00000000e+00],
       [ -5.93681653e-01,  -5.71519414e-01,   3.84181516e-01,
          1.00000000e+00],
       [ -7.87313476e-02,  -1.44257334e+00,  -9.92169710e-01,
          1.00000000e+00],
       [  1.08662104e+00,  -1.07794612e+00,  -1.00597925e+00,
          1.00000000e+00],
       [  1.12281936e+00,   1.73778635e+00,   6.32753309e-01,
          1.00000000e+00],
       [ -1.27327593e+00,   1.15033137e+00,  -8.58677450e-01,
          1.00000000e+00],
       [ -1.19504085e+00,   1.71239749e-01,  -4.58200672e-01,
          1.00000000e+00],
       [  1.56070228e+00,  -6.32290618e-01,   2.96721070e-01,
          1.00000000e+00],
       [ -3.04095087e-01,  -1.00367020e+00,   8.35293289e-01,
          1.00000000e+00],
       [  5.90353742e-01,   2.43084817e-03,  -7.52804279e-01,
          1.00000000e+00],
       [  2.83251860e-01,   1.10981724e+00,   3.28943340e-01,
          1.00000000e+00],
       [  4.75920341e-01,  -1.46120984e-01,  -9.69153803e-01,
          1.00000000e+00],
       [ -1.66912209e+00,  -7.87594807e-01,  -1.14407469e+00,
          1.00000000e+00],
       [ -6.20538471e-01,   1.36640677e+00,   9.18150553e-01,
          1.00000000e+00],
       [  3.21989902e-02,  -1.48308748e+00,  -2.87882962e-01,
          1.00000000e+00],
       [ -1.58037782e+00,   9.20751268e-01,   6.74181942e-01,
          1.00000000e+00],
       [ -1.79152496e-01,  -3.28434597e-01,   1.86244718e-01,
          1.00000000e+00],
       [  2.97264113e-01,  -3.48691665e-01,   6.72064478e-03,
          1.00000000e+00],
       [ -7.16288868e-01,   8.46475352e-01,   8.62912377e-01,
          1.00000000e+00],
       [  4.82926468e-01,  -3.48691665e-01,  -2.28041604e-01,
          1.00000000e+00],
       [  1.92172214e-01,   9.13998912e-01,  -1.06582061e+00,
          1.00000000e+00],
       [ -3.48467222e-01,  -5.78271770e-01,  -1.15788424e+00,
          1.00000000e+00],
       [  1.02123053e+00,  -1.34128800e+00,   2.49704176e+00,
          1.00000000e+00],
       [ -1.50798117e+00,   9.68017760e-01,  -4.12168859e-01,
          1.00000000e+00],
       [  6.97781017e-01,  -1.21974559e+00,  -5.13438849e-01,
          1.00000000e+00],
       [  7.98202165e-01,   2.26879163e-02,   1.24497643e+00,
          1.00000000e+00],
       [  1.60273904e+00,  -8.55118367e-01,  -1.11185242e+00,
          1.00000000e+00],
       [ -1.13315340e+00,  -7.87594807e-01,  -5.59470662e-01,
          1.00000000e+00],
       [  2.03849092e-01,  -1.59625696e-01,   7.75451931e-01,
          1.00000000e+00],
       [ -1.48813048e+00,  -2.13644545e-01,  -6.23915201e-01,
          1.00000000e+00],
       [  2.49388915e-01,  -1.09145083e+00,  -8.17248818e-01,
          1.00000000e+00],
       [  8.79940308e-01,  -1.34128800e+00,  -8.03439274e-01,
          1.00000000e+00],
       [  1.51633014e+00,   1.73103399e+00,   5.17673775e-01,
          1.00000000e+00],
       [  1.18353913e+00,   4.68343414e-01,  -4.72010216e-01,
          1.00000000e+00],
       [  2.70407294e-01,  -1.04418434e+00,   2.13863806e-01,
          1.00000000e+00],
       [  1.51399477e+00,  -1.41556392e+00,  -3.15502050e-01,
          1.00000000e+00],
       [  2.16693657e-01,  -8.95632503e-01,  -5.96296113e-01,
          1.00000000e+00],
       [  1.11601758e-01,  -1.39530685e+00,  -1.02439198e+00,
          1.00000000e+00],
       [  8.34400486e-01,  -1.20624088e+00,  -1.45184340e-01,
          1.00000000e+00],
       [ -1.06075676e+00,  -1.18598381e+00,  -3.93111688e-02,
          1.00000000e+00],
       [  1.64127273e+00,   1.33264499e+00,   1.89862818e+00,
          1.00000000e+00],
       [  1.24659427e+00,  -1.32616272e-01,  -2.55016247e-02,
          1.00000000e+00],
       [  6.76762637e-01,   1.47444446e+00,  -5.04232486e-01,
          1.00000000e+00],
       [ -8.80728498e-02,  -1.42906863e+00,  -1.82009791e-01,
          1.00000000e+00],
       [  5.14454038e-01,   3.67058074e-01,  -5.68677025e-01,
          1.00000000e+00],
       [  1.62258973e+00,  -6.32290618e-01,  -1.23613832e+00,
          1.00000000e+00],
       [ -1.49863967e+00,  -7.53833027e-01,  -3.29311594e-01,
          1.00000000e+00],
       [ -1.25576062e+00,   1.20435022e+00,  -1.13947151e+00,
          1.00000000e+00],
       [ -8.35393020e-01,  -8.41613655e-01,  -1.13026515e+00,
          1.00000000e+00],
       [ -1.51615499e+00,  -1.29402151e+00,   4.81492770e-02,
          1.00000000e+00],
       [  2.30705910e-01,   1.26512143e+00,  -1.24074150e+00,
          1.00000000e+00],
       [  3.10313024e-02,   8.32970639e-01,  -1.13026515e+00,
          1.00000000e+00],
       [ -1.27094056e+00,  -1.32103093e+00,  -7.71217005e-01,
          1.00000000e+00],
       [ -6.17035408e-01,  -1.24000266e+00,  -1.03359834e+00,
          1.00000000e+00],
       [  3.49810063e-01,  -9.42898996e-01,  -1.11185242e+00,
          1.00000000e+00],
       [  1.59456522e+00,   1.26512143e+00,   1.64085003e+00,
          1.00000000e+00],
       [  9.93206022e-01,  -9.90165488e-01,  -1.00597925e+00,
          1.00000000e+00]])

2. Реализуйте функцию mserror - среднеквадратичную ошибку прогноза. Она принимает два аргумента - объекты Series y (значения целевого признака) и y_pred (предсказанные значения). Не используйте в этой функции циклы - тогда она будет вычислительно неэффективной.


In [20]:
def mserror(y, y_pred):
    return np.sum(np.square(y - y_pred)) / len(y)

Какова среднеквадратичная ошибка прогноза значений Sales, если всегда предсказывать медианное значение Sales по исходной выборке? Запишите ответ в файл '1.txt'.


In [22]:
answer1 = mserror(y, np.median(y))
print(answer1)
write_answer_to_file(answer1, '1.txt')


28.34575

3. Реализуйте функцию normal_equation, которая по заданным матрицам (массивам NumPy) X и y вычисляет вектор весов $w$ согласно нормальному уравнению линейной регрессии.


In [23]:
def normal_equation(X, y):
    return np.dot(np.linalg.pinv(X), y)

In [24]:
norm_eq_weights = normal_equation(X, y)
print(norm_eq_weights)


[  3.91925365   2.79206274  -0.02253861  14.0225    ]

Какие продажи предсказываются линейной моделью с весами, найденными с помощью нормального уравнения, в случае средних инвестиций в рекламу по ТВ, радио и в газетах? (то есть при нулевых значениях масштабированных признаков TV, Radio и Newspaper). Запишите ответ в файл '2.txt'.


In [26]:
answer2 = np.sum(np.array([0, 0, 0, 1]) * norm_eq_weights)
print(answer2)
write_answer_to_file(answer2, '2.txt')


14.0225

4. Напишите функцию linear_prediction, которая принимает на вход матрицу X и вектор весов линейной модели w, а возвращает вектор прогнозов в виде линейной комбинации столбцов матрицы X с весами w.


In [29]:
def linear_prediction(X, w):
    return np.dot(X, w)

Какова среднеквадратичная ошибка прогноза значений Sales в виде линейной модели с весами, найденными с помощью нормального уравнения? Запишите ответ в файл '3.txt'.


In [30]:
answer3 = mserror(y, linear_prediction(X, norm_eq_weights))
print(answer3)
write_answer_to_file(answer3, '3.txt')


2.78412631451

5. Напишите функцию stochastic_gradient_step, реализующую шаг стохастического градиентного спуска для линейной регрессии. Функция должна принимать матрицу X, вектора y и w, число train_ind - индекс объекта обучающей выборки (строки матрицы X), по которому считается изменение весов, а также число $\eta$ (eta) - шаг градиентного спуска (по умолчанию eta=0.01). Результатом будет вектор обновленных весов. Наша реализация функции будет явно написана для данных с 3 признаками, но несложно модифицировать для любого числа признаков, можете это сделать.


In [66]:
def stochastic_gradient_step(X, y, w, train_ind, eta=0.01):
    x = X[train_ind] * (np.sum(X[train_ind] * w) - y[train_ind]) * (2 / X.shape[0])
    grad0 = x[0]
    grad1 = x[1]
    grad2 = x[2]
    grad3 = x[3]
    return  w - eta * np.array([grad0, grad1, grad2, grad3])

6. Напишите функцию stochastic_gradient_descent, реализующую стохастический градиентный спуск для линейной регрессии. Функция принимает на вход следующие аргументы:

  • X - матрица, соответствующая обучающей выборке
  • y - вектор значений целевого признака
  • w_init - вектор начальных весов модели
  • eta - шаг градиентного спуска (по умолчанию 0.01)
  • max_iter - максимальное число итераций градиентного спуска (по умолчанию 10000)
  • max_weight_dist - максимальное евклидово расстояние между векторами весов на соседних итерациях градиентного спуска, при котором алгоритм прекращает работу (по умолчанию 1e-8)
  • seed - число, используемое для воспроизводимости сгенерированных псевдослучайных чисел (по умолчанию 42)
  • verbose - флаг печати информации (например, для отладки, по умолчанию False)

На каждой итерации в вектор (список) должно записываться текущее значение среднеквадратичной ошибки. Функция должна возвращать вектор весов $w$, а также вектор (список) ошибок.


In [67]:
def stochastic_gradient_descent(X, y, w_init, eta=1e-2, max_iter=1e4,
                                max_weight_dist=1e-8, seed=42, verbose=False):
    # Инициализируем расстояние между векторами весов на соседних
    # итерациях большим числом. 
    weight_dist = np.inf
    # Инициализируем вектор весов
    w = w_init
    # Сюда будем записывать ошибки на каждой итерации
    errors = [mserror(y, linear_prediction(X, w))]
    # Счетчик итераций
    iter_num = 0
    # Будем порождать псевдослучайные числа 
    # (номер объекта, который будет менять веса), а для воспроизводимости
    # этой последовательности псевдослучайных чисел используем seed.
    np.random.seed(seed)
        
    # Основной цикл
    while weight_dist > max_weight_dist and iter_num < max_iter: 
        # порождаем псевдослучайный 
        # индекс объекта обучающей выборки
        random_ind = np.random.randint(X.shape[0])
        
        # Ваш код здесь
        new_w = stochastic_gradient_step(X, y, w, random_ind, eta)
        errors.append(mserror(y, linear_prediction(X, new_w)))
        
        if verbose:
            print (errors[-1])
        weight_dist = np.sqrt(np.sum(np.square(new_w - w)))
        w = new_w
        iter_num += 1
        
    return w, errors

Запустите $10^5$ итераций стохастического градиентного спуска. Укажите вектор начальных весов w_init, состоящий из нулей. Оставьте параметры eta и seed равными их значениям по умолчанию (eta=0.01, seed=42 - это важно для проверки ответов).


In [68]:
%%time
stoch_grad_desc_weights, stoch_errors_by_iter = stochastic_gradient_descent(X, y, np.zeros(4), max_iter = 1e5, 
                                                                            verbose = False)


CPU times: user 3.03 s, sys: 108 ms, total: 3.14 s
Wall time: 3.06 s

Посмотрим, чему равна ошибка на первых 50 итерациях стохастического градиентного спуска. Видим, что ошибка не обязательно уменьшается на каждой итерации.


In [69]:
%pylab inline
plot(range(50), stoch_errors_by_iter[:50])
xlabel('Iteration number')
ylabel('MSE')


Populating the interactive namespace from numpy and matplotlib
Out[69]:
<matplotlib.text.Text at 0x117bd86a0>

Теперь посмотрим на зависимость ошибки от номера итерации для $10^5$ итераций стохастического градиентного спуска. Видим, что алгоритм сходится.


In [70]:
%pylab inline
plot(range(len(stoch_errors_by_iter)), stoch_errors_by_iter)
xlabel('Iteration number')
ylabel('MSE')


Populating the interactive namespace from numpy and matplotlib
Out[70]:
<matplotlib.text.Text at 0x115989cc0>

Посмотрим на вектор весов, к которому сошелся метод.


In [71]:
stoch_grad_desc_weights


Out[71]:
array([  3.91069256e+00,   2.78209808e+00,  -8.10462217e-03,
         1.40190566e+01])

Посмотрим на среднеквадратичную ошибку на последней итерации.


In [72]:
stoch_errors_by_iter[-20:]


Out[72]:
[2.7843972510768391,
 2.7843952768710505,
 2.7843948977269988,
 2.7844106449734669,
 2.7844129911978883,
 2.7844098550761265,
 2.7844101247343165,
 2.7844087286300403,
 2.78440988033792,
 2.7844088948741694,
 2.784411783755071,
 2.7844120305197793,
 2.7844124321697179,
 2.7844040186178045,
 2.78440697194468,
 2.7844125046389747,
 2.784410075611238,
 2.7844107986666473,
 2.7844125883527591,
 2.7844125884067039]

Какова среднеквадратичная ошибка прогноза значений Sales в виде линейной модели с весами, найденными с помощью градиентного спуска? Запишите ответ в файл '4.txt'.


In [73]:
answer4 = mserror(y, linear_prediction(X, stoch_grad_desc_weights))
print(answer4)
write_answer_to_file(answer4, '4.txt')


2.78441258841

Ответами к заданию будут текстовые файлы, полученные в ходе этого решения. Обратите внимание, что отправленные файлы не должны содержать пустую строку в конце. Данный нюанс является ограничением платформы Coursera. Мы работаем над исправлением этого ограничения.


In [ ]: