LionKing数据科学专栏

购买普通会员高级会员可以解锁网站精华内容且享受VIP服务的优惠

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏

决策树(decision tree)原理及算法实现

决策树的原理

决策树既可以用来解决分类问题,也可以用来解决回归问题。其思想是每次将数据关于某一个特征分成两部分:对于数值型特征,将数据关于某个阈值(threshold)划分成左右两部分;对于离散型特征,将数据关于该特征是否属于某一个集合划分成左右两部分。

决定按照哪一个特征以及关于该特征的哪一个阈值/集合划分的方法是尝试所有可能的划分,并计算该划分的一个指标,选择使得指标最低的划分。决策树的划分策略是贪心算法。

上图中,我们首先将数据按照$X$是否小于5分成了左右两部分,对于左边,我们关于$Y$是否小于3分成了左右两部分。

对于左边,我们预测结果为红色。对于右边,我们再关于$X$是否小于3分成了左右两部分,分别预测为红色/蓝色。

对于根结点的右边,即一开始满足$X \geqslant 5$的样本,我们再关于$Y$是否小于3.2分为左右两部分,分别预测为蓝色/红色。

下图展示了决策的过程,由于该决策过程很像一棵树,我们称该算法为决策树算法。

每次将数据划分为两部分之后,再对两个部分分别重复以上步骤,直到所有的数据都是一类或者划分数据带来的收益可以忽略不计。

为了防止过拟合,还需要对最终生成的树做一个剪枝(pruning)。

训练

决策树的主流算法有C4.5和CART,这里介绍CART算法。C4.5与CART的区别主要在于对连续型特征的处理和不同的打分方式。

决策树的核心步骤是对划分的打分。每一个划分会把现有的数据分成两个部分,我们对于两个部分分别计算一个不纯净度指标(impurity measure)$H_L, H_R$并按照结点个数相加得到$N_LH_L + N_RH_R$,其中$N_L$是左边的结点个数,$N_R$是右边的结点个数。选取使得这个值最小的一种划分。

对于分类问题,考虑一组数据$(x^{(1)}, y^{(1)}), \ldots, (x^{(n)}, y^{(n)})$,假设总共有$K$种可能的类型,并记录第$i$种类型出现的比例为$p_i = \frac{\sum_{j=1}^{n}I\{y^{(j)} = i\}}{n}$。有三种主要的不纯净度指标$H$:

对于回归问题,纯净度为均方误差(Mean Squared Error):$$H = \frac{1}{n}\sum_{i = 1}^{n}(y_i - \overline{y})^2$$其中$$\overline{y} = \frac{1}{n}\sum_{i=1}^{n}y_i$$

注意对于分类型特征,如果不同类型多于15,可能的划分方式超过$2^{14} - 1$,因此我们将该分类型特征进行某种排序将其视作数值型变量,以减少计算复杂度和防止过拟合。

预测

在训练好一棵决策树后,对于新的数据,我们需要对该数据进行预测。

首先将这个数据放在根结点(root),不断地根据当前结点的评判标准移到左子结点或右子结点,直到进入一个叶结点(leaf)。

对于分类问题,在叶结点中,我们知道各个类别的训练样本个数,直接返回最主要的类别即可。

对于回归问题,我们知道训练样本的目标变亮,取平均返回即可。

剪枝

当决策树生成的叶子过多时,训练数据上的效果会接近完美,但是由于得到的模型过于复杂,往往泛化能力不够理想。因此我们希望对树进行剪枝防止模型过拟合。

剪枝的第一种方法是预剪枝(prepruning):使用一部分留出(holdout)的数据作为交叉验证集,在建树阶段计算划分后在验证集上的效果,如果有所下降,则不进行这一步的划分。

预剪枝的缺点是有一些划分虽然导致了泛化性能下降,在此划分基础上的后续划分可以增加泛化能力。

剪枝的第二种方法是后剪枝(postpruning):在整棵树建好之后,从叶子出发,如果一个结点满足某个判定条件,则将此结点换做一个叶子,并且使用该结点的数据计算叶子上的预测值。

常见的判定条件包括:

Python实现

from sklearn import tree
import numpy as np
from matplotlib import pyplot as plt
import subprocess

n_train = 1000
n_test = 200

np.random.seed(10)
# 生成X
X = np.random.rand(n_train + n_test, 2)
# 生成y = (X1 < 0.5 and X2 < 0.5) or (X1 > 0.5 and X2 > 0.5) 并加上10%随机扰动
y = (X[:, 0] < 0.5) ^ (X[:, 1] < 0.5) ^ (np.random.rand(n_train + n_test) < 0.10)
y = y.reshape((n_train + n_test, 1))
X_train = X[:n_train, :]
y_train = y[:n_train, :]
X_test = X[n_train:, ]
y_test = y[n_train:, ]

plt.close()
plt.scatter(X_train[:, 0], X_train[:, 1], color=['r' if col else 'b' for col in y_train])
plt.savefig('decision-tree-data.png')

# 创建决策树分类训练器,限制最大深度为5
decision_tree_classifier = tree.DecisionTreeClassifier(max_depth=3)
# 训练模型
decision_tree_classifier.fit(X_train, y_train)
# 测试集预测
yhat_test = decision_tree_classifier.predict(X_test)
# 测试集错误率
err_test = np.mean(yhat_test ^ y_test.flatten())
print('测试集错误率为%f' % (err_test, ))

with open('tree.dot', 'w') as f:
  tree.export_graphviz(decision_tree_classifier, out_file=f)
subprocess.call(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png'])
        
输出:

测试集错误率为0.110000
        

以上程序生成的训练数据如下:

其中蓝色为0类,红色为1类。

训练的3层树如下:

对于$(X_1, X_2) = (0.2, 0.3)$,根结点的判断为False,我们移到右子树。

第二个判断为True,移动到左子树。

第三个判断为True,再次移动到左子树。

该子树有128个0和17个1,故分类为0。

常见面试问题

Q:决策树有哪些优缺点?

需要购买普通会员高级会员登录后刷新该页面查看

Q:决策树如何处理缺失值?

需要购买普通会员高级会员登录后刷新该页面查看

Q:对于分类问题,如果数据不平衡,是否会对决策树的结果有影响?

需要购买普通会员高级会员登录后刷新该页面查看

Q:CART和C4.5有哪些区别?

需要购买普通会员高级会员登录后刷新该页面查看

Q:决策树如何计算特征重要性(variable importance)?

需要购买普通会员高级会员登录后刷新该页面查看

更多机器学习相关问题见本网站论坛机器学习理论版面机器学习实践版面

更多面试问题见面试真题汇总

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