Scikit-Learn 线性回归

线性回归是最常见的统计建模方法之一,本节将向大家介绍 Scikit-Learn 中线性回归工具的使用方法,以及一些常用的操作技巧。

首先,我们导入需要用到的其他工具库,并对构建的示例数据进行展示:


In [7]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np

In [8]:
rng = np.random.RandomState(1)
x = 10 * rng.rand(50)
y = 2 * x - 5 + rng.randn(50)
plt.scatter(x, y);


可以看到,这些随机样本的两个特征在二维空间中表现为较明显的正相关关系:

通过导入 Scikit-Learn 的 LinearRegression 函数,我们可以很轻松地构建一个线性回归模型,模型的构建过程及拟合效果如下:


In [9]:
from sklearn.linear_model import LinearRegression
model = LinearRegression(fit_intercept=True)

model.fit(x[:, np.newaxis], y)

xfit = np.linspace(0, 10, 1000)
yfit = model.predict(xfit[:, np.newaxis])

plt.scatter(x, y)
plt.plot(xfit, yfit);



In [ ]: