【深層学習】VAEとは?VAEの理論を徹底的に分かりやすく解説する

Posted: , Category: ベイズ統計 , 機械学習 , 深層学習

VAE(Variational Auto Encoder, 変分オートエンコーダー)は、ニューラルネットワークと取り入れた、生成モデルのアルゴリズムです。

生成モデルと言っても、PCA(主成分分析)やSVD(特異値分解)で行っているような、次元削減もエンコード層でおこなっています。

今回は、VAE(Variational Auto Encoder)についてわかりやすく解説していきます。

VAEの全体モデル(アーキテクチャ)

まず最初に、VAEの全体のアーキテクチャを示します。

VAEが特徴的なのは、Encoderでは、潜在変数$\bm{z}$をニューラルネットで直接出力するのではなく、ガウス分布のパラメータ$\mu$と$\Sigma$を生成し、そのガウス分布に従って潜在変数$\bm{z}$をサンプリングします。

さらに、ガウス分布からサンプリングされた潜在変数$\bm{z}$を、Decoderで元のデータ分布に復元するモデルとなっています。

このような、エンコーダとデコーダの要素が含まれている、VAEと似ている既存手法としてAE(オートエンコーダ)があります。しかし、オートエンコーダは、エンコード層が直接潜在変数$\bm{z}$を出力する一方で、VAEは上記での述べたように、$\bm{z}$がガウス分布に従うと仮定して、そのガウス分布を出力するところにあります。

以下で、VAEのエンコーダとデコーダに対する詳しい解説をしていきます。

VAEにおけるエンコーダー(Encoder)

VAEにおけるエンコーダ(Encoder)のアーキテクチャはこのようになっています。

一番最初に提示したVAEのアーキテクチャにおける、エンコーダ$q_{\phi}(z | x)$の中身に相当する部分です。

VAEのエンコーダは、条件付き確率分布として$q_{\phi}(\bm{z} | \bm{x})$で表現されます。これは、入力ベクトル$\bm{x}$を潜在変数$\bm{z}$に変換することからこのように表現します。

エンコーダ$q_{\phi}(\bm{z} | \bm{x})$は、何かしらのパラメトリックな確率モデルではなく、ニューラルネットワークとサンプリングによって構成されます。この辺りが少し難しいですね..。中身をよく見ていきましょう。

Encoderの出力層はガウス分布のパラメータ

最初にEncoderのアーキテクチャの概観を説明しましたが、ニューラルネットワークの出力層が2つのパラメタになっていることが分かります。これが、ガウス分布のパラメータ$\bm{\mu}$と$\bm{\Sigma}$に相当します。

VAEにおけるデコーダー(Decoder)

続いて、VAEにおけるデコーダ(Decoder)のアーキテクチャはこのようになっています。

潜在変数$\bm{z}$が与えられたときの$\bm{x}$を生成する確率分布 $p_{\theta}(x | z)$ を求めればよさそうです。

VAEのグラフィカルモデル

VAEの確率変数をグラフィカルモデルで表現するとこのようになります。

ここで、$\phi$や$\theta$は先ほどから何度も登場していますが、それぞれ$\phi$がエンコーダのパラメータで、$\theta$がデコーダにおけるパラメータです。エンコーダもデコーダもニューラルネットでモデルを表現しているため、ニューラルネットのパラメータが$\phi$や$\theta$に相当します。

VAEの最適化

では、VAEのアーキテクチャが定まったとして、これらのパラメータをどのように決定するのでしょうか。VAEではその名前のように、変分推論(Variational Inference, VI)を利用します。変分推論の理論的な側面は少し難解ですが、様々な場面で登場するので、一度理解しておくとよいでしょう。

VAEにおける損失関数

まず、最初にVAEにおける損失関数$\mathcal{L}$を最初に提示します。その後に、変分推論により、この損失関数を導きます。VAEにおける損失関数は、大きく、再構成項(reconstruction)正則化項(regularization)で構成されます。

