変分オートエンコーダ(修正版)

記事の内容


以前書いた変分オートエンコーダの記事の修正版です. 説明を詳しくし, 改めて実験を行いました. 変分オートエンコーダを扱った書籍やネット上の記事はうまくいくように, モデルやコードがかなり最適化されている印象がありますが, 本記事ではかなり素朴に実装しています. したがって, 結果もそんなに綺麗ではありませんが...

概要

動機付けと記事の構成

手元に画像データが大量にあるとします. 白黒の画像データの場合, 画像は0から1の間の実数値が各ピクセルに並べられた行列と考えることができます. 例えばMNISTの手書きの文字データであれば, 端の方は黒っぽい色が広がり, 真ん中らへんに白い部分が連続的に位置するので, 行列の各成分はバラバラな値を取るわけではなく, 何かしらの法則性があると考えるのが自然です. こう考えると画像データの自由度はデータそのものよりも極端に低く, もっと低次元の空間上に分布していそうです. そこで, 手元のデータが生成される前に, データの低次元表現が潜在的に生成されていると想定することにします. この想定のもと, 次のような目標を立てておきます.

【目標】

データの生成過程を想像し, 統計モデルを作り, データの低次元表現 (潜在表現)を得たい!

この目標を達成するためのアプローチとして, この記事では, データ生成過程のモデル, 変分オートエンコーダを説明します.

この記事の構成は以下の通りです. 第1節"概要"では, 本記事で扱う生成モデルの大枠を示します. モデルのおおよその形と学習の方針をできるだけ一般的な形になるように述べました. 次のセクション "モデルと方針"では, 変分推論法から出発して, 最適化問題として定式化し, 通常の変分推論と異なる点を明示した上で, 学習アルゴリズムの概要を示しました. その次のセクション "reparametrization trick"では, 目的関数の勾配計算で必要になる近似計算方法を説明しました. 次節 "画像データの応用"の最初のセクション "変分オートエンコーダ"では, 本節で導入した一般的なモデルと学習方法を具体例を通じて再構成することで, 変分オートエンコーダを説明します. 続くセクション "数値実験"では, MNISTデータへの適用結果を示します. 要するに, 抽象から具体へと進むように構成しています.

モデルと方針

それでは, もう少し具体的に定式化しておきましょう. 手元のデータ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\subset \mathbb{R}^{D_x}\)に対して, 以下のようなモデルを考えます:

【モデル】

for \(n=1,\dots,N\) do
\(\boldsymbol{z}_n \sim \pi_{\theta}(\boldsymbol{z})\)
\(\boldsymbol{x}_n \mid \boldsymbol{z}_n \sim p_{\xi}(\boldsymbol{x}\mid \boldsymbol{z}_n)\)
end

このモデルを用いて, 潜在変数たち\(Z=\{\boldsymbol{z}_1,\dots,\boldsymbol{z}_N\}\subset\mathbb{R}^{D_z}\)の事後分布を求めるのが目標です:

\begin{equation} \pi_{\xi,\theta} (Z\mid X) = \frac{\displaystyle \prod_{n=1}^{N}p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z}_n)\pi_{\theta}(\boldsymbol{z}_n)}{p_{\xi,\theta}(X)}. \end{equation}

あとで見る例もそうですが, 一般に, 上の事後分布を解析的に計算するのは困難です. そこで, 変分推論法の適用を考えます. 次のような近似分布を用意しましょう:

\begin{equation} r_\eta(Z\mid X) = \prod_{n=1}^Nr_\eta^{(n)}(\boldsymbol{z}_n\mid \boldsymbol{x}_n). \end{equation}

変分推論法では, \(r_\eta^{(n)}\)が属する分布族を指定し, 近似事後分布のパラメータは次のような最小化問題を解くことで定めます:

\begin{equation} \hat{\boldsymbol{\eta}} = \underset{\eta}{\mathrm{argmin}}\ D_{\mathrm{KL}}[r_\eta\| \pi_{\xi,\theta}]. \end{equation}

この他にも, モデルに未知のパラメータ\(\xi\)と\(\theta\)がありますが, これらは周辺尤度最大化により定めます:

\begin{equation} \hat{\boldsymbol{\xi}},\hat{\boldsymbol{\theta}} = \underset{\xi,\theta}{\mathrm{argmax}}\ \log p_{\xi,\theta}(X). \end{equation}

