LionKing数据科学专栏

购买普通会员高级会员可以解锁网站精华内容且享受VIP服务的优惠

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏

k近邻(k nearest neighbors)原理及算法实现

k近邻的原理

假设你有一个数据集,希望通过一些特征$x_1, \ldots, x_p$预测目标变量$y$,$y$可以是连续型变量(回归问题)或者离散型变量(分类问题)。k近邻算法的思路是在训练数据中寻找距离新数据最近的$k$个观测,对它们的$y$取多数投票(majority vote)或者平均(average)。

算法

k近邻算法无需训练部分。

记训练数据为$(x^{(1)}, y^{(1)}), \ldots, (x^{(n)}, y^{(n)})$。对于新的观测点$x \in \mathbb{R}^p$,寻找与$x$欧式距离(Euclidean distance)最小的$x^{(i_1)}, \ldots, x^{(i_k)}$。

对于分类问题,我们取$y^{(i_1)}, \ldots, y^{(i_k)}$中出现最多的一类作为$x$类别的预测。

对于回归问题,我们取$\frac{1}{k}\sum_{j=1}^{k}y^{(i_j)}$作为$x$对应的目标的预测。

(图片来自wikipedia)

上图是一个分类问题用k近邻算法的实例。其中绿色的点是测试数据。红色的点目标变量为1,蓝色的点目标变量为0。

如果选取$k = 3$,则最近的3个点中有2个1和1个0,测试数据被预测为1。

如果选取$k = 5$,则最近的5个点中有2个1和3个0,测试数据被预测为0。

对于分类问题,如果有数据不平衡(imbalanced dataset)的情况,则传统的做法倾向于将数据分为主流的类别。此时可以通过对于不同类别加权来改善结果。我们可以对于最近的$k$个邻居关于与$x$的欧式距离的倒数$w_j = \frac{1}{\|x - x^{(i_j)}\|_2}$作为权重,

对于分类问题,取权重之和最大的一类作为预测;对于回归问题,取加权平均$$\frac{\sum_{j=1}^{k}w_jy^{(i_j)}}{\sum_{j=1}^{k}w_j}$$作为预测。

Python实现

from sklearn import neighbors
import numpy as np

n_train = 1000
n_test = 200

# 生成X
X_train = np.concatenate([np.random.normal(0, 1, (n_train // 2, 2)), np.random.normal(2, 1, (n_train // 2, 2))])
X_test = np.concatenate([np.random.normal(0, 1, (n_test // 2, 2)), np.random.normal(2, 1, (n_test // 2, 2))])
# 生成y
y_train = np.concatenate([np.ones(n_train // 2), np.zeros(n_train // 2)])
y_test = np.concatenate([np.ones(n_test // 2), np.zeros(n_test // 2)])

# 创建k近邻分类训练器
knn_classifier = neighbors.KNeighborsClassifier(n_neighbors=5)
# 训练模型
knn_classifier.fit(X_train, y_train)
# 测试句预测
yhat_test = knn_classifier.predict(X_test)
# 测试错误率
err_test = np.mean(y_test != yhat_test)
print('测试错误率为%f' % (err_test, ))
        
输出:

测试错误率为0.105000
        

方差偏差均衡(bias variance tradeoff)

随着$k$的增大,k近邻模型的偏差会增加,而方差会减少。

直观上,对于较大的k,很远的训练数据也会用来预测,容易造成结果不准确。另一方面,因为用了更多的训练数据,结果相对稳定故方差更小。

需要购买普通会员高级会员登录后刷新该页面查看

太小的k容易过拟合(overfitting),太大的k容易欠拟合(underfitting)。为了选取最优的k,需要对数据进行交叉验证(cross validation)。

常见面试问题

Q:如果$x$不是连续型特征,而是分类型特征应当如何处理?

需要购买普通会员高级会员登录后刷新该页面查看

Q:k近邻算法需要怎样的数据预处理(data preprocessing)?

需要购买普通会员高级会员登录后刷新该页面查看

更多机器学习相关问题见本网站论坛机器学习理论版面机器学习实践版面

更多面试问题见面试真题汇总

想要查看更多数据科学相关的内容请关注我们的微信公众号知乎专栏