Линейная регрессия и стохастический градиентный спуск

Задание основано на материалах лекций по линейной регрессии и градиентному спуску. Вы будете прогнозировать выручку компании в зависимости от уровня ее инвестиций в рекламу по TV, в газетах и по радио.

Вы научитесь:

  • решать задачу восстановления линейной регрессии
  • реализовывать стохастический градиентный спуск для ее настройки
  • решать задачу линейной регрессии аналитически

Введение

Линейная регрессия - один из наиболее хорошо изученных методов машинного обучения, позволяющий прогнозировать значения количественного признака в виде линейной комбинации прочих признаков с параметрами - весами модели. Оптимальные (в смысле минимальности некоторого функционала ошибки) параметры линейной регрессии можно найти аналитически с помощью нормального уравнения или численно с помощью методов оптимизации.

Линейная регрессия использует простой функционал качества - среднеквадратичную ошибку. Мы будем работать с выборкой, содержащей 3 признака. Для настройки параметров (весов) модели решается следующая задача: $$\Large \frac{1}{\ell}\sum_{i=1}^\ell{{((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}^2} \rightarrow \min_{w_0, w_1, w_2, w_3},$$ где $x_{i1}, x_{i2}, x_{i3}$ - значения признаков $i$-го объекта, $y_i$ - значение целевого признака $i$-го объекта, $\ell$ - число объектов в обучающей выборке.

Градиентный спуск

Параметры $w_0, w_1, w_2, w_3$, по которым минимизируется среднеквадратичная ошибка, можно находить численно с помощью градиентного спуска. Градиентный шаг для весов будет выглядеть следующим образом: $$\Large w_0 \leftarrow w_0 - \frac{2\eta}{\ell} \sum_{i=1}^\ell{{((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}}$$ $$\Large w_j \leftarrow w_j - \frac{2\eta}{\ell} \sum_{i=1}^\ell{{x_{ij}((w_0 + w_1x_{i1} + w_2x_{i2} + w_3x_{i3}) - y_i)}},\ j \in \{1,2,3\}$$ Здесь $\eta$ - параметр, шаг градиентного спуска.

Стохастический градиентный спуск

Проблема градиентного спуска, описанного выше, в том, что на больших выборках считать на каждом шаге градиент по всем имеющимся данным может быть очень вычислительно сложно. В стохастическом варианте градиентного спуска поправки для весов вычисляются только с учетом одного случайно взятого объекта обучающей выборки: $$\Large w_0 \leftarrow w_0 - \frac{2\eta}{\ell} {((w_0 + w_1x_{k1} + w_2x_{k2} + w_3x_{k3}) - y_k)}$$ $$\Large w_j \leftarrow w_j - \frac{2\eta}{\ell} {x_{kj}((w_0 + w_1x_{k1} + w_2x_{k2} + w_3x_{k3}) - y_k)},\ j \in \{1,2,3\},$$ где $k$ - случайный индекс, $k \in \{1, \ldots, \ell\}$.

Нормальное уравнение

Нахождение вектора оптимальных весов $w$ может быть сделано и аналитически. Мы хотим найти такой вектор весов $w$, чтобы вектор $y$, приближающий целевой признак, получался умножением матрицы $X$ (состоящей из всех признаков объектов обучающей выборки, кроме целевого) на вектор весов $w$. То есть, чтобы выполнялось матричное уравнение: $$\Large y = Xw$$ Домножением слева на $X^T$ получаем: $$\Large X^Ty = X^TXw$$ Это хорошо, поскольку теперь матрица $X^TX$ - квадратная, и можно найти решение (вектор $w$) в виде: $$\Large w = {(X^TX)}^{-1}X^Ty$$ Матрица ${(X^TX)}^{-1}X^T$ - псевдообратная для матрицы $X$. В NumPy такую матрицу можно вычислить с помощью функции numpy.linalg.pinv.

Однако, нахождение псевдообратной матрицы - операция вычислительно сложная и нестабильная в случае малого определителя матрицы $X$ (проблема мультиколлинеарности). На практике лучше находить вектор весов $w$ решением матричного уравнения $$\Large X^TXw = X^Ty$$Это может быть сделано с помощью функции numpy.linalg.solve.

Но все же на практике для больших матриц $X$ быстрее работает градиентный спуск, особенно его стохастическая версия.

Инструкции по выполнению

В начале напишем простую функцию для записи ответов в текстовый файл. Ответами будут числа, полученные в ходе решения этого задания, округленные до 3 знаков после запятой. Полученные файлы после выполнения задания надо отправить в форму на странице задания на Coursera.org.


In [2]:
def write_answer_to_file(answer, filename):
    with open(filename, 'w') as f_out:
        f_out.write(str(round(answer, 3)))

1. Загрузите данные из файла advertising.csv в объект pandas DataFrame. Источник данных.


In [3]:
import pandas as pd
adver_data = pd.read_csv('advertising.csv')

Посмотрите на первые 5 записей и на статистику признаков в этом наборе данных.


In [4]:
adver_data.head(5)


Out[4]:
TV Radio Newspaper Sales
1 230.1 37.8 69.2 22.1
2 44.5 39.3 45.1 10.4
3 17.2 45.9 69.3 9.3
4 151.5 41.3 58.5 18.5
5 180.8 10.8 58.4 12.9

In [5]:
adver_data.describe()


Out[5]:
TV Radio Newspaper Sales
count 200.000000 200.000000 200.000000 200.000000
mean 147.042500 23.264000 30.554000 14.022500
std 85.854236 14.846809 21.778621 5.217457
min 0.700000 0.000000 0.300000 1.600000
25% 74.375000 9.975000 12.750000 10.375000
50% 149.750000 22.900000 25.750000 12.900000
75% 218.825000 36.525000 45.100000 17.400000
max 296.400000 49.600000 114.000000 27.000000

Создайте массивы NumPy X из столбцов TV, Radio и Newspaper и y - из столбца Sales. Используйте атрибут values объекта pandas DataFrame.


In [6]:
import numpy as np
X = adver_data[['TV', 'Radio', 'Newspaper']].values
y = adver_data[['Sales']].values
print 'X array', X.shape, ':'
print X
print 'y array', y.shape, ':'
print y


X array (200L, 3L) :
[[ 230.1   37.8   69.2]
 [  44.5   39.3   45.1]
 [  17.2   45.9   69.3]
 [ 151.5   41.3   58.5]
 [ 180.8   10.8   58.4]
 [   8.7   48.9   75. ]
 [  57.5   32.8   23.5]
 [ 120.2   19.6   11.6]
 [   8.6    2.1    1. ]
 [ 199.8    2.6   21.2]
 [  66.1    5.8   24.2]
 [ 214.7   24.     4. ]
 [  23.8   35.1   65.9]
 [  97.5    7.6    7.2]
 [ 204.1   32.9   46. ]
 [ 195.4   47.7   52.9]
 [  67.8   36.6  114. ]
 [ 281.4   39.6   55.8]
 [  69.2   20.5   18.3]
 [ 147.3   23.9   19.1]
 [ 218.4   27.7   53.4]
 [ 237.4    5.1   23.5]
 [  13.2   15.9   49.6]
 [ 228.3   16.9   26.2]
 [  62.3   12.6   18.3]
 [ 262.9    3.5   19.5]
 [ 142.9   29.3   12.6]
 [ 240.1   16.7   22.9]
 [ 248.8   27.1   22.9]
 [  70.6   16.    40.8]
 [ 292.9   28.3   43.2]
 [ 112.9   17.4   38.6]
 [  97.2    1.5   30. ]
 [ 265.6   20.     0.3]
 [  95.7    1.4    7.4]
 [ 290.7    4.1    8.5]
 [ 266.9   43.8    5. ]
 [  74.7   49.4   45.7]
 [  43.1   26.7   35.1]
 [ 228.    37.7   32. ]
 [ 202.5   22.3   31.6]
 [ 177.    33.4   38.7]
 [ 293.6   27.7    1.8]
 [ 206.9    8.4   26.4]
 [  25.1   25.7   43.3]
 [ 175.1   22.5   31.5]
 [  89.7    9.9   35.7]
 [ 239.9   41.5   18.5]
 [ 227.2   15.8   49.9]
 [  66.9   11.7   36.8]
 [ 199.8    3.1   34.6]
 [ 100.4    9.6    3.6]
 [ 216.4   41.7   39.6]
 [ 182.6   46.2   58.7]
 [ 262.7   28.8   15.9]
 [ 198.9   49.4   60. ]
 [   7.3   28.1   41.4]
 [ 136.2   19.2   16.6]
 [ 210.8   49.6   37.7]
 [ 210.7   29.5    9.3]
 [  53.5    2.    21.4]
 [ 261.3   42.7   54.7]
 [ 239.3   15.5   27.3]
 [ 102.7   29.6    8.4]
 [ 131.1   42.8   28.9]
 [  69.     9.3    0.9]
 [  31.5   24.6    2.2]
 [ 139.3   14.5   10.2]
 [ 237.4   27.5   11. ]
 [ 216.8   43.9   27.2]
 [ 199.1   30.6   38.7]
 [ 109.8   14.3   31.7]
 [  26.8   33.    19.3]
 [ 129.4    5.7   31.3]
 [ 213.4   24.6   13.1]
 [  16.9   43.7   89.4]
 [  27.5    1.6   20.7]
 [ 120.5   28.5   14.2]
 [   5.4   29.9    9.4]
 [ 116.     7.7   23.1]
 [  76.4   26.7   22.3]
 [ 239.8    4.1   36.9]
 [  75.3   20.3   32.5]
 [  68.4   44.5   35.6]
 [ 213.5   43.    33.8]
 [ 193.2   18.4   65.7]
 [  76.3   27.5   16. ]
 [ 110.7   40.6   63.2]
 [  88.3   25.5   73.4]
 [ 109.8   47.8   51.4]
 [ 134.3    4.9    9.3]
 [  28.6    1.5   33. ]
 [ 217.7   33.5   59. ]
 [ 250.9   36.5   72.3]
 [ 107.4   14.    10.9]
 [ 163.3   31.6   52.9]
 [ 197.6    3.5    5.9]
 [ 184.9   21.    22. ]
 [ 289.7   42.3   51.2]
 [ 135.2   41.7   45.9]
 [ 222.4    4.3   49.8]
 [ 296.4   36.3  100.9]
 [ 280.2   10.1   21.4]
 [ 187.9   17.2   17.9]
 [ 238.2   34.3    5.3]
 [ 137.9   46.4   59. ]
 [  25.    11.    29.7]
 [  90.4    0.3   23.2]
 [  13.1    0.4   25.6]
 [ 255.4   26.9    5.5]
 [ 225.8    8.2   56.5]
 [ 241.7   38.    23.2]
 [ 175.7   15.4    2.4]
 [ 209.6   20.6   10.7]
 [  78.2   46.8   34.5]
 [  75.1   35.    52.7]
 [ 139.2   14.3   25.6]
 [  76.4    0.8   14.8]
 [ 125.7   36.9   79.2]
 [  19.4   16.    22.3]
 [ 141.3   26.8   46.2]
 [  18.8   21.7   50.4]
 [ 224.     2.4   15.6]
 [ 123.1   34.6   12.4]
 [ 229.5   32.3   74.2]
 [  87.2   11.8   25.9]
 [   7.8   38.9   50.6]
 [  80.2    0.     9.2]
 [ 220.3   49.     3.2]
 [  59.6   12.    43.1]
 [   0.7   39.6    8.7]
 [ 265.2    2.9   43. ]
 [   8.4   27.2    2.1]
 [ 219.8   33.5   45.1]
 [  36.9   38.6   65.6]
 [  48.3   47.     8.5]
 [  25.6   39.     9.3]
 [ 273.7   28.9   59.7]
 [  43.    25.9   20.5]
 [ 184.9   43.9    1.7]
 [  73.4   17.    12.9]
 [ 193.7   35.4   75.6]
 [ 220.5   33.2   37.9]
 [ 104.6    5.7   34.4]
 [  96.2   14.8   38.9]
 [ 140.3    1.9    9. ]
 [ 240.1    7.3    8.7]
 [ 243.2   49.    44.3]
 [  38.    40.3   11.9]
 [  44.7   25.8   20.6]
 [ 280.7   13.9   37. ]
 [ 121.     8.4   48.7]
 [ 197.6   23.3   14.2]
 [ 171.3   39.7   37.7]
 [ 187.8   21.1    9.5]
 [   4.1   11.6    5.7]
 [  93.9   43.5   50.5]
 [ 149.8    1.3   24.3]
 [  11.7   36.9   45.2]
 [ 131.7   18.4   34.6]
 [ 172.5   18.1   30.7]
 [  85.7   35.8   49.3]
 [ 188.4   18.1   25.6]
 [ 163.5   36.8    7.4]
 [ 117.2   14.7    5.4]
 [ 234.5    3.4   84.8]
 [  17.9   37.6   21.6]
 [ 206.8    5.2   19.4]
 [ 215.4   23.6   57.6]
 [ 284.3   10.6    6.4]
 [  50.    11.6   18.4]
 [ 164.5   20.9   47.4]
 [  19.6   20.1   17. ]
 [ 168.4    7.1   12.8]
 [ 222.4    3.4   13.1]
 [ 276.9   48.9   41.8]
 [ 248.4   30.2   20.3]
 [ 170.2    7.8   35.2]
 [ 276.7    2.3   23.7]
 [ 165.6   10.    17.6]
 [ 156.6    2.6    8.3]
 [ 218.5    5.4   27.4]
 [  56.2    5.7   29.7]
 [ 287.6   43.    71.8]
 [ 253.8   21.3   30. ]
 [ 205.    45.1   19.6]
 [ 139.5    2.1   26.6]
 [ 191.1   28.7   18.2]
 [ 286.    13.9    3.7]
 [  18.7   12.1   23.4]
 [  39.5   41.1    5.8]
 [  75.5   10.8    6. ]
 [  17.2    4.1   31.6]
 [ 166.8   42.     3.6]
 [ 149.7   35.6    6. ]
 [  38.2    3.7   13.8]
 [  94.2    4.9    8.1]
 [ 177.     9.3    6.4]
 [ 283.6   42.    66.2]
 [ 232.1    8.6    8.7]]
y array (200L, 1L) :
[[ 22.1]
 [ 10.4]
 [  9.3]
 [ 18.5]
 [ 12.9]
 [  7.2]
 [ 11.8]
 [ 13.2]
 [  4.8]
 [ 10.6]
 [  8.6]
 [ 17.4]
 [  9.2]
 [  9.7]
 [ 19. ]
 [ 22.4]
 [ 12.5]
 [ 24.4]
 [ 11.3]
 [ 14.6]
 [ 18. ]
 [ 12.5]
 [  5.6]
 [ 15.5]
 [  9.7]
 [ 12. ]
 [ 15. ]
 [ 15.9]
 [ 18.9]
 [ 10.5]
 [ 21.4]
 [ 11.9]
 [  9.6]
 [ 17.4]
 [  9.5]
 [ 12.8]
 [ 25.4]
 [ 14.7]
 [ 10.1]
 [ 21.5]
 [ 16.6]
 [ 17.1]
 [ 20.7]
 [ 12.9]
 [  8.5]
 [ 14.9]
 [ 10.6]
 [ 23.2]
 [ 14.8]
 [  9.7]
 [ 11.4]
 [ 10.7]
 [ 22.6]
 [ 21.2]
 [ 20.2]
 [ 23.7]
 [  5.5]
 [ 13.2]
 [ 23.8]
 [ 18.4]
 [  8.1]
 [ 24.2]
 [ 15.7]
 [ 14. ]
 [ 18. ]
 [  9.3]
 [  9.5]
 [ 13.4]
 [ 18.9]
 [ 22.3]
 [ 18.3]
 [ 12.4]
 [  8.8]
 [ 11. ]
 [ 17. ]
 [  8.7]
 [  6.9]
 [ 14.2]
 [  5.3]
 [ 11. ]
 [ 11.8]
 [ 12.3]
 [ 11.3]
 [ 13.6]
 [ 21.7]
 [ 15.2]
 [ 12. ]
 [ 16. ]
 [ 12.9]
 [ 16.7]
 [ 11.2]
 [  7.3]
 [ 19.4]
 [ 22.2]
 [ 11.5]
 [ 16.9]
 [ 11.7]
 [ 15.5]
 [ 25.4]
 [ 17.2]
 [ 11.7]
 [ 23.8]
 [ 14.8]
 [ 14.7]
 [ 20.7]
 [ 19.2]
 [  7.2]
 [  8.7]
 [  5.3]
 [ 19.8]
 [ 13.4]
 [ 21.8]
 [ 14.1]
 [ 15.9]
 [ 14.6]
 [ 12.6]
 [ 12.2]
 [  9.4]
 [ 15.9]
 [  6.6]
 [ 15.5]
 [  7. ]
 [ 11.6]
 [ 15.2]
 [ 19.7]
 [ 10.6]
 [  6.6]
 [  8.8]
 [ 24.7]
 [  9.7]
 [  1.6]
 [ 12.7]
 [  5.7]
 [ 19.6]
 [ 10.8]
 [ 11.6]
 [  9.5]
 [ 20.8]
 [  9.6]
 [ 20.7]
 [ 10.9]
 [ 19.2]
 [ 20.1]
 [ 10.4]
 [ 11.4]
 [ 10.3]
 [ 13.2]
 [ 25.4]
 [ 10.9]
 [ 10.1]
 [ 16.1]
 [ 11.6]
 [ 16.6]
 [ 19. ]
 [ 15.6]
 [  3.2]
 [ 15.3]
 [ 10.1]
 [  7.3]
 [ 12.9]
 [ 14.4]
 [ 13.3]
 [ 14.9]
 [ 18. ]
 [ 11.9]
 [ 11.9]
 [  8. ]
 [ 12.2]
 [ 17.1]
 [ 15. ]
 [  8.4]
 [ 14.5]
 [  7.6]
 [ 11.7]
 [ 11.5]
 [ 27. ]
 [ 20.2]
 [ 11.7]
 [ 11.8]
 [ 12.6]
 [ 10.5]
 [ 12.2]
 [  8.7]
 [ 26.2]
 [ 17.6]
 [ 22.6]
 [ 10.3]
 [ 17.3]
 [ 15.9]
 [  6.7]
 [ 10.8]
 [  9.9]
 [  5.9]
 [ 19.6]
 [ 17.3]
 [  7.6]
 [  9.7]
 [ 12.8]
 [ 25.5]
 [ 13.4]]

Отмасштабируйте столбцы матрицы X, вычтя из каждого значения среднее по соответствующему столбцу и поделив результат на стандартное отклонение. Для определенности, используйте методы mean и std векторов NumPy (реализация std в Pandas может отличаться). Обратите внимание, что в numpy вызов функции .mean() без параметров возвращает среднее по всем элементам массива, а не по столбцам, как в pandas. Чтобы произвести вычисление по столбцам, необходимо указать параметр axis.


In [7]:
means, stds = X.mean(axis=0), X.std(axis=0)
print 'Mean values:', means
print 'Std. deviations:', stds


Mean values: [ 147.0425   23.264    30.554 ]
Std. deviations: [ 85.63933176  14.80964564  21.72410606]

In [8]:
X = (X - means) / stds
print 'X array', X.shape, ':'
print X


X array (200L, 3L) :
[[  9.69852266e-01   9.81522472e-01   1.77894547e+00]
 [ -1.19737623e+00   1.08280781e+00   6.69578760e-01]
 [ -1.51615499e+00   1.52846331e+00   1.78354865e+00]
 [  5.20496822e-02   1.21785493e+00   1.28640506e+00]
 [  3.94182198e-01  -8.41613655e-01   1.28180188e+00]
 [ -1.61540845e+00   1.73103399e+00   2.04592999e+00]
 [ -1.04557682e+00   6.43904671e-01  -3.24708413e-01]
 [ -3.13436589e-01  -2.47406325e-01  -8.72486994e-01]
 [ -1.61657614e+00  -1.42906863e+00  -1.36042422e+00]
 [  6.16042873e-01  -1.39530685e+00  -4.30581584e-01]
 [ -9.45155670e-01  -1.17923146e+00  -2.92486143e-01]
 [  7.90028350e-01   4.96973404e-02  -1.22232878e+00]
 [ -1.43908760e+00   7.99208859e-01   1.62704048e+00]
 [ -5.78501712e-01  -1.05768905e+00  -1.07502697e+00]
 [  6.66253447e-01   6.50657027e-01   7.11007392e-01]
 [  5.64664612e-01   1.65000572e+00   1.02862691e+00]
 [ -9.25304978e-01   9.00494200e-01   3.84117072e+00]
 [  1.56887609e+00   1.10306488e+00   1.16211917e+00]
 [ -9.08957349e-01  -1.86635121e-01  -5.64073843e-01]
 [  3.00679600e-03   4.29449843e-02  -5.27248393e-01]
 [  8.33232798e-01   2.99534513e-01   1.05164281e+00]
 [  1.05509347e+00  -1.22649795e+00  -3.24708413e-01]
 [ -1.56286250e+00  -4.97243498e-01   8.76721921e-01]
 [  9.48833887e-01  -4.29719938e-01  -2.00422516e-01]
 [ -9.89527805e-01  -7.20071247e-01  -5.64073843e-01]
 [  1.35285385e+00  -1.33453565e+00  -5.08835667e-01]
 [ -4.83714657e-02   4.07572210e-01  -8.26455181e-01]
 [  1.08662104e+00  -4.43224650e-01  -3.52327501e-01]
 [  1.18820988e+00   2.59020377e-01  -3.52327501e-01]
 [ -8.92609721e-01  -4.90491142e-01   4.71641962e-01]
 [  1.70316018e+00   3.40048650e-01   5.82118314e-01]
 [ -3.98677796e-01  -3.95958157e-01   3.70371972e-01]
 [ -5.82004775e-01  -1.46958277e+00  -2.55016247e-02]
 [  1.38438142e+00  -2.20396901e-01  -1.39264649e+00]
 [ -5.99520091e-01  -1.47633512e+00  -1.06582061e+00]
 [  1.67747105e+00  -1.29402151e+00  -1.01518562e+00]
 [  1.39956136e+00   1.38666383e+00  -1.17629696e+00]
 [ -8.44734522e-01   1.76479577e+00   6.97197848e-01]
 [ -1.21372386e+00   2.32010953e-01   2.09260624e-01]
 [  9.45330823e-01   9.74770116e-01   6.65620024e-02]
 [  6.47570443e-01  -6.50927121e-02   4.81492770e-02]
 [  3.49810063e-01   6.84418807e-01   3.74975153e-01]
 [  1.71133400e+00   2.99534513e-01  -1.32359877e+00]
 [  6.98948705e-01  -1.00367020e+00  -1.91216154e-01]
 [ -1.42390765e+00   1.64487393e-01   5.86721496e-01]
 [  3.27623995e-01  -5.15880000e-02   4.35460956e-02]
 [ -6.69581357e-01  -9.02384859e-01   2.36879713e-01]
 [  1.08428567e+00   1.23135965e+00  -5.54867481e-01]
 [  9.35989321e-01  -5.03995854e-01   8.90531465e-01]
 [ -9.35814168e-01  -7.80842451e-01   2.87514708e-01]
 [  6.16042873e-01  -1.36154507e+00   1.86244718e-01]
 [ -5.44638766e-01  -9.22641928e-01  -1.24074150e+00]
 [  8.09879042e-01   1.24486436e+00   4.16403786e-01]
 [  4.15200577e-01   1.54872038e+00   1.29561142e+00]
 [  1.35051848e+00   3.73810430e-01  -6.74550196e-01]
 [  6.05533683e-01   1.76479577e+00   1.35545278e+00]
 [ -1.63175608e+00   3.26543937e-01   4.99261050e-01]
 [ -1.26606546e-01  -2.74415749e-01  -6.42327927e-01]
 [  7.44488528e-01   1.77830048e+00   3.28943340e-01]
 [  7.43320840e-01   4.21076922e-01  -9.78360166e-01]
 [ -1.09228433e+00  -1.43582099e+00  -4.21375221e-01]
 [  1.33417085e+00   1.31238792e+00   1.11148417e+00]
 [  1.07727954e+00  -5.24252922e-01  -1.49787521e-01]
 [ -5.17781948e-01   4.27829278e-01  -1.01978880e+00]
 [ -1.86158622e-01   1.31914027e+00  -7.61366196e-02]
 [ -9.11292725e-01  -9.42898996e-01  -1.36502740e+00]
 [ -1.34917564e+00   9.02114765e-02  -1.30518604e+00]
 [ -9.04082253e-02  -5.91776482e-01  -9.36931533e-01]
 [  1.05509347e+00   2.86029801e-01  -9.00106083e-01]
 [  8.14549794e-01   1.39341619e+00  -1.54390703e-01]
 [  6.07869059e-01   4.95352838e-01   3.74975153e-01]
 [ -4.34876116e-01  -6.05281194e-01   5.27524584e-02]
 [ -1.40405696e+00   6.57409383e-01  -5.18042030e-01]
 [ -2.06009314e-01  -1.18598381e+00   3.43397329e-02]
 [  7.74848409e-01   9.02114765e-02  -8.03439274e-01]
 [ -1.51965805e+00   1.37991148e+00   2.70878810e+00]
 [ -1.39588315e+00  -1.46283041e+00  -4.53597491e-01]
 [ -3.09933525e-01   3.53553362e-01  -7.52804279e-01]
 [ -1.65394214e+00   4.48086346e-01  -9.73756984e-01]
 [ -3.62479475e-01  -1.05093669e+00  -3.43121138e-01]
 [ -8.24883830e-01   2.32010953e-01  -3.79946589e-01]
 [  1.08311798e+00  -1.29402151e+00   2.92117889e-01]
 [ -8.37728396e-01  -2.00139833e-01   8.95779092e-02]
 [ -9.18298852e-01   1.43393033e+00   2.32276531e-01]
 [  7.76016097e-01   1.33264499e+00   1.49419267e-01]
 [  5.38975481e-01  -3.28434597e-01   1.61783412e+00]
 [ -8.26051518e-01   2.86029801e-01  -6.69947015e-01]
 [ -4.24366926e-01   1.17058844e+00   1.50275459e+00]
 [ -6.85928986e-01   1.50982681e-01   1.97227908e+00]
 [ -4.34876116e-01   1.65675807e+00   9.59579186e-01]
 [ -1.48792614e-01  -1.24000266e+00  -9.78360166e-01]
 [ -1.38303858e+00  -1.46958277e+00   1.12593816e-01]
 [  8.25058983e-01   6.91171163e-01   1.30942097e+00]
 [  1.21273132e+00   8.93741844e-01   1.92164409e+00]
 [ -4.62900623e-01  -6.25538262e-01  -9.04709264e-01]
 [  1.89836839e-01   5.62876398e-01   1.02862691e+00]
 [  5.90353742e-01  -1.33453565e+00  -1.13486833e+00]
 [  4.42057396e-01  -1.52873340e-01  -3.93756133e-01]
 [  1.66579418e+00   1.28537849e+00   9.50372823e-01]
 [ -1.38283424e-01   1.24486436e+00   7.06404211e-01]
 [  8.79940308e-01  -1.28051680e+00   8.85928284e-01]
 [  1.74402926e+00   8.80237132e-01   3.23815396e+00]
 [  1.55486384e+00  -8.88880147e-01  -4.21375221e-01]
 [  4.77088029e-01  -4.09462869e-01  -5.82486569e-01]
 [  1.06443498e+00   7.45190011e-01  -1.16248742e+00]
 [ -1.06755854e-01   1.56222509e+00   1.30942097e+00]
 [ -1.42507534e+00  -8.28108943e-01  -3.93111688e-02]
 [ -6.61407543e-01  -1.55061104e+00  -3.38517957e-01]
 [ -1.56403019e+00  -1.54385868e+00  -2.28041604e-01]
 [  1.26527727e+00   2.45515665e-01  -1.15328106e+00]
 [  9.19641692e-01  -1.01717491e+00   1.19434143e+00]
 [  1.10530405e+00   9.95027184e-01  -3.38517957e-01]
 [  3.34630122e-01  -5.31005278e-01  -1.29597968e+00]
 [  7.30476274e-01  -1.79882765e-01  -9.13915627e-01]
 [ -8.03865450e-01   1.58923451e+00   1.81641536e-01]
 [ -8.40063771e-01   7.92456503e-01   1.01942054e+00]
 [ -9.15759131e-02  -6.05281194e-01  -2.28041604e-01]
 [ -8.24883830e-01  -1.51684926e+00  -7.25185191e-01]
 [ -2.49213762e-01   9.20751268e-01   2.23926360e+00]
 [ -1.49046586e+00  -4.90491142e-01  -3.79946589e-01]
 [ -6.70544700e-02   2.38763309e-01   7.20213755e-01]
 [ -1.49747198e+00  -1.05606848e-01   9.13547372e-01]
 [  8.98623313e-01  -1.40881156e+00  -6.88359740e-01]
 [ -2.79573643e-01   7.65447079e-01  -8.35661544e-01]
 [  9.62846140e-01   6.10142891e-01   2.00910454e+00]
 [ -6.98773552e-01  -7.74090095e-01  -2.14232060e-01]
 [ -1.62591764e+00   1.05579839e+00   9.22753735e-01]
 [ -7.80511695e-01  -1.57086811e+00  -9.82963347e-01]
 [  8.55418865e-01   1.73778635e+00  -1.25915423e+00]
 [ -1.02105537e+00  -7.60585383e-01   5.77515133e-01]
 [ -1.70882347e+00   1.10306488e+00  -1.00597925e+00]
 [  1.37971067e+00  -1.37504978e+00   5.72911952e-01]
 [ -1.61891151e+00   2.65772733e-01  -1.30978922e+00]
 [  8.49580427e-01   6.91171163e-01   6.69578760e-01]
 [ -1.28612050e+00   1.03554132e+00   1.61323094e+00]
 [ -1.15300409e+00   1.60273923e+00  -1.01518562e+00]
 [ -1.41806922e+00   1.06255074e+00  -9.78360166e-01]
 [  1.47896413e+00   3.80562786e-01   1.34164324e+00]
 [ -1.21489154e+00   1.77992105e-01  -4.62803854e-01]
 [  4.42057396e-01   1.39341619e+00  -1.32820195e+00]
 [ -8.59914463e-01  -4.22967582e-01  -8.12645637e-01]
 [  5.44813920e-01   8.19465927e-01   2.07354907e+00]
 [  8.57754241e-01   6.70914095e-01   3.38149702e-01]
 [ -4.95595880e-01  -1.18598381e+00   1.77038355e-01]
 [ -5.93681653e-01  -5.71519414e-01   3.84181516e-01]
 [ -7.87313476e-02  -1.44257334e+00  -9.92169710e-01]
 [  1.08662104e+00  -1.07794612e+00  -1.00597925e+00]
 [  1.12281936e+00   1.73778635e+00   6.32753309e-01]
 [ -1.27327593e+00   1.15033137e+00  -8.58677450e-01]
 [ -1.19504085e+00   1.71239749e-01  -4.58200672e-01]
 [  1.56070228e+00  -6.32290618e-01   2.96721070e-01]
 [ -3.04095087e-01  -1.00367020e+00   8.35293289e-01]
 [  5.90353742e-01   2.43084817e-03  -7.52804279e-01]
 [  2.83251860e-01   1.10981724e+00   3.28943340e-01]
 [  4.75920341e-01  -1.46120984e-01  -9.69153803e-01]
 [ -1.66912209e+00  -7.87594807e-01  -1.14407469e+00]
 [ -6.20538471e-01   1.36640677e+00   9.18150553e-01]
 [  3.21989902e-02  -1.48308748e+00  -2.87882962e-01]
 [ -1.58037782e+00   9.20751268e-01   6.74181942e-01]
 [ -1.79152496e-01  -3.28434597e-01   1.86244718e-01]
 [  2.97264113e-01  -3.48691665e-01   6.72064478e-03]
 [ -7.16288868e-01   8.46475352e-01   8.62912377e-01]
 [  4.82926468e-01  -3.48691665e-01  -2.28041604e-01]
 [  1.92172214e-01   9.13998912e-01  -1.06582061e+00]
 [ -3.48467222e-01  -5.78271770e-01  -1.15788424e+00]
 [  1.02123053e+00  -1.34128800e+00   2.49704176e+00]
 [ -1.50798117e+00   9.68017760e-01  -4.12168859e-01]
 [  6.97781017e-01  -1.21974559e+00  -5.13438849e-01]
 [  7.98202165e-01   2.26879163e-02   1.24497643e+00]
 [  1.60273904e+00  -8.55118367e-01  -1.11185242e+00]
 [ -1.13315340e+00  -7.87594807e-01  -5.59470662e-01]
 [  2.03849092e-01  -1.59625696e-01   7.75451931e-01]
 [ -1.48813048e+00  -2.13644545e-01  -6.23915201e-01]
 [  2.49388915e-01  -1.09145083e+00  -8.17248818e-01]
 [  8.79940308e-01  -1.34128800e+00  -8.03439274e-01]
 [  1.51633014e+00   1.73103399e+00   5.17673775e-01]
 [  1.18353913e+00   4.68343414e-01  -4.72010216e-01]
 [  2.70407294e-01  -1.04418434e+00   2.13863806e-01]
 [  1.51399477e+00  -1.41556392e+00  -3.15502050e-01]
 [  2.16693657e-01  -8.95632503e-01  -5.96296113e-01]
 [  1.11601758e-01  -1.39530685e+00  -1.02439198e+00]
 [  8.34400486e-01  -1.20624088e+00  -1.45184340e-01]
 [ -1.06075676e+00  -1.18598381e+00  -3.93111688e-02]
 [  1.64127273e+00   1.33264499e+00   1.89862818e+00]
 [  1.24659427e+00  -1.32616272e-01  -2.55016247e-02]
 [  6.76762637e-01   1.47444446e+00  -5.04232486e-01]
 [ -8.80728498e-02  -1.42906863e+00  -1.82009791e-01]
 [  5.14454038e-01   3.67058074e-01  -5.68677025e-01]
 [  1.62258973e+00  -6.32290618e-01  -1.23613832e+00]
 [ -1.49863967e+00  -7.53833027e-01  -3.29311594e-01]
 [ -1.25576062e+00   1.20435022e+00  -1.13947151e+00]
 [ -8.35393020e-01  -8.41613655e-01  -1.13026515e+00]
 [ -1.51615499e+00  -1.29402151e+00   4.81492770e-02]
 [  2.30705910e-01   1.26512143e+00  -1.24074150e+00]
 [  3.10313024e-02   8.32970639e-01  -1.13026515e+00]
 [ -1.27094056e+00  -1.32103093e+00  -7.71217005e-01]
 [ -6.17035408e-01  -1.24000266e+00  -1.03359834e+00]
 [  3.49810063e-01  -9.42898996e-01  -1.11185242e+00]
 [  1.59456522e+00   1.26512143e+00   1.64085003e+00]
 [  9.93206022e-01  -9.90165488e-01  -1.00597925e+00]]

Добавьте к матрице X столбец из единиц, используя методы hstack, ones и reshape библиотеки NumPy. Вектор из единиц нужен для того, чтобы не обрабатывать отдельно коэффициент $w_0$ линейной регрессии.


In [9]:
X = np.hstack( (X, np.ones((X.shape[0], 1))) )
print 'X array', X.shape, ':'
print X


X array (200L, 4L) :
[[  9.69852266e-01   9.81522472e-01   1.77894547e+00   1.00000000e+00]
 [ -1.19737623e+00   1.08280781e+00   6.69578760e-01   1.00000000e+00]
 [ -1.51615499e+00   1.52846331e+00   1.78354865e+00   1.00000000e+00]
 [  5.20496822e-02   1.21785493e+00   1.28640506e+00   1.00000000e+00]
 [  3.94182198e-01  -8.41613655e-01   1.28180188e+00   1.00000000e+00]
 [ -1.61540845e+00   1.73103399e+00   2.04592999e+00   1.00000000e+00]
 [ -1.04557682e+00   6.43904671e-01  -3.24708413e-01   1.00000000e+00]
 [ -3.13436589e-01  -2.47406325e-01  -8.72486994e-01   1.00000000e+00]
 [ -1.61657614e+00  -1.42906863e+00  -1.36042422e+00   1.00000000e+00]
 [  6.16042873e-01  -1.39530685e+00  -4.30581584e-01   1.00000000e+00]
 [ -9.45155670e-01  -1.17923146e+00  -2.92486143e-01   1.00000000e+00]
 [  7.90028350e-01   4.96973404e-02  -1.22232878e+00   1.00000000e+00]
 [ -1.43908760e+00   7.99208859e-01   1.62704048e+00   1.00000000e+00]
 [ -5.78501712e-01  -1.05768905e+00  -1.07502697e+00   1.00000000e+00]
 [  6.66253447e-01   6.50657027e-01   7.11007392e-01   1.00000000e+00]
 [  5.64664612e-01   1.65000572e+00   1.02862691e+00   1.00000000e+00]
 [ -9.25304978e-01   9.00494200e-01   3.84117072e+00   1.00000000e+00]
 [  1.56887609e+00   1.10306488e+00   1.16211917e+00   1.00000000e+00]
 [ -9.08957349e-01  -1.86635121e-01  -5.64073843e-01   1.00000000e+00]
 [  3.00679600e-03   4.29449843e-02  -5.27248393e-01   1.00000000e+00]
 [  8.33232798e-01   2.99534513e-01   1.05164281e+00   1.00000000e+00]
 [  1.05509347e+00  -1.22649795e+00  -3.24708413e-01   1.00000000e+00]
 [ -1.56286250e+00  -4.97243498e-01   8.76721921e-01   1.00000000e+00]
 [  9.48833887e-01  -4.29719938e-01  -2.00422516e-01   1.00000000e+00]
 [ -9.89527805e-01  -7.20071247e-01  -5.64073843e-01   1.00000000e+00]
 [  1.35285385e+00  -1.33453565e+00  -5.08835667e-01   1.00000000e+00]
 [ -4.83714657e-02   4.07572210e-01  -8.26455181e-01   1.00000000e+00]
 [  1.08662104e+00  -4.43224650e-01  -3.52327501e-01   1.00000000e+00]
 [  1.18820988e+00   2.59020377e-01  -3.52327501e-01   1.00000000e+00]
 [ -8.92609721e-01  -4.90491142e-01   4.71641962e-01   1.00000000e+00]
 [  1.70316018e+00   3.40048650e-01   5.82118314e-01   1.00000000e+00]
 [ -3.98677796e-01  -3.95958157e-01   3.70371972e-01   1.00000000e+00]
 [ -5.82004775e-01  -1.46958277e+00  -2.55016247e-02   1.00000000e+00]
 [  1.38438142e+00  -2.20396901e-01  -1.39264649e+00   1.00000000e+00]
 [ -5.99520091e-01  -1.47633512e+00  -1.06582061e+00   1.00000000e+00]
 [  1.67747105e+00  -1.29402151e+00  -1.01518562e+00   1.00000000e+00]
 [  1.39956136e+00   1.38666383e+00  -1.17629696e+00   1.00000000e+00]
 [ -8.44734522e-01   1.76479577e+00   6.97197848e-01   1.00000000e+00]
 [ -1.21372386e+00   2.32010953e-01   2.09260624e-01   1.00000000e+00]
 [  9.45330823e-01   9.74770116e-01   6.65620024e-02   1.00000000e+00]
 [  6.47570443e-01  -6.50927121e-02   4.81492770e-02   1.00000000e+00]
 [  3.49810063e-01   6.84418807e-01   3.74975153e-01   1.00000000e+00]
 [  1.71133400e+00   2.99534513e-01  -1.32359877e+00   1.00000000e+00]
 [  6.98948705e-01  -1.00367020e+00  -1.91216154e-01   1.00000000e+00]
 [ -1.42390765e+00   1.64487393e-01   5.86721496e-01   1.00000000e+00]
 [  3.27623995e-01  -5.15880000e-02   4.35460956e-02   1.00000000e+00]
 [ -6.69581357e-01  -9.02384859e-01   2.36879713e-01   1.00000000e+00]
 [  1.08428567e+00   1.23135965e+00  -5.54867481e-01   1.00000000e+00]
 [  9.35989321e-01  -5.03995854e-01   8.90531465e-01   1.00000000e+00]
 [ -9.35814168e-01  -7.80842451e-01   2.87514708e-01   1.00000000e+00]
 [  6.16042873e-01  -1.36154507e+00   1.86244718e-01   1.00000000e+00]
 [ -5.44638766e-01  -9.22641928e-01  -1.24074150e+00   1.00000000e+00]
 [  8.09879042e-01   1.24486436e+00   4.16403786e-01   1.00000000e+00]
 [  4.15200577e-01   1.54872038e+00   1.29561142e+00   1.00000000e+00]
 [  1.35051848e+00   3.73810430e-01  -6.74550196e-01   1.00000000e+00]
 [  6.05533683e-01   1.76479577e+00   1.35545278e+00   1.00000000e+00]
 [ -1.63175608e+00   3.26543937e-01   4.99261050e-01   1.00000000e+00]
 [ -1.26606546e-01  -2.74415749e-01  -6.42327927e-01   1.00000000e+00]
 [  7.44488528e-01   1.77830048e+00   3.28943340e-01   1.00000000e+00]
 [  7.43320840e-01   4.21076922e-01  -9.78360166e-01   1.00000000e+00]
 [ -1.09228433e+00  -1.43582099e+00  -4.21375221e-01   1.00000000e+00]
 [  1.33417085e+00   1.31238792e+00   1.11148417e+00   1.00000000e+00]
 [  1.07727954e+00  -5.24252922e-01  -1.49787521e-01   1.00000000e+00]
 [ -5.17781948e-01   4.27829278e-01  -1.01978880e+00   1.00000000e+00]
 [ -1.86158622e-01   1.31914027e+00  -7.61366196e-02   1.00000000e+00]
 [ -9.11292725e-01  -9.42898996e-01  -1.36502740e+00   1.00000000e+00]
 [ -1.34917564e+00   9.02114765e-02  -1.30518604e+00   1.00000000e+00]
 [ -9.04082253e-02  -5.91776482e-01  -9.36931533e-01   1.00000000e+00]
 [  1.05509347e+00   2.86029801e-01  -9.00106083e-01   1.00000000e+00]
 [  8.14549794e-01   1.39341619e+00  -1.54390703e-01   1.00000000e+00]
 [  6.07869059e-01   4.95352838e-01   3.74975153e-01   1.00000000e+00]
 [ -4.34876116e-01  -6.05281194e-01   5.27524584e-02   1.00000000e+00]
 [ -1.40405696e+00   6.57409383e-01  -5.18042030e-01   1.00000000e+00]
 [ -2.06009314e-01  -1.18598381e+00   3.43397329e-02   1.00000000e+00]
 [  7.74848409e-01   9.02114765e-02  -8.03439274e-01   1.00000000e+00]
 [ -1.51965805e+00   1.37991148e+00   2.70878810e+00   1.00000000e+00]
 [ -1.39588315e+00  -1.46283041e+00  -4.53597491e-01   1.00000000e+00]
 [ -3.09933525e-01   3.53553362e-01  -7.52804279e-01   1.00000000e+00]
 [ -1.65394214e+00   4.48086346e-01  -9.73756984e-01   1.00000000e+00]
 [ -3.62479475e-01  -1.05093669e+00  -3.43121138e-01   1.00000000e+00]
 [ -8.24883830e-01   2.32010953e-01  -3.79946589e-01   1.00000000e+00]
 [  1.08311798e+00  -1.29402151e+00   2.92117889e-01   1.00000000e+00]
 [ -8.37728396e-01  -2.00139833e-01   8.95779092e-02   1.00000000e+00]
 [ -9.18298852e-01   1.43393033e+00   2.32276531e-01   1.00000000e+00]
 [  7.76016097e-01   1.33264499e+00   1.49419267e-01   1.00000000e+00]
 [  5.38975481e-01  -3.28434597e-01   1.61783412e+00   1.00000000e+00]
 [ -8.26051518e-01   2.86029801e-01  -6.69947015e-01   1.00000000e+00]
 [ -4.24366926e-01   1.17058844e+00   1.50275459e+00   1.00000000e+00]
 [ -6.85928986e-01   1.50982681e-01   1.97227908e+00   1.00000000e+00]
 [ -4.34876116e-01   1.65675807e+00   9.59579186e-01   1.00000000e+00]
 [ -1.48792614e-01  -1.24000266e+00  -9.78360166e-01   1.00000000e+00]
 [ -1.38303858e+00  -1.46958277e+00   1.12593816e-01   1.00000000e+00]
 [  8.25058983e-01   6.91171163e-01   1.30942097e+00   1.00000000e+00]
 [  1.21273132e+00   8.93741844e-01   1.92164409e+00   1.00000000e+00]
 [ -4.62900623e-01  -6.25538262e-01  -9.04709264e-01   1.00000000e+00]
 [  1.89836839e-01   5.62876398e-01   1.02862691e+00   1.00000000e+00]
 [  5.90353742e-01  -1.33453565e+00  -1.13486833e+00   1.00000000e+00]
 [  4.42057396e-01  -1.52873340e-01  -3.93756133e-01   1.00000000e+00]
 [  1.66579418e+00   1.28537849e+00   9.50372823e-01   1.00000000e+00]
 [ -1.38283424e-01   1.24486436e+00   7.06404211e-01   1.00000000e+00]
 [  8.79940308e-01  -1.28051680e+00   8.85928284e-01   1.00000000e+00]
 [  1.74402926e+00   8.80237132e-01   3.23815396e+00   1.00000000e+00]
 [  1.55486384e+00  -8.88880147e-01  -4.21375221e-01   1.00000000e+00]
 [  4.77088029e-01  -4.09462869e-01  -5.82486569e-01   1.00000000e+00]
 [  1.06443498e+00   7.45190011e-01  -1.16248742e+00   1.00000000e+00]
 [ -1.06755854e-01   1.56222509e+00   1.30942097e+00   1.00000000e+00]
 [ -1.42507534e+00  -8.28108943e-01  -3.93111688e-02   1.00000000e+00]
 [ -6.61407543e-01  -1.55061104e+00  -3.38517957e-01   1.00000000e+00]
 [ -1.56403019e+00  -1.54385868e+00  -2.28041604e-01   1.00000000e+00]
 [  1.26527727e+00   2.45515665e-01  -1.15328106e+00   1.00000000e+00]
 [  9.19641692e-01  -1.01717491e+00   1.19434143e+00   1.00000000e+00]
 [  1.10530405e+00   9.95027184e-01  -3.38517957e-01   1.00000000e+00]
 [  3.34630122e-01  -5.31005278e-01  -1.29597968e+00   1.00000000e+00]
 [  7.30476274e-01  -1.79882765e-01  -9.13915627e-01   1.00000000e+00]
 [ -8.03865450e-01   1.58923451e+00   1.81641536e-01   1.00000000e+00]
 [ -8.40063771e-01   7.92456503e-01   1.01942054e+00   1.00000000e+00]
 [ -9.15759131e-02  -6.05281194e-01  -2.28041604e-01   1.00000000e+00]
 [ -8.24883830e-01  -1.51684926e+00  -7.25185191e-01   1.00000000e+00]
 [ -2.49213762e-01   9.20751268e-01   2.23926360e+00   1.00000000e+00]
 [ -1.49046586e+00  -4.90491142e-01  -3.79946589e-01   1.00000000e+00]
 [ -6.70544700e-02   2.38763309e-01   7.20213755e-01   1.00000000e+00]
 [ -1.49747198e+00  -1.05606848e-01   9.13547372e-01   1.00000000e+00]
 [  8.98623313e-01  -1.40881156e+00  -6.88359740e-01   1.00000000e+00]
 [ -2.79573643e-01   7.65447079e-01  -8.35661544e-01   1.00000000e+00]
 [  9.62846140e-01   6.10142891e-01   2.00910454e+00   1.00000000e+00]
 [ -6.98773552e-01  -7.74090095e-01  -2.14232060e-01   1.00000000e+00]
 [ -1.62591764e+00   1.05579839e+00   9.22753735e-01   1.00000000e+00]
 [ -7.80511695e-01  -1.57086811e+00  -9.82963347e-01   1.00000000e+00]
 [  8.55418865e-01   1.73778635e+00  -1.25915423e+00   1.00000000e+00]
 [ -1.02105537e+00  -7.60585383e-01   5.77515133e-01   1.00000000e+00]
 [ -1.70882347e+00   1.10306488e+00  -1.00597925e+00   1.00000000e+00]
 [  1.37971067e+00  -1.37504978e+00   5.72911952e-01   1.00000000e+00]
 [ -1.61891151e+00   2.65772733e-01  -1.30978922e+00   1.00000000e+00]
 [  8.49580427e-01   6.91171163e-01   6.69578760e-01   1.00000000e+00]
 [ -1.28612050e+00   1.03554132e+00   1.61323094e+00   1.00000000e+00]
 [ -1.15300409e+00   1.60273923e+00  -1.01518562e+00   1.00000000e+00]
 [ -1.41806922e+00   1.06255074e+00  -9.78360166e-01   1.00000000e+00]
 [  1.47896413e+00   3.80562786e-01   1.34164324e+00   1.00000000e+00]
 [ -1.21489154e+00   1.77992105e-01  -4.62803854e-01   1.00000000e+00]
 [  4.42057396e-01   1.39341619e+00  -1.32820195e+00   1.00000000e+00]
 [ -8.59914463e-01  -4.22967582e-01  -8.12645637e-01   1.00000000e+00]
 [  5.44813920e-01   8.19465927e-01   2.07354907e+00   1.00000000e+00]
 [  8.57754241e-01   6.70914095e-01   3.38149702e-01   1.00000000e+00]
 [ -4.95595880e-01  -1.18598381e+00   1.77038355e-01   1.00000000e+00]
 [ -5.93681653e-01  -5.71519414e-01   3.84181516e-01   1.00000000e+00]
 [ -7.87313476e-02  -1.44257334e+00  -9.92169710e-01   1.00000000e+00]
 [  1.08662104e+00  -1.07794612e+00  -1.00597925e+00   1.00000000e+00]
 [  1.12281936e+00   1.73778635e+00   6.32753309e-01   1.00000000e+00]
 [ -1.27327593e+00   1.15033137e+00  -8.58677450e-01   1.00000000e+00]
 [ -1.19504085e+00   1.71239749e-01  -4.58200672e-01   1.00000000e+00]
 [  1.56070228e+00  -6.32290618e-01   2.96721070e-01   1.00000000e+00]
 [ -3.04095087e-01  -1.00367020e+00   8.35293289e-01   1.00000000e+00]
 [  5.90353742e-01   2.43084817e-03  -7.52804279e-01   1.00000000e+00]
 [  2.83251860e-01   1.10981724e+00   3.28943340e-01   1.00000000e+00]
 [  4.75920341e-01  -1.46120984e-01  -9.69153803e-01   1.00000000e+00]
 [ -1.66912209e+00  -7.87594807e-01  -1.14407469e+00   1.00000000e+00]
 [ -6.20538471e-01   1.36640677e+00   9.18150553e-01   1.00000000e+00]
 [  3.21989902e-02  -1.48308748e+00  -2.87882962e-01   1.00000000e+00]
 [ -1.58037782e+00   9.20751268e-01   6.74181942e-01   1.00000000e+00]
 [ -1.79152496e-01  -3.28434597e-01   1.86244718e-01   1.00000000e+00]
 [  2.97264113e-01  -3.48691665e-01   6.72064478e-03   1.00000000e+00]
 [ -7.16288868e-01   8.46475352e-01   8.62912377e-01   1.00000000e+00]
 [  4.82926468e-01  -3.48691665e-01  -2.28041604e-01   1.00000000e+00]
 [  1.92172214e-01   9.13998912e-01  -1.06582061e+00   1.00000000e+00]
 [ -3.48467222e-01  -5.78271770e-01  -1.15788424e+00   1.00000000e+00]
 [  1.02123053e+00  -1.34128800e+00   2.49704176e+00   1.00000000e+00]
 [ -1.50798117e+00   9.68017760e-01  -4.12168859e-01   1.00000000e+00]
 [  6.97781017e-01  -1.21974559e+00  -5.13438849e-01   1.00000000e+00]
 [  7.98202165e-01   2.26879163e-02   1.24497643e+00   1.00000000e+00]
 [  1.60273904e+00  -8.55118367e-01  -1.11185242e+00   1.00000000e+00]
 [ -1.13315340e+00  -7.87594807e-01  -5.59470662e-01   1.00000000e+00]
 [  2.03849092e-01  -1.59625696e-01   7.75451931e-01   1.00000000e+00]
 [ -1.48813048e+00  -2.13644545e-01  -6.23915201e-01   1.00000000e+00]
 [  2.49388915e-01  -1.09145083e+00  -8.17248818e-01   1.00000000e+00]
 [  8.79940308e-01  -1.34128800e+00  -8.03439274e-01   1.00000000e+00]
 [  1.51633014e+00   1.73103399e+00   5.17673775e-01   1.00000000e+00]
 [  1.18353913e+00   4.68343414e-01  -4.72010216e-01   1.00000000e+00]
 [  2.70407294e-01  -1.04418434e+00   2.13863806e-01   1.00000000e+00]
 [  1.51399477e+00  -1.41556392e+00  -3.15502050e-01   1.00000000e+00]
 [  2.16693657e-01  -8.95632503e-01  -5.96296113e-01   1.00000000e+00]
 [  1.11601758e-01  -1.39530685e+00  -1.02439198e+00   1.00000000e+00]
 [  8.34400486e-01  -1.20624088e+00  -1.45184340e-01   1.00000000e+00]
 [ -1.06075676e+00  -1.18598381e+00  -3.93111688e-02   1.00000000e+00]
 [  1.64127273e+00   1.33264499e+00   1.89862818e+00   1.00000000e+00]
 [  1.24659427e+00  -1.32616272e-01  -2.55016247e-02   1.00000000e+00]
 [  6.76762637e-01   1.47444446e+00  -5.04232486e-01   1.00000000e+00]
 [ -8.80728498e-02  -1.42906863e+00  -1.82009791e-01   1.00000000e+00]
 [  5.14454038e-01   3.67058074e-01  -5.68677025e-01   1.00000000e+00]
 [  1.62258973e+00  -6.32290618e-01  -1.23613832e+00   1.00000000e+00]
 [ -1.49863967e+00  -7.53833027e-01  -3.29311594e-01   1.00000000e+00]
 [ -1.25576062e+00   1.20435022e+00  -1.13947151e+00   1.00000000e+00]
 [ -8.35393020e-01  -8.41613655e-01  -1.13026515e+00   1.00000000e+00]
 [ -1.51615499e+00  -1.29402151e+00   4.81492770e-02   1.00000000e+00]
 [  2.30705910e-01   1.26512143e+00  -1.24074150e+00   1.00000000e+00]
 [  3.10313024e-02   8.32970639e-01  -1.13026515e+00   1.00000000e+00]
 [ -1.27094056e+00  -1.32103093e+00  -7.71217005e-01   1.00000000e+00]
 [ -6.17035408e-01  -1.24000266e+00  -1.03359834e+00   1.00000000e+00]
 [  3.49810063e-01  -9.42898996e-01  -1.11185242e+00   1.00000000e+00]
 [  1.59456522e+00   1.26512143e+00   1.64085003e+00   1.00000000e+00]
 [  9.93206022e-01  -9.90165488e-01  -1.00597925e+00   1.00000000e+00]]

2. Реализуйте функцию mserror - среднеквадратичную ошибку прогноза. Она принимает два аргумента - объекты Series y (значения целевого признака) и y_pred (предсказанные значения). Не используйте в этой функции циклы - тогда она будет вычислительно неэффективной.


In [10]:
def mserror(y, y_pred):
    return np.sum((y - y_pred) ** 2) / y.shape[0]

Какова среднеквадратичная ошибка прогноза значений Sales, если всегда предсказывать медианное значение Sales по исходной выборке? Запишите ответ в файл '1.txt'.


In [11]:
answer1 = mserror(y, np.median(y))
print 'Median sales value:', np.median(y)
print(answer1)
write_answer_to_file(answer1, '1.txt')


Median sales value: 12.9
28.34575

3. Реализуйте функцию normal_equation, которая по заданным матрицам (массивам NumPy) X и y вычисляет вектор весов $w$ согласно нормальному уравнению линейной регрессии.


In [12]:
def normal_equation(X, y):
    return np.dot(np.linalg.pinv(X), y)

In [13]:
norm_eq_weights = normal_equation(X, y)
print(norm_eq_weights)


[[  3.91925365]
 [  2.79206274]
 [ -0.02253861]
 [ 14.0225    ]]

Какие продажи предсказываются линейной моделью с весами, найденными с помощью нормального уравнения, в случае средних инвестиций в рекламу по ТВ, радио и в газетах? (то есть при нулевых значениях масштабированных признаков TV, Radio и Newspaper). Запишите ответ в файл '2.txt'.


In [14]:
answer2 = norm_eq_weights[3]
print(answer2)
write_answer_to_file(answer2, '2.txt')


[ 14.0225]

4. Напишите функцию linear_prediction, которая принимает на вход матрицу X и вектор весов линейной модели w, а возвращает вектор прогнозов в виде линейной комбинации столбцов матрицы X с весами w.


In [15]:
def linear_prediction(X, w):
    return np.dot(X, w)

Какова среднеквадратичная ошибка прогноза значений Sales в виде линейной модели с весами, найденными с помощью нормального уравнения? Запишите ответ в файл '3.txt'.


In [16]:
answer3 = mserror(y, linear_prediction(X, norm_eq_weights))
print(answer3)
write_answer_to_file(answer3, '3.txt')


2.78412631451

5. Напишите функцию stochastic_gradient_step, реализующую шаг стохастического градиентного спуска для линейной регрессии. Функция должна принимать матрицу X, вектора y и w, число train_ind - индекс объекта обучающей выборки (строки матрицы X), по которому считается изменение весов, а также число $\eta$ (eta) - шаг градиентного спуска (по умолчанию eta=0.01). Результатом будет вектор обновленных весов. Наша реализация функции будет явно написана для данных с 3 признаками, но несложно модифицировать для любого числа признаков, можете это сделать.


In [76]:
def stochastic_gradient_step(X, y, w, train_ind, eta=0.01): 
    #массив градиента для каждого параметра
    grad = np.zeros( (X.shape[1], 1) )
    
    #перевод строки параметров в столбец для поэлементного умножения
    xRowToCol = X[train_ind].reshape(-1, 1)
    
    #вычисление градиента для каждого параметра
    for i in xrange(X.shape[1]):
        grad[i] = xRowToCol[i] * ( np.sum(xRowToCol * w) - y[train_ind])
        
    return (2 * eta / X.shape[0]) * grad

6. Напишите функцию stochastic_gradient_descent, реализующую стохастический градиентный спуск для линейной регрессии. Функция принимает на вход следующие аргументы:

  • X - матрица, соответствующая обучающей выборке
  • y - вектор значений целевого признака
  • w_init - вектор начальных весов модели
  • eta - шаг градиентного спуска (по умолчанию 0.01)
  • max_iter - максимальное число итераций градиентного спуска (по умолчанию 10000)
  • max_weight_dist - максимальное евклидово расстояние между векторами весов на соседних итерациях градиентного спуска, при котором алгоритм прекращает работу (по умолчанию 1e-8)
  • seed - число, используемое для воспроизводимости сгенерированных псевдослучайных чисел (по умолчанию 42)
  • verbose - флаг печати информации (например, для отладки, по умолчанию False)

На каждой итерации в вектор (список) должно записываться текущее значение среднеквадратичной ошибки. Функция должна возвращать вектор весов $w$, а также вектор (список) ошибок.


In [77]:
def stochastic_gradient_descent(X, y, w_init, eta=1e-2, max_iter=1e4,
                                max_weight_dist=1e-8, seed=42, verbose=False):
    # Инициализируем расстояние между векторами весов на соседних
    # итерациях большим числом. 
    weight_dist = np.inf
    # Инициализируем вектор весов
    w = w_init
    # Сюда будем записывать ошибки на каждой итерации
    errors = []
    # Счетчик итераций
    iter_num = 0
    # Будем порождать псевдослучайные числа 
    # (номер объекта, который будет менять веса), а для воспроизводимости
    # этой последовательности псевдослучайных чисел используем seed.
    np.random.seed(seed)
        
    # Основной цикл
    while weight_dist > max_weight_dist and iter_num < max_iter:
        # порождаем псевдослучайный 
        # индекс объекта обучающей выборки
        random_ind = np.random.randint(X.shape[0])
        
        #вычисляем значение весов на следующей итерации
        wOld = w
        w = wOld - stochastic_gradient_step(X, y, w, random_ind, eta=eta)
        
        #расстояние между векторами весов на соседних итерациях
        weight_dist = np.linalg.norm(wOld - w)
        
        #среднеквадратичная ошибка для найденного вектора весов
        err = mserror( y, linear_prediction(X, w) )
        errors.append(err)
        
        #счётчик итераций
        iter_num += 1
    
    return w, errors

Запустите $10^5$ итераций стохастического градиентного спуска. Укажите вектор начальных весов w_init, состоящий из нулей. Оставьте параметры eta и seed равными их значениям по умолчанию (eta=0.01, seed=42 - это важно для проверки ответов).


In [78]:
%%time
stoch_grad_desc_weights, stoch_errors_by_iter = stochastic_gradient_descent(X, y, np.zeros( (X.shape[1], 1) ), max_iter=1e5)


Wall time: 5.25 s

Посмотрим, чему равна ошибка на первых 50 итерациях стохастического градиентного спуска. Видим, что ошибка не обязательно уменьшается на каждой итерации.


In [88]:
%pylab inline
plot(range(200), stoch_errors_by_iter[:200])
xlabel('Iteration number')
ylabel('MSE')


Populating the interactive namespace from numpy and matplotlib
Out[88]:
<matplotlib.text.Text at 0x10598fd0>

Теперь посмотрим на зависимость ошибки от номера итерации для $10^5$ итераций стохастического градиентного спуска. Видим, что алгоритм сходится.


In [80]:
%pylab inline
plot(range(len(stoch_errors_by_iter)), stoch_errors_by_iter)
axis([0, 1e5, 0, 250])
grid(True)
xlabel('Iteration number')
ylabel('MSE')


Populating the interactive namespace from numpy and matplotlib
Out[80]:
<matplotlib.text.Text at 0xd275780>

Посмотрим на вектор весов, к которому сошелся метод.


In [81]:
stoch_grad_desc_weights


Out[81]:
array([[  3.91069256e+00],
       [  2.78209808e+00],
       [ -8.10462217e-03],
       [  1.40190566e+01]])

Посмотрим на среднеквадратичную ошибку на последней итерации.


In [82]:
stoch_errors_by_iter[-1]


Out[82]:
2.7844125884067039

Какова среднеквадратичная ошибка прогноза значений Sales в виде линейной модели с весами, найденными с помощью градиентного спуска? Запишите ответ в файл '4.txt'.


In [83]:
answer4 = mserror(y, linear_prediction(X, stoch_grad_desc_weights))
print(answer4)
write_answer_to_file(answer4, '4.txt')


2.78441258841

Ответами к заданию будут текстовые файлы, полученные в ходе этого решения. Обратите внимание, что отправленные файлы не должны содержать пустую строку в конце. Данный нюанс является ограничением платформы Coursera. Мы работаем над исправлением этого ограничения.