2つの最適化問題が登場しましたが, それぞれの目的関数同士は以下のような関係にあります:

\begin{equation} -D_{\mathrm{KL}}[r_\eta\| \pi_{\xi,\theta}]+\log p_{\xi,\theta}(X) = \mathcal{L}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}). \end{equation}

ここで, \(\mathcal{L}\)は変分推論法でいうところのELBOで, \begin{align} &\mathcal{L}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) \\ &= \sum_{n=1}^N\mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}),\quad \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) = \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} - \mathrm{D}_{\mathrm{KL}}[r^{(n)}_\eta\| \pi_{\theta}] \end{align} としました. ということで, ELBOを最大化することで2つの最適化問題を (近似的にですが, )同時に解くことができそうです. すなわち, 事後分布の近似も良くなり, かつモデルも良くなると期待できます. 完全なBayesモデルの場合, 周辺尤度は定数ですが, 今回は学習中に変化する点に注意してください.

以上の考え方をまとめます. まずは, 既知の情報と未知の情報をまとめておきます.

既知 or 未知 決め方 変数
既知 データから分かる データ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\)
データサイズ\(N\)
データの次元\(D_x\)
人間が決める モデル\(p_\xi\)が属する分布族
事前分布\(\pi_\theta\)が属する分布族
近似分布\(r_\eta^{(n)}\)が属する分布族
潜在変数の次元\(D_z\)
未知 分布を推定 潜在変数\(Z=\{\boldsymbol{z}_1,\cdots,\boldsymbol{z}_N\}\)
点推定 近似分布のパラメータ\(\boldsymbol{\eta}\)
モデルのパラメータ\(\boldsymbol{\xi}\)
事前分布のパラメータ\(\boldsymbol{\theta}\)

用意したデータとモデルに対して, 以下の最適化問題を解くことで, 事後分布の近似とモデル選択を同時に実行します:

\begin{equation} \hat{\boldsymbol{\eta}},\hat{\boldsymbol{\xi}},\hat{\boldsymbol{\theta}} = \underset{\eta,\xi,\theta}{\mathrm{argmax}} \ \mathcal{L}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}). \end{equation}

最大化は確率的勾配"上昇"法により実行します. おおよそ次のようなアルゴリズムになります.

【今回使う学習アルゴリズム】

Set the step size \(\alpha_k\).
Initialize the parameters \(\boldsymbol{\eta}^{(0)}\), \(\boldsymbol{\xi}^{(0)}\) and \(\boldsymbol{\theta}^{(0)}\).
for \(k=0,1,\dots\) do
Choose minibatch \(\mathcal{M}\).
Update \(\boldsymbol{\eta}\) by \(\boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} + \alpha_k \nabla \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta}^{(k)},\boldsymbol{\xi}^{(k)},\boldsymbol{\theta}^{(k)})\).
Update \(\boldsymbol{\xi}\) by \(\boldsymbol{\xi}^{(k+1)} = \boldsymbol{\xi}^{(k)} + \alpha_k \nabla \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta}^{(k)},\boldsymbol{\xi}^{(k)},\boldsymbol{\theta}^{(k)})\).
Update \(\boldsymbol{\theta}\) by \(\boldsymbol{\theta}^{(k+1)} = \boldsymbol{\theta}^{(k)} + \alpha_k \nabla \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta}^{(k)},\boldsymbol{\xi}^{(k)},\boldsymbol{\theta}^{(k)})\).
end

アルゴリズム中に出てくる\(\mathcal{L}_{\mathcal{M}}\)は, ELBOのミニバッチ版で, サイズを\(M\)とするとき, 次式で定義されます:

\begin{equation} \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) = \frac{N}{M}\sum_{x_n\in \mathcal{M}} \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}). \end{equation}

ELBOは積分を使って定義されるため, (近似事後分布のパラメータに関する)勾配計算は自明ではありません. 次のセクションでは, うまい変数変換をみつけることで勾配を近似する方法を説明します.

reparametrization trick

ここでは, reparametrization tirckを説明します. アルゴリズム中に出てきた次の勾配計算を考えましょう. :

\begin{equation} \nabla_\eta \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) = \frac{N}{M}\sum_{x_n\in \mathcal{M}} \nabla_\eta \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}). \end{equation}

