In [1]:
# %load /Users/facai/Study/book_notes/preconfig.py
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set(font='SimHei')
plt.rcParams['axes.grid'] = False
from IPython.display import SVG
def show_image(filename, figsize=None):
if figsize:
plt.figure(figsize=figsize)
plt.imshow(plt.imread(filename))
分析用的代码版本信息:
~/W/spark ❯❯❯ git log -n 1
commit 2eedc00b04ef8ca771ff64c4f834c25f835f5f44
Author: Wenchen Fan <wenchen@databricks.com>
Date: Mon Aug 1 17:54:41 2016 -0700
[SPARK-16828][SQL] remove MaxOf and MinOf
## What changes were proposed in this pull request?
These 2 expressions are not needed anymore after we have `Greatest` and `Least`. This PR removes them and related tests.
## How was this patch tested?
N/A
Author: Wenchen Fan <wenchen@databricks.com>
Closes #14434 from cloud-fan/minor1.
In [3]:
SVG("./res/uml/gbdt_spark.svg")
Out[3]:
spark 封装得比较细,一眼看去各种类,乱花迷眼。但其实因为拆分得当,厘清关系是比较容易的,较高的抽象层次也让代码非常易读。
为了便于逻辑叙述,我们按调用顺序由顶至下地,分别介绍各个涉及类的大致作用。
训练调用顺序:
GBTRegressor.train
和 GBTClassifier.train
是训练方法的入口。GradientboostedTrees
通过 run
启动 boost
方法开始训练,这是 GBDT 的算法主体代码逻辑。BoostingStrategy
控制,它主要的构成是:Loss
,具体实现是 logstic, L1, L2 三种。Strategy
指导 DecisionTreeRegressor
进行拟合。GBTRegressionModel
或 GBTClassificationModel
。预测调用顺序
预测比较简单,GBT*Model
类按权值加总 DecisionTreeRegressor
结果即可。
另外,GBT*Model.featureImportances
是个有意思的变量,它用于评估特征的重要度,主要由 TreeEnsembleModel
计算,后面会细讲。
因为这里很多类只是起封装作用,我们只会介绍涉及到 GBDT 算法实现的具体类和函数,不会对整个工程面面俱到。另外,决策树在 spark 中的实现会专门分析,本文不深入。
In [5]:
show_image("./res/gbdt.png", figsize=(10,5))
上图来源于论文 Friedman - Greedy function approximation: A gradient boosting machine,第6行引入了学习率控制过拟合。
下面是 spark 中实现代码:
243 def boost(
244 //+-- 5 lines: input: RDD[LabeledPoint],-----------------------------------------------------
249 //+-- 7 lines: val timer = new TimeTracker()-------------------------------------------------
256 val numIterations = boostingStrategy.numIterations
257 val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
258 val baseLearnerWeights = new Array[Double](numIterations)
259 val loss = boostingStrategy.loss
260 val learningRate = boostingStrategy.learningRate
261 // Prepare strategy for individual trees, which use regression with variance impurity.
262 val treeStrategy = boostingStrategy.treeStrategy.copy
263 val validationTol = boostingStrategy.validationTol
264 treeStrategy.algo = OldAlgo.Regression
265 treeStrategy.impurity = OldVariance
266 treeStrategy.assertValid()
267
268 //+-- 20 lines: Cache input-------------------------------------------------------------------
288 // Initialize tree
289 timer.start("building tree 0")
290 val firstTree = new DecisionTreeRegressor().setSeed(seed)
291 val firstTreeModel = firstTree.train(input, treeStrategy)
292 val firstTreeWeight = 1.0
293 baseLearners(0) = firstTreeModel
294 baseLearnerWeights(0) = firstTreeWeight
295
296 var predError: RDD[(Double, Double)] =
297 computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
298 //+-- 11 lines: predErrorCheckpointer.update(predError)---------------------------------------
309
310 var m = 1
311 var doneLearning = false
312 while (m < numIterations && !doneLearning) {
313 // Update data with pseudo-residuals
314 val data = predError.zip(input).map { case ((pred, _), point) =>
315 LabeledPoint(-loss.gradient(pred, point.label), point.features)
316 }
317
318 //+-- 4 lines: timer.start(s"building tree $m")----------------------------------------------
322 val dt = new DecisionTreeRegressor().setSeed(seed + m)
323 val model = dt.train(data, treeStrategy)
324 //+-- 2 lines: timer.stop(s"building tree $m")-----------------------------------------------
326 baseLearners(m) = model
327 // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
328 // Technically, the weight should be optimized for the particular loss.
329 // However, the behavior should be reasonable, though not optimal.
330 baseLearnerWeights(m) = learningRate
331
332 predError = updatePredictionError(
333 input, predError, baseLearnerWeights(m), baseLearners(m), loss)
334 //+-- 21 lines: predErrorCheckpointer.update(predError)---------------------------------------
355 m += 1
356 }
357 //+-- 15 lines: timer.stop("total")-----------------------------------------------------------
372 }
接下来,我们将代码和算法步骤联系起来:
算法第一行,基准模型,对应 290L-294L;
算法第二行,结束条件,对应 312L;
算法第三行,损失函数的梯度,对应 315L;
算法第四行,树模型训练,对应 323L;
算法第五行,权重计算,对应 330L;
算法第六行,加入新训练的树,对应 332L-333L.
很容易注意到,330L 对权重计算的实现很奇怪,将学习率直接作为树的权重值。对应的代码注释也说明此实现有问题,但不是很直白。用我的理解,它的实际意图应该是如此考虑的:
$\beta h(x_i; a)$ 是树模型的训练。决策树生成时用的 impurity 方法是 265L 的 OldVariance,追踪过去可以看到 $\beta = 1$。
对于损失函数 L2,它的 $\rho_m = \beta_m = 1$,推导很简单,见论文 Friedman : Greedy function approximation: A gradient boosting machine. 章节 4.1 Least Squares regression。
因为 $\rho = 1$,而算法第六行,要乘的权重 $\rho v = v$。
这就是注释里说只对 SquaredError (即 L2)权重是正确的。
GBT*.predict
预测很简单,每颗数的结果加权累加,这里用矩阵乘法实现。分类问题多了212L的二值化。
207 override protected def predict(features: Vector): Double = {
208 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
209 // Classifies by thresholding sum of weighted tree predictions
210 val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
211 val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
212 if (prediction > 0.0) 1.0 else 0.0
213 }
TreeEnsembleModel.featureImportances
计算思路很直觉,就是计算各个特征在每颗数上的总的纯净度提升量。这里主要有两点细节,一个是提升值会按节点中样本数加权,另一个是归一化。
代码如下:
155 def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
156 val totalImportances = new OpenHashMap[Int, Double]()
157 trees.foreach { tree =>
158 // Aggregate feature importance vector for this tree
159 val importances = new OpenHashMap[Int, Double]()
160 computeFeatureImportance(tree.rootNode, importances)
161 // Normalize importance vector for this tree, and add it to total.
162 // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
163 val treeNorm = importances.map(_._2).sum
164 if (treeNorm != 0) {
165 importances.foreach { case (idx, impt) =>
166 val normImpt = impt / treeNorm
167 totalImportances.changeValue(idx, normImpt, _ + normImpt)
168 }
169 }
170 } 171 // Normalize importances
172 normalizeMapValues(totalImportances)
173 // Construct vector
174 val d = if (numFeatures != -1) {
175 numFeatures
176 } else {
177 // Find max feature index used in trees
178 val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
179 maxFeatureIndex + 1
180 }
181 //+-- 4 lines: if (d == 0) {-----------------------------------------------------------------
185 val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
186 Vectors.sparse(d, indices.toArray, values.toArray)
187 }
其中:
各个节点是用递归遍历来汇总特征的提升量:
217 def computeFeatureImportance(
218 node: Node,
219 importances: OpenHashMap[Int, Double]): Unit = {
220 node match {
221 case n: InternalNode =>
222 val feature = n.split.featureIndex
223 val scaledGain = n.gain * n.impurityStats.count
224 importances.changeValue(feature, scaledGain, _ + scaledGain)
225 computeFeatureImportance(n.leftChild, importances)
226 computeFeatureImportance(n.rightChild, importances)
227 case n: LeafNode =>
228 // do nothing
229 }
230 }
这个没什么好讲的,gradient
就是一阶导数的代数式。给一个 L2 的例子:
23 //+-- 4 lines: *-----------------------------------------------------------------------------
27 //* The squared (L2) error is defined as:
28 //* (y - F(x))**2
29 //+-- 4 lines: * where y is the label and F(x) is the model prediction for features x.-------
33 object SquaredError extends Loss {
34
35 //+-- 3 lines: *-----------------------------------------------------------------------------
38 //* The gradient with respect to F(x) is: - 2 (y - F(x))
39 //+-- 5 lines: * @param prediction Predicted label.------------------------------------------
44 override def gradient(prediction: Double, label: Double): Double = {
45 - 2.0 * (label - prediction)
46 }
47
48 override private[spark] def computeError(prediction: Double, label: Double): Double = {
49 val err = label - prediction
50 err * err
51 }
52 }
In [ ]: