Jiahonzheng's Blog

ML KNN

字数统计: 701阅读时长: 4 min
2019/09/01 Share

KNN is a nonparametric method used for classification and regression. It is a type of instance-based learning, or lazy learning, where the function is only approximated locally and all computation is differed until classification.

How does it work?

We use an example from Wikipedia to explain the decision-making process of KNN: The test sample (green dot) should be classified either to blue squares or to red triangles. If k = 3 (solid circle) it is assigned to the red triangles because there are 2 triangles and only 1 square inside the inner circle. If k = 5 (dashed line circle) it is assigned to the blue squares (3 squared vs. 2 triangles inside the outer circle).

So as we can see, KNN is based on feature similarity.

The general decision-making process of KNN is:

  • Calculate the distance between the query-instance and all the training samples.

  • Sort the distance.

  • Determine the nearest neighbors based on the K-th minimum distance.

  • Use simple majority of the category of the nearest neighbors as the prediction value of the query instance.

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt


class KNN:
# Initialize the model.
def __init__(self, k, x_train, y_train):
self._k = k
self._x_train = x_train
self._y_train = y_train

# Train the model.
def train(self):
# KNN does not need training because it is a lazy algorithm.
pass

# Make prediction.
def predict(self, x_test, y_test):
# Validate the length.
if len(x_test) != len(y_test):
raise ValueError("length doesn't match")

y_predict = []
for sample in x_test:
# Find the k nearest neighbors.
neighbors = self._find_k_nearest_neighbors(sample)
# Vote for the neighbors.
predict_result = self._vote(neighbors)
y_predict.append(predict_result)

return y_predict

# Evaluate the prediction.
def evaluate(self, y_test, y_predict):
correct_count = 0
for i in range(len(y_test)):
if y_predict[i] == y_test[i]:
correct_count += 1
else:
continue
acc = correct_count / len(y_test)

return acc

def _find_k_nearest_neighbors(self, sample):
# Calculate similarity.
distance_list = []
for sample_train in self._x_train:
# Important! Here we calculate the similarity of two digits.
# Tips of "norm":
# Execute "np.linalg.norm([3, 4])" prints 5, because sqrt(3^2, 4^2) = 5.
dist = np.linalg.norm(sample_train - sample)
distance_list.append(dist)

# Find the k nearest neighbors.
distance_rank = np.argsort(distance_list)
k_nearest_neighbors = distance_rank[:self._k]

return k_nearest_neighbors

def _vote(self, neighbors):
# Find the candidate target.
target_list = []
for item in neighbors:
target_list.append(self._y_train[item])

# Start the vote.
result = max(target_list, key=target_list.count)

return result


if __name__ == '__main__':
# Import digits data.
digits = datasets.load_digits()

# Dataset Visualization
# pylint: disable = no-member
images_and_labels = list(zip(digits.images, digits.target))
# Show each image and corresponding label on the screen.
for index, (image, label) in enumerate(images_and_labels[:10]):
plt.subplot(2, 5, index + 1)
plt.axis('off')
plt.imshow(image)
plt.title('Label: %i' % label)
plt.show()

# It contains 1797 (8 by 8) digit images.
image_set = digits['data']
# It contains the corresponding answers to the digits.
target_set = digits['target']

# Split the dataset into train set and test set.
train_rate = 0.5
sample_num = len(image_set)
x_train_set = image_set[:int(train_rate * sample_num)]
y_train_set = target_set[:int(train_rate * sample_num)]
x_test_set = image_set[int(train_rate * sample_num):]
y_test_set = target_set[int(train_rate * sample_num):]

# Show the performance of different hyperparameter k.
acc_list = []
for neighbor_num in range(1, 20):
agent = KNN(neighbor_num, x_train_set, y_train_set)
y_predict = agent.predict(x_test_set, y_test_set)
acc = agent.evaluate(y_test_set, y_predict)
acc_list.append(acc)
print(f'Prediction Accuracy with {neighbor_num} neighbors: {acc}')
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
ax.plot(acc_list, color='r', label='Accuracy')
ax.legend(loc=1)
plt.show()

We apply the KNN algorithm in the hand written single digit classification problem, where each input sample is an $8 \times 8$ image. Half of them are used as training data and others are used for testing. And we use different $K$ as hyper-parameter to compare their prediction performances.

As we can see from the above figure, the choice of $K$ has a great impact on the accuracy.

Optimization

There are some fast adaptations of the original KNN algorithm, to name a few:

Resources

原文作者:Jiahonzheng

原文链接:https://blog.jiahonzheng.cn/2019/09/01/ML KNN/

发表日期:September 1st 2019

更新日期:November 26th 2019

版权声明:本文采用知识共享署名-非商业性使用 4.0 国际许可协议进行许可

CATALOG
  1. 1. How does it work?
  2. 2. Show me the code!
  3. 3. Optimization
  4. 4. Resources