k近傍法(k-Nearest Neighbor algorithm, k-NN法)

R

機械学習アルゴリズムの中で最も単純なものの一つで、分類を目的に使用されるk近傍法についてまとめています。

1.どんな時に使えるの?

k近傍法は、機械学習アルゴリズムの中で最も単純なものの一つで、分類を目的に使用されます。

訓練用データとして、特徴量がn組(\(\mathbf{x_1},\mathbf{x_2},…,\mathbf{x_n}\))あり、それぞれに対しラベル(\(y_1,y_2,…,,y_n\))が与えられている時、未知のデータ\(\mathbf{x}\)に対して\(y\)を推定する時に使用することができます。

たとえば、食べ物の分類を行う場合であれば、特徴量として「炭水化物」「脂質」「タンパク質」を選び、分類ラベルを「肉」「野菜」「果物」とすると、以下のような\(\mathbf{x},y\)を考えるということです。

\(\mathbf{x}\) =(炭水化物, 脂質, タンパク質)(食べ物100gに含まれる量)、
\(y \in \) (肉、野菜、果物)

\(\mathbf{x_1}=(10.1, 0.1, 1.0)\quad y_1=\)果物

\(\mathbf{x_2}=(0.1, 1.9, 23.8)\quad y_2=\)肉

\(\mathbf{x_3}=(2.5, 0.3, 1.8)\quad y_3=\)野菜

\(\mathbf{x_4}=(0.3, 19.1, 19.1)\quad y_4=\)肉

(ちなみに1はグレープフルーツ、2は鶏胸肉、3はほうれん草、4は豚ロース肉です。)

これに対し、肉なのか、野菜なのか、果物なのかわからない、\(\mathbf{x}=(4.8, 0.1, 0.7)\)というデータが与えられた時、これがどれに分類されるのかを推定する場合にk近傍法を使うことができます(ちなみにこの\(\mathbf{x}\)はトマトです)。

2. k近傍法のロジック

計算のロジックは極めて簡単です。

推定したいデータ\(\mathbf{x}\)と訓練用の各データとの距離を計算し、その中から距離が最も近いk個のデータを選択し、選択されたk個のデータの多数決でデータ\(\mathbf{x}\)の分類推定値を決定します。

3.特記

3.1 距離

通常、ユークリッド距離が使われますが、マンハッタン距離やその他の距離を使うことも可能です。

ユークリッド距離 = \(\sqrt{\sum_{k=1}^p(x_k-x_{ik})^2}\)

マンハッタン距離 = \(\sum_{k=1}^p|x_k-x_{ik}|\)

(\(p\)はデータ\(\mathbf{x}\)の次元(特徴量の個数))

3.2 データの正規化

1項の例にあげた食べ物の分類の訓練用データの例では、炭水化物、脂質、タンパク質の量は100gあたりの量なので、原理上各要素とも0〜100までの範囲になります(実際には水分とかもあるので、100になったりはしませんが)。

モデルの推定精度を高めるために、ここに「酸味」という特徴量を入れることを考えます。

酸味を測る単位として「pH(ペーハー)」がありますが、pHの値は水で7.0、レモンで約2.0という値になります。炭水化物、脂質、タンパク質の場合、食品ごとの差は数十程度はあるのに対し、pH値で評価すると酸味の差は多くても5程度と他の項目よりも差が小さな値になります。

このため、与えられたデータそのままで距離を計算する場合、酸味の大小差の影響は、炭水化物、脂質、タンパク質よりも距離の評価に反映されにくい状態になっています。

そこで、各要素の影響を適切に距離に反映するために、通常、データの正規化が行われます。データの正規化方法としてよく使われるのは、最小最大正規化やZスコア標準化といわれるものです。データを正規化することにより、各特徴量の影響を同程度に距離の評価に反映することができるようになります。

<最小最大正規化>
$$
x_{new,ik} = \frac{x_{ik}-min(x_i)}{max(x_i)-min(x_i)}\\
max(),min()はそれぞれ最大値、最小値を求める関数
$$

Zスコア標準化
$$
x_{new,ik} = \frac{x_{ik}-mean(x_i)}{sd(x_i)}\\
mean(),sd()はそれぞれ平均、標準偏差を求める関数
$$

3.3 kの値

kの値をいくつにするかで、モデルの性能が変わってきます。

k = 1とした場合、推定分類値はもっとも距離が近いデータと同じ値になります。データにノイズ(外れ値、誤り)が多かったりすると、ノイズの影響で分類の性能が落ちてしまいます。逆に、kをデータサイズまで大きくした場合、推定分類値は常に、訓練データの中で1番多い分類の値となります。

kの値はデータに合わせて適切に選択する必要がありますが、訓練データ数の平方根をkとするという方法はよく使われる方法の一つであり、kを設定する際のひとつの目安として使えます。

4.Rでの実装

irisのデータを使って、k近傍法で品種の分類推定を行ってみます。irisのデータには150個のデータが入っているので、このうち、100個を訓練用のデータとして使用し、残りの50個に対して分類の推定を行い、正答率を評価してみました。

k近傍法は単純な計算方法ですが、正答率98%とかなり高い結果が得られました