ということで, 勾配 \begin{equation} \nabla_\eta \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) = \nabla_\eta \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} - \nabla_\eta \mathrm{D}_{\mathrm{KL}}[r^{(n)}_\eta\| \pi_{\theta}] \end{equation} が計算できれば十分です. 右辺第2項のKLダイバージェンスの勾配はは何とか計算できるとして, 第1項の勾配 \begin{equation} \nabla_\eta \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} \end{equation} を考えましょう. 単純に演算子\(\nabla_\eta\)を積分の中に入れると密度関数が微分されてしまい, 一般にモンテカルロ近似ができません. もし適当な変数変換によって \begin{equation} \boldsymbol{z} = g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) \end{equation} と変換できれば, 以下のように計算できます:

\begin{align} & \nabla_\eta \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} \\ &= \nabla_\eta \int r_{\eta}^{(n)}(g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n)\mid\boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) )|\det(J_g(\boldsymbol{\varepsilon}))|\mathrm{d}\boldsymbol{\varepsilon} \\ &= \nabla_\eta \int r_{\eta}^{(n)}(g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n)\mid\boldsymbol{x}_n)|\det(J_g(\boldsymbol{\varepsilon}))|\log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) )\mathrm{d}\boldsymbol{\varepsilon}. \end{align}

ここで, \(J_g\)はJacobi行列です. もし, \begin{equation} \phi(\boldsymbol{\varepsilon}) = r_{\eta}^{(n)}(g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n)\mid\boldsymbol{x}_n)|\det(J_g(\boldsymbol{\varepsilon}))| \end{equation} が\(\boldsymbol{\eta}\)に依存しない密度関数になれば, 次式のように計算できます:

\begin{align} &= \nabla_\eta \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} \\ &= \nabla_\eta \int \phi(\boldsymbol{\varepsilon}) \log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) )\mathrm{d}\boldsymbol{\varepsilon}\\ &= \int \phi(\boldsymbol{\varepsilon}) \nabla_\eta \log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) )\mathrm{d}\boldsymbol{\varepsilon}\\ &\simeq \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{x}_n) ),\quad \boldsymbol{\varepsilon}^{(s)}\sim \phi(\boldsymbol{\varepsilon}). \end{align}

以上の近似計算をreparametrization trickといいます. ポイントはいい感じの変換\(g\)を見つけられるかどうかですが, 近似分布が正規分布ならうまくいきます. なお, 上の説明では\(\boldsymbol{\eta}\)に関する勾配計算を考えていますが, \(\boldsymbol{\xi}\)と\(\boldsymbol{\theta}\)に関する勾配も上の近似式中の\(\nabla_\eta\)を置き換えて計算できます.

画像データへの応用

それでは, いよいよ画像データに前節の枠組みを適用します. ここまで色々説明してきましたが, データ以外に用意するものは以下の通りです.

  • モデル\(p_\xi\)が属する分布族
  • 潜在変数の事前分布\(\pi_\theta\)が属する分布族
  • 近似分布の各因子\(r_\eta^{(n)}\)が属する分布族
  • 潜在変数の次元\(D_z\)
  • その他, 変数変換\(g_\eta\), ミニバッチサイズ\(M\)とreparametrization trickのサンプル数\(S\)

これらを用意しながら, 画像データへの応用を説明します.

変分オートエンコーダ

この節では変分オートエンコーダ(Variational Auto-Encoder)を導入します. いま, 手元に画像データ\(X=\{\boldsymbol{x}_1,\dots,\boldsymbol{x}_N\}\subset\mathbb{R}^{D_x}\)があるとします. ここで, 各\(\boldsymbol{x}_n\)は画像をベクトル化したものとします. この画像の生成過程を以下のように想定します.

【モデル:変分オートエンコーダ(VAE)】

for \(n=1,\dots,N\) do
\(\boldsymbol{z}_n \sim \mathrm{N}(\boldsymbol{0},I_{D_z})\)
\(\boldsymbol{x}_n \mid \boldsymbol{z}_n \sim \mathrm{N}(F_\xi(\boldsymbol{z}_n),I_{D_x})\)
end

ここで, \(F_\xi\colon \mathbb{R}^{D_z}\to\mathbb{R}^{D_x}\)は\(\boldsymbol{\xi}\)を重みパラメータとするニューラルネットワークで, デコーダと呼ばれる.

事前分布を標準正規分布にしたので, 前節の\(\theta\)をここでは無視します. 低次元の潜在表現から高次元の画像データへの変換の過程は複雑であることが予想されるため, モデルにニューラルネットワークを組み込むのは自然な選択です. しかし, 副作用として推論が複雑になるため, 前節のような近似的な枠組みが有効に機能します.

【モデルのイメージ図】

次に, 事後分布の近似分布を定義しましょう. 近似分布の各因子は以下のような 正規分布としましょう:

\begin{equation} r_\eta^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n) = \mathrm{N}\left( \boldsymbol{z}\mid \boldsymbol{m}_\eta(\boldsymbol{x}_n), \mathrm{diag} (\boldsymbol{s}_\eta(\boldsymbol{x}_n))^2\right). \end{equation}

ここで, 近似事後分布の平均と標準偏差パラメータは, 以下のように, \(\boldsymbol{\eta}\)を重みとするニューラルネットワーク\(G_\eta\colon\mathbb{R}^{D_x}\to\mathbb{R}^{2D_z}\)の出力とします:

\begin{equation} \begin{bmatrix} \boldsymbol{m}_\eta(\boldsymbol{x}_n)\\ \log \boldsymbol{s}_\eta(\boldsymbol{x}_n) \end{bmatrix} = G_\eta(\boldsymbol{x}_n). \end{equation}

すべての\(n\)で共通なパラメータ\(\boldsymbol{\eta}\)を用いることで, データサイズに依らず効率的に学習できます. ちなみにこのニューラルネットワークはエンコーダと呼ばれます. このとき, ELBOは以下のように計算できます:

\begin{align} & \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) \\ &= \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} - \mathrm{D}_{\mathrm{KL}}[r^{(n)}_\eta\| \pi_{\theta}] \\ &= \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} -\frac{1}{2}\sum_{j=1}^{D_z}\left\{\boldsymbol{m}_\eta(\boldsymbol{x}_n)_j^2+ \boldsymbol{s}_\eta(\boldsymbol{x}_n)_j^2 -2\log\boldsymbol{s}_\eta(\boldsymbol{x}_n)_j-1\right\}. \end{align}

問題はreparametrization trickですが, 以下の変数変換を考えます:

\begin{equation} \boldsymbol{z} = g_\eta(\boldsymbol{\varepsilon},\boldsymbol{x}_n) = \boldsymbol{m}_\eta(\boldsymbol{x}_n) + \boldsymbol{s}_\eta(\boldsymbol{x}_n) \odot \boldsymbol{\varepsilon}. \end{equation}

このとき, 変換後の\(\boldsymbol{\varepsilon}\)の分布は標準正規分布になります. したがって, ELBOの勾配は以下のように近似計算できます:

\begin{align} & \nabla_\eta \mathcal{L}^{(n)}(\boldsymbol{\eta},\boldsymbol{\xi},\boldsymbol{\theta}) \\ &= \nabla_\eta \int r_{\eta}^{(n)}(\boldsymbol{z}\mid \boldsymbol{x}_n)\log p_{\xi}(\boldsymbol{x}_n\mid \boldsymbol{z})\mathrm{d}\boldsymbol{z} -\frac{1}{2}\sum_{j=1}^{D_z}\nabla_\eta \left\{\boldsymbol{m}_\eta(\boldsymbol{x}_n)_j^2+ \boldsymbol{s}_\eta(\boldsymbol{x}_n)_j^2 -2\log\boldsymbol{s}_\eta(\boldsymbol{x}_n)_j-1\right\}\\ &\simeq \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p_{\xi}(\boldsymbol{x}_n\mid g_\eta(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{x}_n) ) - \frac{1}{2}\sum_{j=1}^{D_z}\nabla_\eta\left\{\boldsymbol{m}_\eta(\boldsymbol{x}_n)_j^2+ \boldsymbol{s}_\eta(\boldsymbol{x}_n)_j^2 -2\log\boldsymbol{s}_\eta(\boldsymbol{x}_n)_j-1\right\}\\ &\boldsymbol{\varepsilon}^{(s)} \sim \mathrm{N}(\boldsymbol{0},I_{D_z}). \end{align}

以上をまとめて, 次のような学習アルゴリズムを用います.

【VAEの学習アルゴリズム】

