KNN 算法的介绍和实现

算法原理介绍

1. 算法描述

1.png

2. KNN 的三要素

(1)k 值得选择

2.png

(2)距离度量

3.png

4.png

(3)分类决策规则

KNN 算法的分类决策规则一般是多数表决,即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。

KNN 算法的实现与测试

1. KNN 分类器的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from collections import Counter


class KnnClassifier:
"""
a KNN classifier with L2 distance
"""
def __init__(self, k):
"""
:param k: 最近邻数量
"""
self.k = k
self.X_train = None
self.y_train = None

def fit(self, X_train, y_train):
"""
训练 KNN 分类器,对于KNN来说,训练仅仅是记录训练集
:param X_train: 训练集,shape ==> (N, D) , N是训练集数量,D是训练数据的维度
:param y_train: 训练集标签, shape ==> (N,)
"""
self.X_train = X_train
self.y_train = y_train

def predict(self, X_test):
"""
对测试集进行预测
:param X_test: 测试集,shape ==> (n, D),n是测试集数量,D是测试数据的维度
:return: 预测结果,shape ==> (n,)
"""

n = X_test.shape[0]
y_pred = np.zeros(n)

dists = self._compute_dists(X_test) # (n, N)

for i in range(n):
top_k = self.y_train[np.argsort(dists[i])[:self.k]]

'''
the meaning of Counter(top_k).most_common(1):
List the n most common elements and their counts from the most common to the least.
If n is None, then list all element counts.

>>> Counter('abcdeabcdabcaba').most_common(3) ==> [('a', 5), ('b', 4), ('c', 3)]

'''

y_pred[i] = Counter(top_k).most_common(1)[0][0]

return y_pred

def _compute_dists(self, X_test):
"""
计算每一个测试数据与训练集的距离
:param X_test: 测试集,shape ==> (n, D),n是测试集数量,D是测试数据的维度
:return: 测试集和训练集的距离,shape ==> (n, N) ,n是测试集数量,N是训练集数量,
每一行表示一个测试数据和所有训练集的距离
"""
n = X_test.shape[0]
N = self.X_train.shape[0]
dists = np.zeros((n, N))

for i in range(n):
dists[i, :] = np.sqrt(np.sum(np.square(self.X_train - X_test[i, :]), axis=1))

return dists

2. 使用sklearn自带的红酒数据集对KNN进行测试

导包

1
2
3
4
5
from knn import KnnClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

探索数据,并划分训练集和测试集

1
2
3
4
5
6
7
data = datasets.load_wine()
print(data.data.shape, data.target.shape,np.unique(data.target))
# output: ((178, 13), (178,), array([0, 1, 2]))

X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=1)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
# output:((142, 13), (36, 13), (142,), (36,))

简单测试KNN

1
2
3
4
5
6
7
8
9
10
knn = KnnClassifier(k=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

num_correct = np.sum(y_pred == y_test)
acc = float(num_correct)/ len(y_test)

print("Got {}/{}, acc = {:0.3f}".format(num_correct, len(y_test), acc))

# output: Got 24/36, acc = 0.667

对k的不同取值进行测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
accs = []

for k in range(1,16):
knn = KnnClassifier(k)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

num_correct = np.sum(y_pred == y_test)
acc = float(num_correct)/ len(y_test)

accs.append(acc)

print("max_acc: {}, best_k = {}".format(max(accs), accs.index(max(accs)) + 1))

#output: max_acc: 0.75, best_k = 1

plt.figure(figsize=[20, 5])
plt.plot(range(1,16), accs)
plt.xlim(0,16)
plt.xticks(range(0,17))
plt.xlabel("value of k", fontsize=16, fontweight='normal')
plt.ylim(0,1.0)
plt.yticks(np.arange(0.0,1.1,0.1))
plt.ylabel("value of acc", fontsize=16, fontweight='normal')
plt.show()

5.png

当前数据集,当 k=1 时,取得最高准确率。

对sklearn自带的KNN分类器进行测试

接口

官网链接

导包

1
2
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier

使用交叉验证和网格搜索测试

1
2
3
4
5
6
7
param_grid = {'n_neighbors': np.arange(1, 10)}
knn = KNeighborsClassifier(n_neighbors=1)
GS = GridSearchCV(knn, param_grid, cv=10)
GS.fit(data.data, data.target)

print(GS.best_params_, GS.best_score_)
# output: {'n_neighbors': 1} 0.7471910112359551

KNN 的优缺点

KNN 实现简单,是一个判别式、非线性模型,适合在低维度空间中使用,当训练样本大且向量维度高时,计算量非常大,所以需要对大数据量进行一定的处理,此时需要继续学习 kd tree

参考

  1. 《统计学习方法》
  2. 木东居士公众号:KNN
  3. 机器学习-KNN算法