KNN最近邻算法是一种监督学习算法,基本思想是取距离测试数据最近的K个点,这K个点训练数据属于某一类型的数量多,则将该测试数据点也判断为该类。
距离可以取:
1.欧氏距离:
2.曼哈顿距离:
算法:
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在的类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
# -*- coding:utf-8 -*-
import math
import numpy as np
from matplotlib import pyplot
from matplotlib.patches import Ellipse, Circle
import matplotlib.pyplot as plt
from collections import Counter
import warnings
#k-Nearest Neighbor算法
# 训练数据data
# 测试数据predict
# 参数k
def k_nearest_neighbors(data, predict, k=5):
print("k=%d"%k)
if len(data) >= k:
warnings.warn("k is too small")
# 计算predict点到各点的距离
distances = []
for group in data:
for features in data[group]:
# euclidean_distance = np.sqrt(np.sum(np.array(features)-np.array(predict))**2))
euclidean_distance = np.linalg.norm(np.array(features)-np.array(predict))
distances.append([euclidean_distance, group])#记录距离和组别
sorted_distance = sorted(distances)
print(sorted_distance)#对距离进行排序
sorted_distance_kind = [i[1] for i in sorted_distance]#取出其中的类别
top_nearest = sorted_distance_kind[:k]#取前k个样本所对用的类别
print(top_nearest)
# Counter用于频率数,Counter(top_nearest).most_common(1)[0]=('red', 2)
group_res = Counter(top_nearest).most_common(1)[0][0]
confidence = Counter(top_nearest).most_common(1)[0][1] * 1.0 / k#乘以1.0为了映射为浮点型数
#confidence是对本次分类的确定程度,例如(red,red,red),(red,red,black)都分为red组,但red组更为自信
return group_res, confidence, sorted_distance[k_val-1][0]
if __name__ == '__main__':
dataset = {'black': [[1,2], [2,3], [3,1]], 'red':[[6,5], [7,7], [8,6]]}
new_features = [3.5, 5.2]
fig = plt.figure()
ax = fig.add_subplot(111)
k_val = 3
for i in dataset:
for ii in dataset[i]:
pyplot.scatter(ii[0], ii[1], s=50, color=i)
which_group, confidence, radius_val = k_nearest_neighbors(dataset, new_features, k_val)
print(which_group, confidence, radius_val)
pyplot.scatter(new_features[0], new_features[1], s=100, color="g")
#画圆
cir = pyplot.Circle(xy = new_features, radius=radius_val, color = 'b', fill=False)
ax.add_patch(cir)
pyplot.show()
运行结果为:
red 0.6666666666666666
即判断测试数据为分类为红色类,
原因是所取的k=3的范围内有两个红色训练数据,一个红色训练数据;
置信度为2/3即0.666666666.