#-------------------------------------------------------------
#  kNN(d_train, y_train, d_test)
#    d_train : 正規化された訓練データ(matrix)
#    y_train : 訓練データのラベル(vector)
#    d_test  : 正規化されたテストデータ(1データ、vector)
#-------------------------------------------------------------
kNN <- function(d_train, y_train, d_test){
  n_train <- nrow(d_train)
  #testデータがn_train行並んだ行列の作成
  d <- matrix(rep(d_test,n_train), byrow = TRUE, nrow = n_train)

  #testデータと訓練データの差の計算
  d <- d_train - d

  #testデータと訓練データの距離の計算
  dis <- apply(d, MARGIN = 1, function(x){sqrt(sum(x^2))})

  #距離の短い順に並べて、最初からk個を選択して多数決
  # 同数の場合は、決着が着くまで1個ずつ減らして多数決
  for (i in k:2){
    k_nearest <- order(dis)[1:i]
    k_nearest_y <- y_train[k_nearest]

    #k個のデータの中で各ラベルがいくつあるか計算
    count_data <- table(k_nearest_y)
    count_data <- sort(count_data, decreasing = TRUE)
    if(count_data[1] != count_data[2]) {
      break
    } 
  }
  names(count_data)[1]
}

#-----------------------------------------------------------
#   ここから下がMainルーチン
#-----------------------------------------------------------

df <- iris     # irisのデータを使用
n <- nrow(df)  # n = データ数

n_train <- 100  # 訓練用に使うデータの数
k = as.integer(sqrt(n_train))  # kの設定。とりあえずsqrt(n_train)

# 訓練用データの選択(何行目のデータか)
train <- sample(1:n, n_train) 
lg <- rep(FALSE, n)
lg[train] <- TRUE

x_train <- df[lg, -5]
y_train <- df[lg, 5]

# test用データ
x_test <- df[!lg, -5]
y_test <- df[!lg, 5]


#訓練データの平均、標準偏差の計算
train_mean <- sapply(x_train, mean)
train_sd <- sapply(x_train, sd)

#訓練データの正規化
d_train <- as.matrix(x_train)
n_data <- nrow(x_train)
d_mean <- matrix(rep(train_mean, n_data), byrow = TRUE, nrow = n_data)
d_sd <- matrix(rep(train_sd, n_data), byrow = TRUE, nrow = n_data)

d_train <- (d_train - d_mean)/d_sd  # 訓練用正規化データ

#テストデータの正規化(平均、標準偏差は訓練用データ)
d_test <- as.matrix(x_test)
n_data <- nrow(x_test)
d_mean <- matrix(rep(train_mean, n_data), byrow = TRUE, nrow = n_data)
d_sd <- matrix(rep(train_sd, n_data), byrow = TRUE, nrow = n_data)

d_test <- (d_test - d_mean)/d_sd   # test用正規化データ

#test用データに対する推定値を計算
y_pred <- apply(d_test, MARGIN = 1, kNN, d_train = d_train, y_train = y_train)

df_pred <- data.frame(answer = y_test, 
                      pred = y_pred, 
                      is.correct = (y_test == y_pred))

結果

print(df_pred)
##         answer       pred is.correct
## 3       setosa     setosa       TRUE
## 4       setosa     setosa       TRUE
## 5       setosa     setosa       TRUE
## 8       setosa     setosa       TRUE
## 9       setosa     setosa       TRUE
## 11      setosa     setosa       TRUE
## 16      setosa     setosa       TRUE
## 27      setosa     setosa       TRUE
## 30      setosa     setosa       TRUE
## 35      setosa     setosa       TRUE
## 36      setosa     setosa       TRUE
## 41      setosa     setosa       TRUE
## 46      setosa     setosa       TRUE
## 47      setosa     setosa       TRUE
## 49      setosa     setosa       TRUE
## 50      setosa     setosa       TRUE
## 52  versicolor versicolor       TRUE
## 54  versicolor versicolor       TRUE
## 55  versicolor versicolor       TRUE
## 56  versicolor versicolor       TRUE
## 57  versicolor versicolor       TRUE
## 62  versicolor versicolor       TRUE
## 63  versicolor versicolor       TRUE
## 66  versicolor versicolor       TRUE
## 67  versicolor versicolor       TRUE
## 69  versicolor versicolor       TRUE
## 71  versicolor versicolor       TRUE
## 72  versicolor versicolor       TRUE
## 80  versicolor versicolor       TRUE
## 82  versicolor versicolor       TRUE
## 90  versicolor versicolor       TRUE
## 96  versicolor versicolor       TRUE
## 97  versicolor versicolor       TRUE
## 98  versicolor versicolor       TRUE
## 99  versicolor versicolor       TRUE
## 101  virginica  virginica       TRUE
## 103  virginica  virginica       TRUE
## 107  virginica versicolor      FALSE
## 108  virginica  virginica       TRUE
## 109  virginica  virginica       TRUE
## 116  virginica  virginica       TRUE
## 117  virginica  virginica       TRUE
## 122  virginica  virginica       TRUE
## 125  virginica  virginica       TRUE
## 128  virginica  virginica       TRUE
## 133  virginica  virginica       TRUE
## 137  virginica  virginica       TRUE
## 138  virginica  virginica       TRUE
## 139  virginica  virginica       TRUE
## 141  virginica  virginica       TRUE
cat(paste0("正答率 : " ,mean(y_test == y_pred)*100, " [%]"))
## 正答率 : 98 [%]