Set the step size \(\alpha_k\).
Initialize the parameters \(\boldsymbol{\eta}^{(0)}\) and \(\boldsymbol{\xi}^{(0)}\).
for \(k=0,1,\dots\) do
Choose minibatch \(\mathcal{M}\).
Update \(\boldsymbol{\eta}\) by \(\boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} + \alpha_k \nabla \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta}^{(k)},\boldsymbol{\xi}^{(k)})\).
Update \(\boldsymbol{\xi}\) by \(\boldsymbol{\xi}^{(k+1)} = \boldsymbol{\xi}^{(k)} + \alpha_k \nabla \mathcal{L}_{\mathcal{M}}(\boldsymbol{\eta}^{(k)},\boldsymbol{\xi}^{(k)})\).
end

ただし, 以下の実装ではELBOの符号反転の最小化として訓練を行っています.

実験結果

それでは, MNISTを用いた実験結果を示します. MNISTの訓練データ\(N=60,000\)枚を用いて, 変分オートエンコーダを訓練しました. 前処理として, 各画像データはベクトル値に変換し, \(D_x=28^2\)としています. また, 潜在変数の次元は可視化しやすいように\(D_z=2\)としました. ミニバッチのサイズは\(M=100\), reparametrization trickのサンプル数は\(S=1\)としました. 最大エポック数は200とし, 確率的勾配降下法に学習率を0.02に設定したAdaGradを用いて, ELBOの符号反転最小化を目指します. また, ELBOには, デコーダの重みパラメータの二乗ノルムの2乗\(\|\boldsymbol{\xi}\|^2\)に0.01を乗じた正則化項を付加しています.

次に, デコーダ\(F_\xi\)とエンコーダ\(G_\eta\)の構成は以下の通りです:

\begin{align} F_\xi(\boldsymbol{z}) &= W_F^{(3)}\tanh \left( W_F^{(2)}\boldsymbol{z} + \boldsymbol{b}_F^{(2)}\right) + \boldsymbol{b}_F^{(3)} \\ G_\eta(\boldsymbol{x}) &= W_G^{(3)}\tanh \left( W_G^{(2)}\boldsymbol{x} + \boldsymbol{b}_G^{(2)}\right) + \boldsymbol{b}_G^{(3)}. \end{align}

行列に関しては平均0, 標準偏差0.1の正規分布からの乱数で初期化し, バイアス項はゼロベクトルで初期化しています.

下図は, 学習中のELBOの符号反転の変化です. 最初にストンと落ちてからはかなり緩やかに減少しています. 200エポックも要らないかも...

【コード5の実行結果】

今回は\(D_z=2\)としていますので, 潜在変数を2次元平面にプロットして可視化できます. 下図に, 各潜在変数の期待値を最初の5000個分だけプロットし, 正解ラベルごとに色分けしています. うまく分離できていない部分もありますが, 何となく数字ごとに分かれているように見えます.

【コード6の実行結果】

先ほどは各画像データの潜在表現を平面上に図示しましたが, 逆に平面上の適当な点から新たに画像を生成することができます. 下図は, 平面上に等間隔で打った代表点から生成した画像を並べたものです. このように, 潜在空間内では, 各数字画像が連続的に変化しています.

【コード7の実行結果】

さらに, 潜在空間内で連続的に動いた場合の画像の変化をアニメーションで示します. 一番左は潜在空間内を放物線に沿って動き, 中央は潜在空間内を直線に沿って動き, 一番右は潜在空間内を単位円周に沿って動いたときに生成される画像の変化です.

【コード8の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using Zygote
using Flux
using LinearAlgebra

#statistics
using Random
using Statistics
using Distributions

#dataset
using MLDatasets

#visualize
using Images
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack

ProgressMeter.ijulia_behavior(:clear)
【Juliaコード2; モデルと関数の定義】
#vectorize image data
function vectorize_images(X_train)
    #d:dimension, N:train sample size
    d,d,N = size(X_train)

    #vectorize the image
    X = zeros(d*d,N)
    @showprogress for n in 1:N
        X[:,n] = X_train[:,:,n][:]
    end
    return X,N,d*d
end