VAEにおける損失関数$\mathcal{L}$
\begin{equation}
\begin{split}
\mathcal{L} = \operatorname{KL}[\operatorname{Encoder}(x)  || \mathcal{N(\mu, \Sigma)}]
+ 
\{ \operatorname{Decoder}(\operatorname{Encoder}(x) ) - x\}^2
\end{split} 
\end{equation}

損失関数$\mathcal{L}$を導出する

先ほど、VAEにおける損失関数$\mathcal{L}$を提示しました。なぜ、最初にこれを提示したかというと、$\mathcal{L}$を導出しますが、先ほども書いたように、この導出には変分推論の考え方が取り入れられており、少し難解だからです。VAEの実装をするだけであれば、このあたりの理解は必要ありませんが、将来的に最先端の論文を読んでいく際には必要であるため、腰を据えて取り組んでみましょう。

それでは、(1)の損失関数を得るために、色々と式変形をしていきます。

VAEのパラメータを決定する際には、まず最尤推定法により、周辺対数尤度$lnp(\bm{X})$を最大にすることを考えます。

ここで、周辺対数尤度$lnp(\bm{X})$は次のように分解することができます。

\begin{equation}
\begin{split}
lnp(\bm{X}) &= ln\int p(\bm{X}, \bm{z}) d\bm{z} \\
&= ln\int q(\bm{z} | \bm{X}) \frac{p(\bm{X}, \bm{z})}{q(\bm{z} | \bm{X})}d\bm{z}  \\
&\geq \int q(\bm{z} | \bm{X}) ln \frac{p(\bm{X}, \bm{z})}{q(\bm{z} | \bm{X})}d\bm{z} \\
&= \mathcal{L}(\bm{X}, \bm{z})
\end{split}
\end{equation}

(3)行目の式変形には、イェンゼンの不等式を利用しました。イェンゼンの不等式については、こちらの記事で解説しています。

機械学習で登場するイェンゼンの不等式を学ぶ
イェンゼンの不等式について学んでいきます。イェンゼンの不等式は、凸関数において成り立つ不等式です。 なかなか抽象的な面もありますが、一般系な凸関数において成り立つので、汎用性が高く、機械学習や統計学の勉強をすると度々登場 […]

また最終行は、得られた値を$\mathcak{L}$で置き換えているだけです。この$\mathcak{L}$で置き換えた値を、よくELBO(Evidenve Lower Bound)や自由エネルギー(Free Energy)と呼んでいます。この辺りは変分推論の理論になってくるため、こちらの関しても詳しく知りたい人は、次の記事をご覧ください。

【変分ベイズ】変分推論やELBOを理解する
変分推論(variational inference)は、ベイズ手法で解析的に事後分布を求めることができない場合に、非常によく利用される近似アルゴリズムです。変分推論は文脈によっては、変分ベイズと呼ばれたりしています。 […]

ここで、(2)式ではよくわからない式変形をして、さらにイェンゼンの不等式なる定理を用いて、周辺対数尤度$lnp(\bm{X})$を下から押さえ込むことをしていました。

ここで、(2)における左辺の対数の周辺対数尤度$lnp(\bm{X})$と変分下限$\mathcal{L}(\bm{X}, \bm{z})$の差分を考えていきます。この差分を得るには、単純に差を取ってみます。

