15. 다중 회귀 분석


In [68]:
from __future__ import division
from collections import Counter
from functools import partial
from linear_algebra import dot, vector_add
from stats import median, standard_deviation, de_mean
from probability import normal_cdf
from gradient_descent import minimize_stochastic
#from simple_linear_regression import total_sum_of_squares
import math, random

In [69]:
def total_sum_of_squares(y):
    """the total squared variation of y_i's from their mean"""
    return sum(v ** 2 for v in de_mean(y))

14장의 내용을 추가 데이터를 사용해 모델의 성능을 높이기 위해,

더 많은 독립 변수를 사용하는 선형 모델을 시험

14장에서 다뤘던 모델
$y_i=\alpha+\beta x_i+\epsilon_i$
여기에 독립 변수를 추가하면
$\Rightarrow$시간(분)= $\alpha+\beta_1*$(친구 수)$+\beta_2*$(근무 시간)$+\beta_3$*(박사 학위 취득 여부)+$\epsilon$

15.1 모델

각 입력값 $x_i$가 숫자 하나가 아니라 $k$개의 숫자인 $x_i1,\cdot\cdot\cdot,x_ik$라고 한다면,
다중 회귀 모델은 다음과 같은 형태를 띈다.
$y_i=\alpha+\beta_1 x_{i1}+\cdot\cdot\cdot+\beta_k x_{ik}+\epsilon_i$

다중 회귀 분석에서는 보통 파라미터 벡터를 $\beta$라고 부름
여기에 상수항 $\alpha$까지 덧붙이려면, 각 데이터 x_i의 앞부분에 1을 덧붙이면 된다.
beta = [alpha, beta_1,$\cdot\cdot\cdot$, x_ik]

그리고 각 데이터는 다음과 같이 된다.
x_i = [1, x_i1, $\cdot\cdot\cdot$, x_ik]
이렇게 하면 모델을 다음과 같이 나타낼 수 있다.


In [70]:
def predict(x_i, beta):
    """각 x_i의 첫 번째 항목은 1이라고 가정"""
    return dot(x_i, beta)

독립 변수 x는 다음과 같은 벡터들의 열로 표현할 수 있다.

[1,  # 상수항
 49, # 친구의 수
 4,  # 하루 근무 시간
 0]  # 박사 학위 취득 여부

15.2 최소자승법에 대한 몇 가지 추가 가정

  1. $x$의 열은 서로 일차독립 해야 한다.
    이 가정이 성립하지 않는다면 $\beta$를 추정할 수 없다.
  2. $x$의 모든 열은 오류 $\epsilon$과 상관관계가 없어야한다.
    이 가정이 위배되면 잘못된 $\beta$가 추정될 것이다.

독립 변수와 오류 사이에 상관관계가 존재한다면, 최소자승법으로 만들어지는 모델은 편향된 $\beta$를 추정해 준다.

15.3 모델 학습하기

오류 함수


In [71]:
def error(x_i, y_i, beta):
    return y_i - predict(x_i, beta)

SGD를 사용하기 위한 오류 제곱 값


In [72]:
def squared_error(x_i, y_i, beta):
    return error(x_i, y_i, beta) ** 2

만약 미적분을 알고 있다면, 오류를 직접 계산할 수도 있다.


In [73]:
def squared_error_gradient(x_i, y_i, beta):
    """i번째 오류 제곱 값의 beta에 대한 기울기"""
    return [-2 * x_ij * error(x_i, y_i, beta)
            for x_ij in x_i]

SGD를 사용해서 최적의 베타를 계산


In [74]:
def estimate_beta(x, y):
    beta_initial = [random.random() for x_i in x[0]]
    return minimize_stochastic(squared_error, 
                               squared_error_gradient, 
                               x, y, 
                               beta_initial, 
                               0.001)

실험 데이터 셋팅 (14장과 동일)