#create decoder and encoder
function create_NN(z_dim,n_hidden,x_dim)
    #decoder
    decoder = Chain(
        Dense(0.1*randn(n_hidden,z_dim),zeros(n_hidden),tanh),
    )
    x_mean_decoder = Chain(decoder,Dense(0.1*randn(x_dim,n_hidden),zeros(x_dim)))
    
    #encoder 
    encoder = Chain(
        Dense(0.1*randn(n_hidden,x_dim),zeros(n_hidden),tanh),
    )
    z_mean_encoder = Chain(encoder,Dense(0.1*randn(z_dim,n_hidden),zeros(z_dim)))
    z_logstd_encoder = Chain(encoder,Dense(0.1*randn(z_dim,n_hidden),zeros(z_dim)))
    return x_mean_decoder,z_mean_encoder,z_logstd_encoder
end

#reparametrize
function reparametrize(z_mean,z_logstd,z_dim)
    z_mean + exp.(z_logstd) .* randn(z_dim)
end

#KL divergence between variational posterior and prior
function KL_var_prior(z_mean,z_logstd,z_dim)
    0.5f0 * sum(@. (exp(2f0*z_logstd) + z_mean^2f0 - 2f0*z_logstd - 1f0))
end

#negative reconstruction error
function neg_reconst_error(xn,x_mean)
    logpdf(MvNormal(x_mean,1),xn)
end

#ELBO 
function ELBO_minibatch(
        x_mean_decoder,z_mean_encoder,z_logstd_encoder,
        X_minibatch,N,z_dim,minibatch_size
    )
    
    #encode
    z_means,z_logstds = z_mean_encoder(X_minibatch),z_logstd_encoder(X_minibatch)
    zs = reparametrize(z_means,z_logstds,z_dim)
    
    #compute ELBO
    L = neg_reconst_error(X_minibatch[:],x_mean_decoder(zs)[:])
    L -= KL_var_prior(z_means,z_logstds,z_dim)
    return N*L/minibatch_size
end

#create VAE model
function create_model(X,z_dim,n_hidden,x_dim,minibatch_size)
    x_mean_decoder,z_mean_encoder,z_logstd_encoder = create_NN(z_dim,n_hidden,x_dim)
    loss_func = (
        M -> -ELBO_minibatch(
                x_mean_decoder,z_mean_encoder,z_logstd_encoder,M,N,z_dim,minibatch_size
            )
            + 0.01f0*sum(x->sum(x.^2), Flux.params(x_mean_decoder))   
        )
    ps = Flux.params(x_mean_decoder,z_mean_encoder,z_logstd_encoder)
    data_loader = Flux.DataLoader(X,batchsize=minibatch_size, shuffle=true, partial=false)
    return x_mean_decoder,z_mean_encoder,z_logstd_encoder,loss_func,ps,data_loader
end

#train VAE
function train_VAE(data,model_params,n_epochs,minibatch_size)
    #data and model parameters
    @unpack X,N,x_dim = data
    @unpack z_dim,n_hidden = model_params
    
    #decoder, encoder, loss function, parameters and data loader
    x_mean_decoder,z_mean_encoder,z_logstd_encoder,loss_func,ps,data_loader = create_model(
        X,z_dim,n_hidden,x_dim,minibatch_size
    )
    
    #define optimizer
    opt = ADAGrad(0.02)
    
    #ELBO for each epochs
    loss_avg = zeros(n_epochs)
    
    #train by SGD
    for k in 1:n_epochs
        #progress bar
        pb = Progress(length(data_loader), 1, "epoch $(k): ")
        
        #compute ELBO for each minibatchs
        for X_minibatch in data_loader
            loss,back = pullback(ps) do 
                loss_func(X_minibatch)
            end
            gradients = back(1f0)
            Flux.Optimise.update!(opt, ps, gradients)
            loss_avg[k] += loss
            
            #update progress bar
            next!(pb; showvalues=[(:loss, loss)])
        end
        loss_avg[k] = loss_avg[k] / length(data_loader)
    end
    NNs = [x_mean_decoder,z_mean_encoder,z_logstd_encoder]
    return NNs,ps,loss_avg
end

#return z_mean
function latent_estimate(data,model_params,NNs)
    @unpack X,N,x_dim = data
    @unpack z_dim = model_params
    z_mean_encoder = NNs[2]
    z_est = z_mean_encoder(X)
    return z_est
end

#learned manifold
function visualize_manifold(n_imgs,NNs)
    #decoder
    decoder = NNs[1]
    
    #big picture
    img_size = 28*n_imgs
    manifold = zeros(img_size,img_size)
    
    #points in latent space
    xs = range(-2,2,length=n_imgs)
    ys = range(-2,2,length=n_imgs)
    
    #decode
    for j in 1:n_imgs
        for i in 1:n_imgs
            x = xs[j]
            y = ys[i]
            manifold[28(i-1)+1:28*i, 28*(j-1)+1:28*j] = decoder([x,y])
        end
    end
    return manifold