\begin{equation}
\begin{split}
lnp(\bm{X}) - \mathcal{L}(\bm{X}, \bm{z}) &= 
lnp(\bm{X}) -  \int q(\bm{z} | \bm{X}) ln \frac{p(\bm{X}, \bm{z})}{q(\bm{z} | \bm{X})}d\bm{z} \\
&= lnp(\bm{X}) \int q(\bm{z} | \bm{X}) d\bm{z} - 
\int q(\bm{z} | \bm{X}) ln \frac{p(\bm{z} | \bm{X}) p(\bm{X})}{q(\bm{z} | \bm{X})}d\bm{z} \\
&= \int lnp(\bm{X}) q(\bm{z} | \bm{X}) d\bm{z} - \int q(\bm{z} | \bm{X}) 
\{ p(\bm{z} | \bm{X}) + p(\bm{X}) - q(\bm{z} | \bm{X})\} d\bm{z} \\
&= \int q(\bm{z} | \bm{X}) ln\frac{q(\bm{z} | \bm{X})}{p(\bm{z} | \bm{X})} d\bm{z} \\
&= \operatorname{KL}[q(\bm{z} | \bm{X})  || p(\bm{z} | \bm{X}) ]
\end{split}
\end{equation}

となり、周辺対数尤度と変分下限の差は、確率分布$q(\bm{z} | \bm{X}) $と$p(\bm{z} | \bm{X})$のKL情報量となることがわかりました。

ここで、(3)式をよりわかりやすく書くとこのようになります。

\begin{equation}
\begin{split}
lnp(\bm{X}) = \mathcal{L}(\bm{X}, \bm{z})  + \operatorname{KL}[q_{\phi}(\bm{z} | \bm{X})  || p_{\theta}(\bm{z} | \bm{X}) ]
\end{split}
\end{equation}

ここで、周辺対数尤度関数$lnp(\bm{X})$は定数であるため、$\mathcal{L}(\bm{X}, \bm{z})$を最大化することは、カルバックライブラー情報量$\operatorname{KL}[q(\bm{z} | \bm{X}) || p(\bm{z} | \bm{X}) ]$を最小化することと同じことになります。

KL情報量$\operatorname{KL}[q(\bm{z} | \bm{X}) || p(\bm{z} | \bm{X}) ]$を最小化する

ここで、KL情報量を最小化するために、$\operatorname{KL}[q(\bm{z} | \bm{X}) || p(\bm{z} | \bm{X}) ]$を式変形していきます。

\begin{equation}
\begin{split}
\operatorname{KL}[q_{\phi}(\bm{z} | \bm{X}) || p_{\theta}(\bm{z} | \bm{X}) ] &= 
\int q_{\phi}(\bm{z} | \bm{X}) ln \frac{q_{\phi}(\bm{z} | \bm{X})}{p(\bm{z} | \bm{X})} d \bm{X} \\
&= 
\int q_{\phi}(\bm{z} | \bm{X}) ln \{ q_{\phi}(\bm{z} | \bm{X}) - p_{\theta}(\bm{z} | \bm{X}) \} d \bm{X} \\
&= \langle ln q_{\phi}(\bm{z} | \bm{X}) - lnp_{\theta}(\bm{z} | \bm{X}) \rangle_{q_{\phi}(\bm{z} | \bm{X})} \\
&= \langle lnq_{\phi}(\bm{z} | \bm{X}) - ln p_{\theta}(\bm{X} | \bm{z}  )- lnp(\bm{z} ) + lnp(\bm{X} ) \rangle_{q_{\phi}(\bm{z} | \bm{X})}  \\
&= \langle ln q_{\phi}(\bm{z} | \bm{X})  -ln p(\bm{z} )
\rangle_{q_{\phi}(\bm{z} | \bm{X} )} 
-  \langle ln p_{\theta}(\bm{X} | \bm{z}  ) \rangle_{q_{\phi}(\bm{z} | \bm{X})} + ln p(\bm{X} ) \\
&= 
\operatorname{KL}[q_{\phi}(\bm{z} | \bm{X} ) || p(\bm{z} )] - 
\langle ln p_{\theta}(\bm{X} | \bm{z}  ) \rangle_{q_{\phi}(\bm{z} | \bm{X})} + ln p(\bm{X} )
\end{split}
\end{equation}

KL情報量が3つの項に分解されました。1つは、何やらまたKL情報量の形で表現できています。

以降、まだ、(1)の損失関数を導出するための式変形は続きます…。

参考文献

本記事で参考にした論文・サイト

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術