交差エントロピー誤差(Cross Entropy Error, Cross Entropy Loss)は、深層学習の分類問題で非常によく利用される損失関数です。
交差エントロピーと聞くと、初めて遭遇した人にとっては、ものすごく仰々しい名前に思われるかもしれませんが、実際の内容は大したことないので安心してください。
高校数学のlogと$\Sigma$(シグマ)について理解していれば、きっと簡単に理解できると思います。
今回は、たびたびニューラルネットワークにおける分類問題の損失関数で利用される、交差エントロピー誤差について解説していきます。
交差エントロピー誤差の定義式
まず最初に、交差エントロピー誤差の定義式を提示します。
今、モデルの予測値を$\bm{t}$、正解ラベルのone-hot-encoding表現$を\bm{y_k}$としたとき、その1つの予測値に対するモデルの交差エントロピー誤差は次のように定義されます。
\begin{equation} \operatorname{CrossEntropyLoss} = - \sum_{k=1}^{K} t_k ln y_k \end{equation}
(1)の式はかなり仰々しく感じるかもしれませんが、中身としては非常に簡単です。
こちらの図を見てください。
左側がニューラルネットワークの出力層を示しており、その出力値$\bm{y}$にログを取ったものと、正解ラベル$\bm{t}$をかけた値が、交差エントロピー $\sum_{k=1}^{K} t_k ln y_k$の値となっています。
しかし、上手の右側の青い背景の箇所を見るとわかりますが、基本的に$t$の値は、正解ラベルだけ1になり、それ以外の値は0なので、基本的に $t_k ln y_k$の値はゼロになり、正解ラベルの値だけ計算すれば良いことになります。
交差エントロピー誤差の関数系を可視化
正解ラベルの$t_k$は1なので、交差エントロピー誤差の関数系は、$- log(x)$となります。
$- log(x)$の形式は次のようになっています。
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 100)
y = -1 * np.log(x)
fig, ax = plt.subplots(figsize=(6,3), dpi=120)
ax.plot(x, y, lw=0.5, label=r"$y = -log(x)$")
ax.axhline(y=0, xmin=0, xmax=1,color='red', lw=0.5, ls='--', alpha=0.6)
ax.legend()
上のグラフを見ると、入力値が0に近いほど$\infin$になり、1に近づくほど0になることがわかります。
つまり、ニューラルネットワークの正解ラベルにおける出力値が、0に近いほどペナルティ(損失)がとても大きくなり、正解値に近い値を出力するほど、ペナルティはほとんどなくなります。
確かにこの損失関数を用いると、正しく誤差を表現できていることがわかりますね。