LionKing数据科学专栏

购买普通会员高级会员可以解锁网站精华内容且享受VIP服务的优惠

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏

梯度提升树(GBDT)原理及算法实现

梯度提升树的原理

梯度提升树(Gradient Boosting Decision Tree)又叫MART(Multiple Additive Regression Tree),是一种基于决策树的集成学习算法。

经典的GBDT只能用来解决回归问题,改进后可以解决分类问题。

回归问题较容易理解,其思路是学习一系列的树$T_1, T_2, \ldots, T_B$,使得$T_1 + T_2 + \ldots + T_B$尽可能接近目标变量$y$。另一种解读是每次用一棵决策树学习当前模型的残差(residual)并且加入当前模型。

分类问题同样使用回归树,但是学习的目标从残差变成了交叉熵损失函数的负梯度。

算法

首先给出回归问题的算法:

对于一般的损失函数$L(y, \hat{y})$,我们将回归问题的算法推广如下:

上述算法较为繁琐,需要记住的核心思想是我们用损失函数的负梯度来替代原来使用的残差。残差等于平方损失函数下的负梯度。这个推广后的算法可以用来解决分类问题。

对于二分类问题,只需要选择交叉熵损失函数带入上述算法即可:

$$L(y, c) = -y\log{p} - (1 - y)\log{(1 - p)}$$

其中$p = \sigma(c) = \frac{1}{1 + e^{-c}}$

GBDT还可以推广到多分类问题。限于篇幅不予讨论。

Python实现

不同于以往的机器学习模型,我们使用xgboost这个库来训练梯度提升树模型。xgboost对原有的GBDT模型进行了一系列的优化,从而在计算效率和模型精度上都有所提高。

import xgboost
from sklearn import datasets
import numpy as np

n_train = 1000
n_test = 200
p = 10

# 生成X, y
X, y = datasets.make_regression(n_samples=n_train + n_test, n_features=p, n_informative=2, random_state=0)

# 拆分训练和测试数据集
X_train = X[:n_train, :]
y_train = y[:n_train]
X_test = X[n_train:, :]
y_test = y[n_train:]

# 创建梯度提升树回归器
gbdt_regressor = xgboost.XGBRegressor(max_depth=4, n_estimators=200)
# 训练模型
gbdt_regressor.fit(X_train, y_train)
# 训练集预测
yhat_train = gbdt_regressor.predict(X_train)
# 训练均方误差
residual_train = yhat_train - y_train
mse_train = np.mean(residual_train ** 2)
print('训练均方误差为%f' % (mse_train, ))
# 测试集预测
yhat_test = gbdt_regressor.predict(X_test)
# 测试均方误差
residual_test = yhat_test - y_test
mse_test = np.mean(residual_test ** 2)
print('测试均方误差为%f' % (mse_test, ))
# 特征重要性
variable_importance = gbdt_regressor.feature_importances_
print('特征重要性为%s' % (variable_importance, ))
        
输出:

训练均方误差为4.743532
测试均方误差为115.646368
特征重要性为[0.10170162 0.07479224 0.06133755 0.05817175 0.24693312 0.24416304
 0.05263158 0.06015038 0.05184013 0.04827859]
        

常见面试问题

Q:随机森林和GBDT有什么异同?

需要购买普通会员高级会员登录后刷新该页面查看

Q:xgboost是什么?xgboost和传统GBDT有什么不同?

需要购买普通会员高级会员登录后刷新该页面查看

更多机器学习相关问题见本网站论坛机器学习理论版面机器学习实践版面

更多面试问题见面试真题汇总

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