• 一、概述

         K近邻算法是对事物进行分类的机器学习算法。简单来说就是给定一组样本,对新的输入样本,在已有样本中找到与该实例最邻近的K个例子, 当这K个例子的大多属于某个类别时,就把该输入样本分类到这个类之中。

        举例来说,小说有许多不同种类,暂且分成言情小说和玄幻小说,通过统计小说中的恋爱情节和打斗情节得出一组数据

        

 


当我们把该数据通过图表显示

此时,当有一部小说X的打斗情节为6,恋爱情节为2,会属于哪一类小说?(五角星代表小说X)


       计算小说X的数据到样本数据的距离,设k为3,则需要找到距离最小的三个点,当某一类更多时将小说X归于该类小说。这就是通过计算距离总结k个最邻近类并按照少数服从多数原则分类的K近邻算法。

  • 二、算法原理介绍

        以上面数据为例,打斗情节为x,恋爱情节为y,通过欧氏距离公式 l=\sqrt{\left ( x_{2}-x_{1} \right )^{2}+\left (y _{2}-y_{1} \right )^{2}} 计算该点与各个样本数据的距离,当然实际需要的数据有很多是多维的,但如三维的话可以在根号中加入z轴差的平方,也可以计算距离

        现在再计算各个数据的距离,可以得到四个点的距离,将k定为3,则取其中最小的三个数,显而易见这三个数中有两个属于玄幻类,那么小说X属于玄幻小说。

  • 三、算法实现

 

    trainData - 样本数据
    testData - X数据
    labels - 分类

def knn(trainData, testData, labels, k):
   
    rowSize = trainData.shape[0] # 计算样本数据的行数
   
    diff = tile(testData, (rowSize, 1)) - trainData 
    # 计算样本数据和X数据的差值
   
    sqrDiff = diff ** 2
    sqrDiffSum = sqrDiff.sum(axis=1) # 计算差值的平方和
   
    distances = sqrDiffSum ** 0.5 # 计算距离
   
    sortDistance = distances.argsort() 
    
    count = {}
    
    for i in range(k):
        vote = labels[sortDistance[i]]
        count[vote] = count.get(vote, 0) + 1
    # 对所得距离从低到高进行排序
    sortCount = sorted(count.items(), key=operator.itemgetter(1), reverse=True)
    # 对类别出现的频数从高到低进行排序
    
    
    return sortCount[0][0] # 返回出现频数最高的类别
  • 四、测试结果

        

trainData = np.array([[0, 10], [1, 8], [10, 1], [7, 4]])
labels = ['言情', '言情', '玄幻', '玄幻']
testData = [6, 2]
X = knn(trainData, testData, labels, 3)
print(X)

        执行上列代码后得到结果为玄幻,与计算的相同。

        当然实际上该算法不可能100%准确,还是会有不少的偏差,要更接近准确还需要更多的实验样本来减少可能的偏差。

 

 

 

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