almost 3 years ago

KNN全名為K- Nearest Neighbors。中文意思是「西瓜偎大邊」,很容易跟KMeans搞混,KNN是一種監督式學習與Kmeans不同。以下簡單圖示說明概念。

假設我們有一綠色點,要判斷他是屬於紅色三角or藍色方框。最簡單的方法是找距離自己最近的k個點(圖中k=3)。k點之中,看誰人多勢眾(西瓜偎大邊)小弟就會被分類在那邊。如圖中,紅2>藍1,綠色小弟就變成紅色的。一般來說我們挑選k都會選擇奇數,盡量避免2打2的狀況。但是有時還是無法避免出現類似(2:2:1)的狀況。這個時候,就需要有一種更聰明的(潛)規則,來決定綠色點點會被歸類在哪邊。

若出現2:2平手狀況,數量是2 中,離綠點最遠的點,會首先被剔除。然後再比較一次,看誰的勢力大些,綠點就被歸類為誰的。

實例
假設,我們有一組資料,描述在全美地區的資料學家,喜歡使用python, R, Java來作資料分析的狀況。


可以假設,喜歡使用的程式語言有地域性。距離愈近的族群,使用相同語言的機率愈高。西瓜偎大邊的算法,和事實挺接近。

  1. 資料結構長相是 cities = [([lng1, lat1],language1), ([lng2,lat2],language2),...]
  2. 如果用原始資料當作訓練資料,不同的k值正確率如下。

  3. 計算結果如圖


程式碼

## knn

from __future__ import division
from collections import Counter,defaultdict
import numpy as np
import matplotlib.pyplot as plt
import math

def distance(u,v):
    return math.sqrt(sum((ui-vi)**2 for ui,vi in zip(u,v)))

def majority_vote(labels):
    ''' assumes that labels are order from nearest to farthest'''
    vote_counts = Counter(labels)
    winner,winner_counts = vote_counts.most_common(1)[0]
    num_winner = len([count for count 
                        in vote_counts.values()
                        if count == winner_counts])
    if num_winner==1:
        return winner
    else:
        return majority_vote(labels[:-1]) # try again without farthest


def knn_classify(k,label_points,new_point):

    ## evaluate the distance

    by_distance= sorted(label_points,
                    key= lambda (point,_):distance(point,new_point))
    knn_labels = [label for _,label in by_distance[:k]]
    return majority_vote(knn_labels)


def plot_raw_data():
    # scatter plot of raw data according to different favorite language

    group = defaultdict(list)
    for [lng,lat],language in cities:
        group[language].append((lng,lat))

    markers = {'Java':'o','Python':"s","R":"^"}
    colors = {"Java":"r",'Python':"b","R":"g"}

    for key in group.keys():
        xaxis,yaxis = zip(*group[key])
        plt.scatter(xaxis,yaxis,color=colors[key],
            marker=markers[key],label=key,zorder=10)
    plt.legend(loc=0)
    plt.axis([-130,-60,20,55])
    
## try different k and predict the correctness


for k in [1,3,5,7]:
    correctness = 0
    for city in cities:
        other_cities = [other for other in cities if other!=city]
        predict = knn_classify(k,other_cities,city[0])

        if predict == city[1]:
            correctness += 1
    print "k={},there is {} correct out of {} samples".format(k,correctness,len(cities))

## predict favorite language map based on knn algo


def predict_map(k,x_inputs,y_inputs,labelpoints):
    ## k, kth-nearest neighbor, x/yinput: range of scanning range

    predict_dict = defaultdict(list)
    for x in x_inputs:
        for y in y_inputs:
            label = knn_classify(k,labelpoints,[x,y])
            predict_dict[label].append((x,y))

    return predict_dict
klist =[1,3,5,7]

x = range(-130,-60,1)
y = range(20,55,1)
# pred = predict_map(k,x,y,cities)


def draw_predict():
    # row and column sharing

    f, axes = plt.subplots(2, 2, sharex='col', sharey='row')
    for k,ax in zip(klist,axes.flat):

        pred = predict_map(k,x,y,cities)

        markers = {'Java':'o','Python':"s","R":"^"}
        colors = {"Java":"r",'Python':"b","R":"g"}

        for language in pred.keys():
            xplot,yplot = zip(*pred[language])
            ax.scatter(xplot,yplot,color = colors[language],
                        marker=markers[language],label=language)
            ax.set_title('k={}'.format(k))
            ax.axis([-130,-60,20,55])

        
        # plt.legend()

        
        print "plotting k = {}".format(k)
    plt.xlabel('green -> R, blue -> python, red -> Java')
    plt.show()


← 推薦系統簡介(協同過濾演算法) 梯度下降法-Gradient Descent →
 
comments powered by Disqus