【損失関数】交差エントロピー誤差の概要と関数系を理解する

Posted: , Category: 分類問題 , 機械学習 , 深層学習

交差エントロピー誤差(Cross Entropy Error, Cross Entropy Loss)は、深層学習の分類問題で非常によく利用される損失関数です。

交差エントロピーと聞くと、初めて遭遇した人にとっては、ものすごく仰々しい名前に思われるかもしれませんが、実際の内容は大したことないので安心してください。

高校数学のlogと$\Sigma$(シグマ)について理解していれば、きっと簡単に理解できると思います。

今回は、たびたびニューラルネットワークにおける分類問題の損失関数で利用される、交差エントロピー誤差について解説していきます。

交差エントロピー誤差の定義式

まず最初に、交差エントロピー誤差の定義式を提示します。

今、モデルの予測値を$\bm{t}$、正解ラベルのone-hot-encoding表現$を\bm{y_k}$としたとき、その1つの予測値に対するモデルの交差エントロピー誤差は次のように定義されます。

交差エントロピー誤差(クロスエントロピー)
\begin{equation}
\operatorname{CrossEntropyLoss} = - \sum_{k=1}^{K}  t_k ln y_k
\end{equation}

(1)の式はかなり仰々しく感じるかもしれませんが、中身としては非常に簡単です。

こちらの図を見てください。

左側がニューラルネットワークの出力層を示しており、その出力値$\bm{y}$にログを取ったものと、正解ラベル$\bm{t}$をかけた値が、交差エントロピー $\sum_{k=1}^{K} t_k ln y_k$の値となっています。

しかし、上手の右側の青い背景の箇所を見るとわかりますが、基本的に$t$の値は、正解ラベルだけ1になり、それ以外の値は0なので、基本的に $t_k ln y_k$の値はゼロになり、正解ラベルの値だけ計算すれば良いことになります。

交差エントロピー誤差の関数系を可視化

正解ラベルの$t_k$は1なので、交差エントロピー誤差の関数系は、$- log(x)$となります。

$- log(x)$の形式は次のようになっています。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 1, 100)
y = -1 * np.log(x)

fig, ax = plt.subplots(figsize=(6,3), dpi=120)
ax.plot(x, y, lw=0.5, label=r"$y = -log(x)$")
ax.axhline(y=0, xmin=0, xmax=1,color='red', lw=0.5, ls='--', alpha=0.6)
ax.legend()

上のグラフを見ると、入力値が0に近いほど$\infin$になり、1に近づくほど0になることがわかります。

つまり、ニューラルネットワークの正解ラベルにおける出力値が、0に近いほどペナルティ(損失)がとても大きくなり、正解値に近い値を出力するほど、ペナルティはほとんどなくなります。

確かにこの損失関数を用いると、正しく誤差を表現できていることがわかりますね。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術