In [75]:
x = [[1,49,4,0],[1,41,9,0],[1,40,8,0],[1,25,6,0],[1,21,1,0],[1,21,0,0],[1,19,3,0],[1,19,0,0],[1,18,9,0],[1,18,8,0],[1,16,4,0],[1,15,3,0],[1,15,0,0],[1,15,2,0],[1,15,7,0],[1,14,0,0],[1,14,1,0],[1,13,1,0],[1,13,7,0],[1,13,4,0],[1,13,2,0],[1,12,5,0],[1,12,0,0],[1,11,9,0],[1,10,9,0],[1,10,1,0],[1,10,1,0],[1,10,7,0],[1,10,9,0],[1,10,1,0],[1,10,6,0],[1,10,6,0],[1,10,8,0],[1,10,10,0],[1,10,6,0],[1,10,0,0],[1,10,5,0],[1,10,3,0],[1,10,4,0],[1,9,9,0],[1,9,9,0],[1,9,0,0],[1,9,0,0],[1,9,6,0],[1,9,10,0],[1,9,8,0],[1,9,5,0],[1,9,2,0],[1,9,9,0],[1,9,10,0],[1,9,7,0],[1,9,2,0],[1,9,0,0],[1,9,4,0],[1,9,6,0],[1,9,4,0],[1,9,7,0],[1,8,3,0],[1,8,2,0],[1,8,4,0],[1,8,9,0],[1,8,2,0],[1,8,3,0],[1,8,5,0],[1,8,8,0],[1,8,0,0],[1,8,9,0],[1,8,10,0],[1,8,5,0],[1,8,5,0],[1,7,5,0],[1,7,5,0],[1,7,0,0],[1,7,2,0],[1,7,8,0],[1,7,10,0],[1,7,5,0],[1,7,3,0],[1,7,3,0],[1,7,6,0],[1,7,7,0],[1,7,7,0],[1,7,9,0],[1,7,3,0],[1,7,8,0],[1,6,4,0],[1,6,6,0],[1,6,4,0],[1,6,9,0],[1,6,0,0],[1,6,1,0],[1,6,4,0],[1,6,1,0],[1,6,0,0],[1,6,7,0],[1,6,0,0],[1,6,8,0],[1,6,4,0],[1,6,2,1],[1,6,1,1],[1,6,3,1],[1,6,6,1],[1,6,4,1],[1,6,4,1],[1,6,1,1],[1,6,3,1],[1,6,4,1],[1,5,1,1],[1,5,9,1],[1,5,4,1],[1,5,6,1],[1,5,4,1],[1,5,4,1],[1,5,10,1],[1,5,5,1],[1,5,2,1],[1,5,4,1],[1,5,4,1],[1,5,9,1],[1,5,3,1],[1,5,10,1],[1,5,2,1],[1,5,2,1],[1,5,9,1],[1,4,8,1],[1,4,6,1],[1,4,0,1],[1,4,10,1],[1,4,5,1],[1,4,10,1],[1,4,9,1],[1,4,1,1],[1,4,4,1],[1,4,4,1],[1,4,0,1],[1,4,3,1],[1,4,1,1],[1,4,3,1],[1,4,2,1],[1,4,4,1],[1,4,4,1],[1,4,8,1],[1,4,2,1],[1,4,4,1],[1,3,2,1],[1,3,6,1],[1,3,4,1],[1,3,7,1],[1,3,4,1],[1,3,1,1],[1,3,10,1],[1,3,3,1],[1,3,4,1],[1,3,7,1],[1,3,5,1],[1,3,6,1],[1,3,1,1],[1,3,6,1],[1,3,10,1],[1,3,2,1],[1,3,4,1],[1,3,2,1],[1,3,1,1],[1,3,5,1],[1,2,4,1],[1,2,2,1],[1,2,8,1],[1,2,3,1],[1,2,1,1],[1,2,9,1],[1,2,10,1],[1,2,9,1],[1,2,4,1],[1,2,5,1],[1,2,0,1],[1,2,9,1],[1,2,9,1],[1,2,0,1],[1,2,1,1],[1,2,1,1],[1,2,4,1],[1,1,0,1],[1,1,2,1],[1,1,2,1],[1,1,5,1],[1,1,3,1],[1,1,10,1],[1,1,6,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,4,1],[1,1,9,1],[1,1,9,1],[1,1,4,1],[1,1,2,1],[1,1,9,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,1,1],[1,1,1,1],[1,1,5,1]]
num_friends_good = [49,41,40,25,21,21,19,19,18,18,16,15,15,15,15,14,14,13,13,13,13,12,12,11,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,8,8,8,8,8,8,8,8,8,8,8,8,8,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
daily_minutes_good = [68.77,51.25,52.08,38.36,44.54,57.13,51.4,41.42,31.22,34.76,54.01,38.79,47.59,49.1,27.66,41.03,36.73,48.65,28.12,46.62,35.57,32.98,35,26.07,23.77,39.73,40.57,31.65,31.21,36.32,20.45,21.93,26.02,27.34,23.49,46.94,30.5,33.8,24.23,21.4,27.94,32.24,40.57,25.07,19.42,22.39,18.42,46.96,23.72,26.41,26.97,36.76,40.32,35.02,29.47,30.2,31,38.11,38.18,36.31,21.03,30.86,36.07,28.66,29.08,37.28,15.28,24.17,22.31,30.17,25.53,19.85,35.37,44.6,17.23,13.47,26.33,35.02,32.09,24.81,19.33,28.77,24.26,31.98,25.73,24.86,16.28,34.51,15.23,39.72,40.8,26.06,35.76,34.76,16.13,44.04,18.03,19.65,32.62,35.59,39.43,14.18,35.24,40.13,41.82,35.45,36.07,43.67,24.61,20.9,21.9,18.79,27.61,27.21,26.61,29.77,20.59,27.53,13.82,33.2,25,33.1,36.65,18.63,14.87,22.2,36.81,25.53,24.62,26.25,18.21,28.08,19.42,29.79,32.8,35.99,28.32,27.79,35.88,29.06,36.28,14.1,36.63,37.49,26.9,18.58,38.48,24.48,18.95,33.55,14.24,29.04,32.51,25.63,22.22,19,32.73,15.16,13.9,27.2,32.01,29.27,33,13.74,20.42,27.32,18.23,35.35,28.48,9.08,24.62,20.12,35.26,19.92,31.02,16.49,12.16,30.7,31.22,34.65,13.13,27.51,33.2,31.57,14.1,33.42,17.44,10.12,24.42,9.82,23.39,30.93,15.03,21.67,31.09,33.29,22.61,26.89,23.48,8.38,27.81,32.35,23.84]

In [76]:
random.seed(0)
beta = estimate_beta(x, daily_minutes_good)

In [77]:
beta # 분 = 30.63 + 0.972 친구 수 - 1.868 근무 시간 + 0.911 박사 학위 취득 여부


Out[77]:
[30.619881701311712,
 0.9702056472470465,
 -1.8671913880379478,
 0.9163711597955347]

15.4 모델 해석하기

모델의 계수는 해당 항목의 영향력을 나타낸다.

친구 수 증가 => 대략 1분 증가  
근무 시간 증가 => 대략 2분 감소  
박사 학위 취득 => 대략 1분 증가  

이러한 해석은 변수간의 관계를 직접적으로 설명해 주지 못한다.
예를 들어, 친구의 수가 다른 사용자들의 근무 시간은 서로 다를 수 있다.
이 모델은 이러한 관계를 잡아내지 못한다.
이러한 문제는 친구 수와 근무 시간을 곱한 새로운 변수로 해결할 수 있다.

변수가 점점 추가되기 시작하면, 각 계수가 유의미한지 살펴봐야 한다.
변수끼리 곱한 값, 변수의 log값, 변수의 제곱 값 등 수 많은 변수를 추가할 수 있기 때문이다.

15.5 적합성(Goodness of fit)

모델의 R 제곱 값을 다시 계산해 보자.

14장에서의 결과

alpha 22.94755241346903
beta 0.903865945605865
r-squared 0.3291078377836305

In [78]:
def multiple_r_squared(x, y, beta):
    sum_of_squared_errors = sum(error(x_i, y_i, beta) ** 2
                                for x_i, y_i in zip(x, y))
    return 1.0 - sum_of_squared_errors / total_sum_of_squares(y)

In [79]:
multiple_r_squared(x, daily_minutes_good, beta)


Out[79]:
0.6800074955952597

회귀 분석 모델에 새로운 변수를 추가하면 R 제곱 값이 어쩔 수 없이 증가한다.
따라서 다중 회귀 분석 모델은 언제나 단순 회귀 분석 모델보다 작은 오류를 갖게 된다.
이러한 이유로 다중 회귀 분석 모델에서는 각 계수의 표준 오차를 살펴 봐야 한다.
계수의 표준 오차는 추정된 $\beta_1$의 계수가 얼마나 확실한지 알려준다.

표준오차 : 모집단 전체를 알 수 없는 상황에서, 모집단이 정상분포(정규분포)를 이루고 있다는 가정 하에, 여러번의 샘플링을 통해 각 표본 집단의 평균들로 이루어진 표준 평균 분포를 얻고, 이 표준 평균 분포의 표준 편차가 표준 오차가 된다.
표준오차는 모평균과 표본평균 사이의 오차를 알려주므로, 모집단의 표준편차가 클수록 표준오차도 커지고, 사례가 많을 수록 작아진다.


오차를 측정하기 위해서는 각 오류 $\epsilon_1$는 독립이며, 평균은 0이고 표준편차는 $\sigma$인 정규분포의 확률변수라는 가정이 필요하다.
표준오차가 클수록 해당 계수는 무의미해 진다.

15.6 여담: bootstrap

알 수 없는 분포에서 생성된 표본 데이터가 주어졌을 때, 이 표본 데이터의 중앙값을 찾으려면?

만약 표본 데이터가 모두 100 근처에 위치하고 있다면, 중앙값 또한 100 근처에 위치할 것이지만,
표본 데이터의 반은 0, 나머지 반은 200 근처에 위치하고 있다면, 추정된 중앙값을 신뢰하기 힘들다.

bootstrap은 중복이 허용된 재추출을 통해 새로운 데이터의 각 항목을 생성한다.
이를 통해 만들어진 데이터로 중앙값을 계산해 볼 수 있다.


In [80]:
def bootstrap_sample(data):
    """len(data)개의 항목을 중복을 허용한 무작위 추출"""
    return [random.choice(data) for _ in data]

In [81]:
def bootstrap_statistic(data, stats_fn, num_samples):
    """num_samples개의 bootstrap 샘플에 대해 stats_fn을 적용"""
    return [stats_fn(bootstrap_sample(data)) 
            for _ in range(num_samples)]

예를 들어, 다음과 같은 두 가지 데이터를 살펴보자.


In [82]:
# 101개의 데이터가 모두 100에 인접
close_to_100 = [99.5 + random.random() for _ in range(101)]

# 101개의 데이터 중 50개는 0에 인접, 50개는 200에 인접
far_from_100 = ([99.5 + random.random()] +
               [random.random() for _ in range(50)] +
               [200 + random.random() for _ in range(50)])

In [83]:
print(close_to_100)
print(far_from_100)


[99.575022280022, 100.13015471695907, 100.3098947777773, 100.02634537429101, 100.38218103100003, 99.6971220346973, 99.52071572817076, 100.08732646013036, 100.26861005947644, 100.0340458696825, 100.16675086620182, 99.53277295299847, 99.90497388381739, 100.11133261277095, 99.76775564923551, 99.52798582548432, 99.84973748230965, 99.85212794947584, 100.4812988797111, 100.41570521787581, 99.79350418402947, 100.49592295733221, 100.43402789747296, 100.34956975596565, 99.57412241274703, 99.98693719710383, 99.92933404675189, 100.43593518681631, 99.51448924693598, 99.52651243800432, 99.90918525295109, 99.83464437188773, 99.6812568193253, 100.06586584030461, 100.25813894897857, 100.40043150749628, 99.60072489313731, 100.24951314567744, 100.20039187360025, 99.84545826146449, 100.32928949104371, 100.38179466422953, 100.35410868832848, 100.3783065088296, 100.48581172718745, 99.54083837500995, 99.580978843698, 100.04153611616675, 99.60941912058524, 99.84127649439404, 99.64651625272637, 99.82534150579184, 99.82446822437431, 99.95728292174014, 100.07674830805912, 100.31036457325315, 99.6732137127586, 100.06627711829937, 100.35406905959398, 99.88515914699941, 99.95763666335068, 100.32814252691253, 100.3596632260573, 100.23329252282423, 99.74322409376589, 100.02048509159759, 100.18315165595742, 100.08022923777746, 99.88829562117861, 100.47017008563648, 100.29464504045112, 100.40637051114227, 100.24722036575075, 99.60711248956025, 100.32771020179416, 100.4416097768272, 100.49545030736934, 99.83731253741817, 100.15354631416392, 100.41146963253381, 100.33060424399922, 99.73124420451033, 99.63742911848779, 100.47361050279841, 99.57864554239072, 99.8494530070752, 99.65147846498758, 100.26208497004184, 99.94166876350928, 99.97928565593025, 100.37841608958033, 100.27752011876781, 100.26212372144197, 100.19440587880706, 99.66245302650196, 100.06850285070729, 100.25038426285329, 100.0251392489555, 99.84220178835625, 100.27812724899424, 99.92424918869057]
[100.37472916630605, 0.5969512867241137, 0.7946061542121747, 0.4402272456987374, 0.39871783720721043, 0.4720727570723431, 0.7721419649902795, 0.4748469465999495, 0.5786709051115894, 0.3774246047836889, 0.9229610720300252, 0.8082392572934448, 0.828944660511035, 0.19152762619443708, 0.6620212838172704, 0.9632764195933914, 0.0339453758660897, 0.9746929609488131, 0.9679946120377706, 0.8513362921107412, 0.4727575806096489, 0.12953453508753676, 0.2914439042067456, 0.00790453987366957, 0.9740127081653535, 0.8380713386204371, 0.5146761012374821, 0.3740325537023109, 0.5394340744161044, 0.5176545722729882, 0.2722659882372157, 0.5885797061035571, 0.3239093241359847, 0.5787412466587266, 0.9170498145371947, 0.4489727468726433, 0.5860139471941539, 0.44810296853773723, 0.6367533387520619, 0.9085650263234569, 0.8907248275013403, 0.24529506985206173, 0.34478359417656645, 0.9413891648644847, 0.7090486809725813, 0.0050807692558083595, 0.7433478919602433, 0.45392101713131405, 0.6640860574999732, 0.838505962431085, 0.9819299666825538, 200.94851612262863, 200.81580916988426, 200.29136858527113, 200.67075321295815, 200.6339324977142, 200.0439362129153, 200.7731274510282, 200.3207159857896, 200.7821726429382, 200.88155319382838, 200.7313687034986, 200.2899506852126, 200.03611456820246, 200.2693132616282, 200.01173496311017, 200.82807229847933, 200.55841769115833, 200.90463843974587, 200.6838185033596, 200.07839823892192, 200.171028817719, 200.0254523027627, 200.27491788892695, 200.40323925960584, 200.27043909541388, 200.29496033604522, 200.11842771777464, 200.42981161992734, 200.63517020519018, 200.21899599520737, 200.4003645785021, 200.4755646399593, 200.89659751386117, 200.3681851127953, 200.02972914757345, 200.043346407288, 200.03723940374888, 200.46286770862417, 200.4969583761535, 200.46520673354394, 200.63664887406284, 200.28936045149442, 200.72457669185047, 200.24684041759872, 200.50317125734537, 200.34435154326192, 200.60995335029207, 200.07626904987512, 200.2325068223646, 200.06556682467556]

만약 두 데이터의 중앙값을 계산해 보면 둘 다 대략 100에 가까운 것을 확인할 수 있다.


In [84]:
print(median(close_to_100))
print(median(far_from_100))


100.06586584030461
100.37472916630605

하지만 다음과 같이 bootstrap을 적용해 보면,
close_to_100은 100에 대부분 가깝지만, far_from_100은 0 또는 200에 가까운 것을 확인할 수 있다.


In [85]:
print(bootstrap_statistic(close_to_100, median, 100))


[100.02634537429101, 100.06586584030461, 100.06850285070729, 100.06586584030461, 100.08022923777746, 100.02634537429101, 100.0251392489555, 100.06586584030461, 100.0251392489555, 100.18315165595742, 100.06627711829937, 100.06586584030461, 100.06627711829937, 100.06627711829937, 100.08022923777746, 100.08022923777746, 99.97928565593025, 99.95763666335068, 100.06627711829937, 99.95763666335068, 100.06586584030461, 100.06627711829937, 100.06850285070729, 100.08732646013036, 100.0340458696825, 100.02634537429101, 100.16675086620182, 100.11133261277095, 100.02634537429101, 100.0340458696825, 100.02048509159759, 100.06586584030461, 100.0340458696825, 100.02048509159759, 100.06850285070729, 100.06850285070729, 99.95763666335068, 100.08022923777746, 100.02048509159759, 100.07674830805912, 100.06586584030461, 100.06850285070729, 100.06627711829937, 100.11133261277095, 100.08022923777746, 100.06627711829937, 100.02634537429101, 100.0251392489555, 100.06850285070729, 100.06627711829937, 100.06586584030461, 100.0251392489555, 100.08022923777746, 100.11133261277095, 100.06586584030461, 100.0251392489555, 100.02634537429101, 100.04153611616675, 100.08022923777746, 100.06627711829937, 100.19440587880706, 100.04153611616675, 99.98693719710383, 100.06850285070729, 100.07674830805912, 100.08022923777746, 100.08732646013036, 100.06627711829937, 100.04153611616675, 100.08732646013036, 100.11133261277095, 100.06586584030461, 100.0251392489555, 100.13015471695907, 100.02634537429101, 99.95728292174014, 100.08732646013036, 100.06850285070729, 99.92424918869057, 100.02048509159759, 99.97928565593025, 99.92424918869057, 100.13015471695907, 100.06586584030461, 100.08022923777746, 100.06850285070729, 100.04153611616675, 100.06627711829937, 100.06586584030461, 100.06850285070729, 100.02048509159759, 100.02634537429101, 100.06627711829937, 100.04153611616675, 100.02634537429101, 100.02634537429101, 100.06627711829937, 100.18315165595742, 100.06850285070729, 100.06627711829937]

In [86]:
print(bootstrap_statistic(far_from_100, median, 100))


[0.9746929609488131, 0.9746929609488131, 0.9632764195933914, 200.0254523027627, 200.03723940374888, 0.9740127081653535, 200.0254523027627, 200.03723940374888, 0.9740127081653535, 200.043346407288, 200.06556682467556, 0.9679946120377706, 0.9229610720300252, 0.9170498145371947, 0.9746929609488131, 200.02972914757345, 100.37472916630605, 200.0439362129153, 0.9170498145371947, 0.9746929609488131, 0.9819299666825538, 0.9740127081653535, 200.02972914757345, 0.9819299666825538, 200.0439362129153, 200.03611456820246, 200.0439362129153, 200.02972914757345, 200.03611456820246, 200.0254523027627, 200.03723940374888, 0.9819299666825538, 0.9229610720300252, 0.9413891648644847, 0.9413891648644847, 200.01173496311017, 0.9746929609488131, 200.02972914757345, 200.043346407288, 0.9679946120377706, 200.07626904987512, 200.01173496311017, 0.9819299666825538, 200.01173496311017, 0.9632764195933914, 0.9746929609488131, 0.9746929609488131, 200.01173496311017, 0.9632764195933914, 0.9746929609488131, 200.02972914757345, 0.8082392572934448, 0.9413891648644847, 200.11842771777464, 0.9229610720300252, 0.9229610720300252, 100.37472916630605, 0.9819299666825538, 0.9740127081653535, 200.01173496311017, 200.03611456820246, 0.9819299666825538, 200.01173496311017, 0.9632764195933914, 200.0254523027627, 200.02972914757345, 200.02972914757345, 200.03611456820246, 0.9413891648644847, 100.37472916630605, 0.9170498145371947, 200.03723940374888, 0.9632764195933914, 0.8082392572934448, 0.9170498145371947, 200.01173496311017, 0.9746929609488131, 0.9746929609488131, 200.0439362129153, 100.37472916630605, 200.01173496311017, 200.03611456820246, 200.0254523027627, 0.9632764195933914, 200.02972914757345, 0.9740127081653535, 200.0254523027627, 100.37472916630605, 200.07626904987512, 0.9679946120377706, 200.02972914757345, 200.03723940374888, 200.01173496311017, 200.07626904987512, 0.9746929609488131, 0.9740127081653535, 0.9819299666825538, 200.01173496311017, 100.37472916630605, 100.37472916630605]

첫번째는 표준편차가 0에 가깝지만, 두번째는 표준편차가 100에 가까운 것을 확인할 수 있다.


In [87]:
standard_deviation(bootstrap_statistic(close_to_100, median, 100))


Out[87]:
0.05120259628781462

In [88]:
standard_deviation(bootstrap_statistic(far_from_100, median, 100))


Out[88]:
97.02219385327744

데이터가 이렇게 극단적인 경우에는 데이터를 직접 살펴보면 문제를 쉽게 파악할 수 있지만,
대부분의 경우 데이터만 살펴보는 것으로는 부족하다.

15.7 계수의 표준 오차

계수의 표준 오차를 추정할 때도 bootstrap을 적용할 수 있다.
bootstrap을 할 때에는, 하나의 데이터에 속하는 x와 y를 (x_i, y_i) 형태로 묶어줘야 하고,
반환된 데이터를 다시 x_sample, y_sample로 나눠줘야 한다.


In [89]:
def estimate_sample_beta(sample):
    x_sample, y_sample = zip(*sample) # magic unzipping trick
    return estimate_beta(x_sample, y_sample)

In [90]:
random.seed(0) # 예시와 동일한 결과를 얻기 위해 설정

In [91]:
bootstrap_betas = bootstrap_statistic(list(zip(x, daily_minutes_good)),
                                     estimate_sample_beta,
                                     100)

In [92]:
bootstrap_betas


Out[92]:
[[29.939753924432026,
  1.0766987811171587,
  -1.9072694245979773,
  1.229759631083634],
 [28.278404991817176,
  1.0720783947084744,
  -1.8519835485752294,
  1.9135597872325683],
 [29.72096256558336,
  1.0592372902094562,
  -1.832302380438743,
  1.406080921432442],
 [31.29493113749666,
  0.9294712939609966,
  -1.9206097459949676,
  0.17543757346238917],
 [30.665444261614393,
  0.9606538039494921,
  -2.005395609415964,
  0.8552712377676644],
 [31.20395103980187,
  1.0276933591376252,
  -2.1224488742611167,
  1.4100379189697114],
 [28.388430805950506,
  1.1636474982427323,
  -1.7264260057963396,
  2.7482808985335714],
 [30.208189266271106,
  1.1198757677431075,
  -1.9774588115681568,
  1.502278363322292],
 [30.133658456229913,
  1.050576038253023,
  -1.8446966431005063,
  1.079177746046518],
 [30.675215735007107,
  1.0045006785935724,
  -1.9352671318816448,
  0.8555767976704393],
 [30.890885319911387,
  0.9767230930847294,
  -1.917782273386021,
  1.0784038714005937],
 [30.057539026380102,
  1.0126346712121064,
  -1.8796275838125631,
  1.5540498966901921],
 [30.623907011986216,
  1.004223611468741,
  -1.771726683620326,
  0.796237709124627],
 [30.401744262239784,
  0.9489554751025814,
  -1.8159196735423577,
  0.6533668650302189],
 [31.799239379579472,
  0.9526495863690917,
  -1.9859546158453727,
  0.780751076368855],
 [30.58569107273958,
  0.9701040497068303,
  -1.8725813555636341,
  1.252915653229729],
 [28.879439883714806,
  1.0734769954485173,
  -1.7359814587573414,
  1.7677207310928087],
 [29.626713954582574,
  0.9869504295802896,
  -1.738516563411458,
  0.8712974052368523],
 [29.139341270591096,
  0.9688591310976062,
  -1.4642782347614445,
  0.7440716132440629],
 [31.3422007303681,
  0.9710263541178166,
  -1.8756242704186294,
  0.3428333406993962],
 [30.142138176390134,
  1.0438656128502766,
  -1.8767869219047477,
  2.0687426352870686],
 [30.865363137104843,
  1.0361759722107264,
  -2.05362273936756,
  0.47121306706410077],
 [29.28923762832063, 1.1433028180250218, -1.79313504977106, 0.958705985824458],
 [30.044412551862735,
  0.9539149631698808,
  -1.8050356239668175,
  1.91615701841226],
 [30.458728278823862,
  0.9315015710988496,
  -1.654331784676905,
  1.5610576184224607],
 [29.27117755274291,
  1.1052665725726942,
  -1.8267515243534385,
  1.6925600866545762],
 [30.436764359149393,
  0.9858227427767813,
  -1.7820989390270756,
  -0.32942562983455803],
 [31.23422415012503,
  0.9687998109467523,
  -1.8795314473233462,
  0.6115094428497695],
 [30.23012443157135,
  0.9804449615271451,
  -1.9119885894178368,
  1.013002899591167],
 [31.42082515618449,
  1.014155515693022,
  -1.998397890347172,
  -0.15360732980811093],
 [31.343056984945832,
  0.9073946988050725,
  -1.9360825111026645,
  1.0414148892172634],
 [30.64438916270812,
  1.0487045264293338,
  -2.0454966449784333,
  1.8490230739884252],
 [30.297718764085502,
  0.9313073315017633,
  -1.9492867868700605,
  1.294924088657591],
 [30.42083221252522,
  0.9707959028177727,
  -1.7475635933805391,
  0.5777509006054096],
 [28.968710893077578,
  1.0559056857722307,
  -1.7970028326114307,
  1.124627168381006],
 [31.019123828788175,
  0.9380106813823185,
  -1.7462920436369487,
  -0.14910230232356975],
 [29.991059809906233,
  0.9465933769322489,
  -1.8376633836352743,
  1.5665008284896442],
 [30.750723345457015,
  0.9529394291501507,
  -1.787606028926137,
  0.48644433932053177],
 [31.83269612593045,
  0.8035520871161486,
  -1.760478693016914,
  -0.823825326513922],
 [30.93544391823118,
  0.9006273828340532,
  -1.7197550961124592,
  0.2294706582188784],
 [28.95673336164893,
  1.027446164378787,
  -1.8858309945085343,
  2.3742736814611103],
 [30.169131635903277,
  0.9548043247012215,
  -1.903913381787953,
  2.247587789582005],
 [29.911538491020362,
  1.0334281797508842,
  -1.7912082535151153,
  1.8880175327967876],
 [28.544777078843367,
  1.0242774392110523,
  -1.6559495713813936,
  1.193267047095495],
 [31.51678153407387,
  0.8944947742335687,
  -1.838204829292607,
  -0.7676791709888592],
 [30.617977968032367,
  0.9874319252899987,
  -1.6802688478996486,
  0.40869911200454656],
 [30.92856471311331,
  0.9689431528282529,
  -1.8236528298377968,
  0.3811627181838413],
 [30.143408603826657,
  1.0289436966463812,
  -1.9215946211035626,
  0.6443679160373688],
 [32.87640340639101,
  0.9082882039188649,
  -1.9858587181192655,
  -0.8426075759984648],
 [29.20169341388181, 1.019981766027615, -1.808733047279635, 2.216915152820034],
 [31.195787687381777,
  0.9080188444641079,
  -1.8645077160524712,
  0.35021857606746887],
 [32.03649988671767,
  0.9028909288601623,
  -1.972343157494047,
  -0.7088714004899028],
 [30.796238365725735,
  1.034284844109321,
  -1.9425713105569087,
  0.40451350188939966],
 [31.620215176474982,
  0.957601008130083,
  -1.9719039493420996,
  0.7311637358673352],
 [29.461415656901483,
  0.9658332725923695,
  -1.7849522741250785,
  1.4139096562518223],
 [29.64865631358927,
  0.9723026515725409,
  -1.8670922753394064,
  1.9522077402353004],
 [30.163908957232156,
  0.868158083916477,
  -1.6301752274649433,
  1.526118904207664],
 [30.81902707371584,
  0.9699072401581521,
  -1.909448359244763,
  -0.14803799919531274],
 [30.89907423251804,
  1.0091734203978702,
  -2.113462469839564,
  1.398923756817477],
 [31.11549366120244,
  0.9390298184128496,
  -1.9864885107333226,
  1.7755927379888556],
 [32.099181234558564,
  0.9350154892863141,
  -2.103953813593749,
  0.4057147210473868],
 [30.28960511801617,
  0.9922202945383488,
  -1.7410169106507627,
  0.8661299904834323],
 [31.041302334355397,
  0.9741811813421851,
  -1.8508141694180666,
  1.3210274248870244],
 [32.23967916274526,
  0.9112244789273377,
  -1.9188490646041159,
  0.23504332473576237],
 [30.70314665295726,
  0.8940980775153179,
  -1.7891112851395463,
  0.8075537368398492],
 [31.84343649566934,
  0.9717249613451591,
  -2.0974664693805782,
  0.15486325631630113],
 [32.13731695010414,
  0.8879794084368686,
  -1.8583735434160746,
  -0.6865155666174229],
 [31.023576280562768,
  0.879147272369567,
  -1.6702390270415832,
  0.5906314748702747],
 [30.28491625163244,
  0.996029322515991,
  -1.873987188674591,
  -0.009473373138443552],
 [31.382093432582764,
  0.9216282859948381,
  -1.896595333609623,
  0.027166344890374713],
 [30.921614392839796,
  0.9329651964043207,
  -1.9314309500594924,
  0.46760753906171765],
 [30.275803859425874,
  1.030767731879785,
  -1.8104063032324842,
  1.6653584658436331],
 [32.83486459447948,
  0.9399689454834547,
  -2.000454837329984,
  -0.5539452939735333],
 [30.417264186725262,
  0.9904220890087907,
  -1.9367706347955331,
  1.1379193558461553],
 [30.22943868919392,
  0.9483590871856169,
  -1.9191131409231468,
  1.510239565252049],
 [29.71979231340431,
  1.0067306470563924,
  -1.933415405816016,
  2.307372679349379],
 [30.38360464355746,
  1.0092768805613839,
  -1.8105024901337479,
  0.4053632157265963],
 [30.428922288135567,
  0.9647743853551398,
  -1.829386915089297,
  0.20424097766422025],
 [30.918914778499012,
  0.9127089981401707,
  -1.8210613948513952,
  0.7094233889678359],
 [31.25777599583735,
  0.9481284245089976,
  -1.9668874637414853,
  0.010616227403205227],
 [29.052740112148168,
  1.0959371628289478,
  -1.904301016766296,
  2.838516156403951],
 [30.19052353587443,
  1.062007537932147,
  -1.9970574547250417,
  1.2364151497028262],
 [29.877293811664337,
  1.0730743003298784,
  -1.943360025906381,
  1.5397809899002801],
 [30.514220677328172,
  0.9360086465109465,
  -1.708832493430305,
  -0.2576363640294057],
 [31.967527009629382,
  0.8763887529794432,
  -1.9081408923144763,
  -0.378551525080891],
 [29.410500759159707,
  0.9773322842976048,
  -1.6286317806513302,
  0.24945379193720651],
 [30.116102377546778,
  0.9706502151204733,
  -1.8328062091740902,
  0.8637418373664759],
 [30.50813837992598,
  0.9669962169528152,
  -1.7798454554922487,
  1.18423301442585],
 [29.278622616895476,
  1.0987046393662585,
  -1.8965490806012397,
  2.523924636992208],
 [31.986271061027182,
  0.965323882042117,
  -1.9865947474771022,
  -0.2608214613254954],
 [30.225285403476956,
  0.9675414931845573,
  -1.8393172706991456,
  0.7273001956043006],
 [32.411194504402154,
  0.9955723054413659,
  -2.130042260803471,
  -0.5361269322987706],
 [31.511768964117973,
  0.9179084575554802,
  -1.9559138187257954,
  -0.6626425342415486],
 [30.62670609297478,
  0.9231046554569917,
  -1.838908765965435,
  1.4487769867891944],
 [30.21410849174244,
  0.9100946586610947,
  -1.714222481064959,
  -0.038633096503328516],
 [30.165686040883838,
  0.991053038735986,
  -1.8507585194427618,
  0.021587688971951104],
 [29.35765742454494,
  0.9823543746359569,
  -1.7871314326319478,
  2.1254327743039294],
 [30.013030359453495,
  1.0368815798678321,
  -1.850358037351077,
  0.9194524130731508],
 [31.94186523226061,
  0.9715835443018485,
  -2.0274757401206243,
  0.24131790127712038],
 [31.4159031698155,
  0.9834888002736525,
  -1.9094259285035475,
  0.8174140324046765]]

그리고 각 계수의 표준 오차를 추정할 수 있다.


In [93]:
bootstrap_standard_errors = [
    standard_deviation([beta[i] for beta in bootstrap_betas])
    for i in range(4)
]

In [94]:
bootstrap_standard_errors


Out[94]:
[0.953551702104508,
 0.06288763616183773,
 0.11722269488203318,
 0.8591786495949066]

이제 '과연 $\beta_1$는 0일까?' 같은 가설을 검증해 볼 수 있다.

p-value(유의 확률)은 귀무가설이 맞을 경우 대립가설 쪽의 값이 나올 확률을 나타내는 값. 확률 값이라고도 한다. 표본 평균이 귀무가설 값에서 멀수록 작아지게 된다.

In [95]:
def p_value(beta_hat_j, sigma_hat_j):
    if beta_hat_j > 0:
        return 2 * (1 - normal_cdf(beta_hat_j / sigma_hat_j))
    else:
        return 2 * normal_cdf(beta_hat_j / sigma_hat_j)

In [96]:
print(p_value(30.63, 1.174))
print(p_value(0.972, 0.079))
print(p_value(-1.868, 0.131))
print(p_value(0.911, 0.990))


0.0
0.0
0.0
0.35746719881669264

대부분의 계수들은 0이 아닌 것으로 검증되었으나,
박사 학위 취득 여부에 대한 계수에 의미가 없을 수 있다는 것을 암시

15.8 Regularization

  1. 변수가 많아질수록 오버피팅
  • 0이 아닌 계수가 많을수록 모델 해석이 어려움

Regularization은 beta가 커지면 커질수록 해당 모델에게 패널티를 주는 방법이다.
예를 들어, ridge regression의 경우, beta_i를 제곱한 값의 합에 비례하는 패널티를 추가한다.
하지만 상수에 대한 패널티는 주지 않는다.


In [97]:
# alpha는 패널티의 강도를 조절하는 하이퍼 파라미터
# 보통 "lamda"라고 표현하지만 파이썬에서는 이미 사용 중인 키워드이다.
def ridge_penalty(beta, alpha):
  return alpha * dot(beta[1:], beta[1:])

def squared_error_ridge(x_i, y_i, beta, alpha):
    """beta를 사용할 때 오류와 패널티의 합을 추정"""
    return error(x_i, y_i, beta) ** 2 + ridge_penalty(beta, alpha)

그리고 이전과 동일하게 경사 하강법을 적용할 수 있다.


In [98]:
def ridge_penalty_gradient(beta, alpha):
    """패널티의 기울기"""
    return [0] + [2 * alpha * beta_j for beta_j in beta[1:]]

def squared_error_ridge_gradient(x_i, y_i, beta, alpha):
    """i번 오류 제곱 값과 패널티의 기울기"""
    return vector_add(squared_error_gradient(x_i, y_i, beta),
                      ridge_penalty_gradient(beta, alpha))

def estimate_beta_ridge(x, y, alpha):
    """패널티가 alpha인 리지 회귀를 경사 하강법으로 학습"""
    beta_initial = [random.random() for x_i in x[0]]
    return minimize_stochastic(partial(squared_error_ridge, alpha=alpha), 
                               partial(squared_error_ridge_gradient, 
                                       alpha=alpha), 
                               x, y, 
                               beta_initial, 
                               0.001)

만약 alpha가 0이라면 패널티는 전혀 없으며, 이전과 동일한 모델이 학습될 것이다.


In [99]:
random.seed(0)
beta_0 = estimate_beta_ridge(x, daily_minutes_good, alpha=0.0)

beta_0


Out[99]:
[30.619881701311712,
 0.9702056472470465,
 -1.8671913880379478,
 0.9163711597955347]

In [100]:
dot(beta_0[1:], beta_0[1:])


Out[100]:
5.267438780018153

In [101]:
multiple_r_squared(x, daily_minutes_good, beta_0)


Out[101]:
0.6800074955952597

그리고 alpha를 증가시킬수록 적합성은 감소하고, beta의 크기도 감소한다.


In [102]:
beta_0_01 = estimate_beta_ridge(x, daily_minutes_good, alpha=0.01)
print(beta_0_01)
print(dot(beta_0_01[1:], beta_0_01[1:]))
print(multiple_r_squared(x, daily_minutes_good, beta_0_01))


[30.55985204967343, 0.9730655363505671, -1.8624424625144256, 0.9317665551046306]
5.2837373774215655
0.680010213297079

In [103]:
beta_0_1 = estimate_beta_ridge(x, daily_minutes_good, alpha=0.1)
print(beta_0_1)
print(dot(beta_0_1[1:], beta_0_1[1:]))
print(multiple_r_squared(x, daily_minutes_good, beta_0_1))


[30.894860179735474, 0.9490275238632391, -1.8501720889216575, 0.5325129720515789]
4.607360065077926
0.6797276241305292

In [104]:
beta_1 = estimate_beta_ridge(x, daily_minutes_good, alpha=1)
print(beta_1)
print(dot(beta_1[1:], beta_1[1:]))
print(multiple_r_squared(x, daily_minutes_good, beta_1))


[30.666778908554885, 0.908635996761392, -1.6938673046100265, 0.09370161190283018]
3.7035858123105934
0.6757061537631815

In [105]:
beta_10 = estimate_beta_ridge(x, daily_minutes_good, alpha=10)
print(beta_10)
print(dot(beta_10[1:], beta_10[1:]))
multiple_r_squared(x, daily_minutes_good, beta_10)


[28.372861060795607, 0.7307660860322116, -0.9212163182015426, -0.018495551723207087]
1.3830006628491893
Out[105]:
0.5752138470466858

패널티가 증가하면 박사 학위 취득 여부 변수는 사라진다.

다른 형태의 패널티를 사용하는 lasso regression도 있다.


In [106]:
def lasso_penalty(beta, alpha):
    return alpha * sum(abs(beta_i) for beta_i in beta[1:])

리지 회귀의 패널티는 총 계수의 합을 줄여 주지만,
라쏘 회귀의 패널티는 모든 계수를 최대한 0으로 만들어 주며,
보다 희소한(sparse) 모델을 학습하게 해준다.

하지만 라쏘 회귀는 경사 하강법으로는 학습할 수 없다.