end

#across the latent space
function across_latent_space(xs,ys,NNs)
    #decoder
    decoder = NNs[1]
    
    #decode latent representations along the line
    L = length(xs)
    imgs = zeros(28,28,L)
    for l in 1:L
        imgs[:,:,l] = decoder([xs[l],ys[l]])
    end
    return imgs
end
【Juliaコード3; 画像データの用意】
#load the training data
X_train,y_train = MNIST.traindata()

#vectorize image data
X,N,x_dim = vectorize_images(X_train)
【Juliaコード4; VAEの訓練】
#data and training parameter
data = (X=X,N=N,x_dim=x_dim)
model_params = (z_dim=2,n_hidden=500)
n_epochs = 200
minibatch_size = 100

#train VAE
@time NNs,ps,loss_avg = train_VAE(data,model_params,n_epochs,minibatch_size)
【Juliaコード5; 訓練損失のプロット】
fig1 = plot(1:n_epochs,loss_avg,title="Training loss",xlabel="epoch",ylabel="training loss",label=false)
savefig(fig1,"figs-VAE/fig1.png")
【Juliaコード6; 潜在空間の可視化】
n_test = 5000
z_est = latent_estimate(data,model_params,NNs)
fig2 = plot(
    z_est[1,1:n_test],z_est[2,1:n_test],st=:scatter,zcolor=y_train,markerstrokewidth=0,label=false,
    xlabel="z₁", ylabel="z₂", title="2D latent space", c=palette(:tab10)
)
savefig("figs-VAE/fig2.png")
【Juliaコード7; 画像の生成】
#learned manifold
n_imgs = 15
manifold = visualize_manifold(n_imgs,NNs)

#show the result
fig3 = plot(
    st=:heatmap,reverse(manifold',dims=1), 
    color=:grays, aspect_ratio=1, colorbar=false, xticks=false, yticks=false
    ) 
savefig(fig3,"figs-VAE/fig3.png")
【Juliaコード8; アニメーションで表示】
#a line in latent space and images
xs = -2:0.05:2
ys = @. xs^2 + xs - 4
imgs = across_latent_space(xs,ys,NNs)

#animation 1
idx = 1:length(xs)
anim1 = @animate for l in vcat(idx,reverse(idx))
    plot(
        st=:heatmap,reverse(imgs[:,:,l]',dims=1),size=(200,200),
        color=:grays,aspect_ratio=1, colorbar=false, xticks=false, yticks=false
    )
end
gif(anim1,"figs-VAE/anim1.gif")

#a line in latent space and images
xs = -2:0.05:2
ys = xs
imgs = across_latent_space(xs,ys,NNs)

#animation 2
idx = 1:length(xs)
anim2 = @animate for l in vcat(idx,reverse(idx))
    plot(
        st=:heatmap,reverse(imgs[:,:,l]',dims=1),size=(200,200),
        color=:grays,aspect_ratio=1, colorbar=false, xticks=false, yticks=false
    )
end
gif(anim2,"figs-VAE/anim2.gif")

#a line in latent space and images
ts = 0:0.05:2*π
xs = cos.(ts)
ys = sin.(ts)
imgs = across_latent_space(xs,ys,NNs)

#animation 3
idx = 1:length(xs)
anim3 = @animate for l in 1:length(idx)
    plot(
        st=:heatmap,reverse(imgs[:,:,l]',dims=1),size=(200,200),
        color=:grays,aspect_ratio=1, colorbar=false, xticks=false, yticks=false
    )
end
gif(anim3,"figs-VAE/anim3.gif")
参考文献

      [1]C.Doersch, Tutorial on Variational Autoencoders, arXiv:1606.05908, 2016.
      [2]D.P.Kingma, M.Welling, Auto-Encoding Variational Bayes, 2nd International Conference on Learning Representations, 2014.
      [3]D.P.Kingma, M.Welling, An Introduction to Variational Autoencoders, Foundations and Trends in Machine Learning, 12(4), pp.307-392, 2019.
      [4]須山敦志, ベイズ深層学習, 講談社, 2020