変分推論法(修正版)

記事の内容


以前書いた記事の修正版です. モデルやプログラムを修正し, 例を追加しました. 変分推論を具体例を用いて説明します. 最も有名な平均場近似を用いた手法と, ミニバッチ学習に対応した確率的変分推論法を紹介します. 確率的変分推論法は, 元々[5]で提案された手法ですが, ここではもっと広い意味で用いています. 原論文では確率的最適化と自然勾配の考え方を用いていますが, ここでは「確率的最適化をベースとした変分推論法」程度の意味で用いています.

変分推論法の概要

問題設定

変分推論は, 複雑な事後分布をより単純な近似する方法です. サンプリング法であるMCMCとは違って最適化がベースであり, 一般に変分推論法の方が高速です. まず, どういう状況で用いる手法なのか, 問題設定を確認しておきましょう.

手元にデータ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\subset \mathbb{R}^D\)があるとします.このデータを発生させた分布を推定することが目的です. モデルを\(p(\boldsymbol{x}\mid\boldsymbol{w})\)でモデル化したとします. ここで, \(\boldsymbol{w}\in\mathbb{R}^d\)はパラメータです. このパラメータの事前分布を\(p(\boldsymbol{w})\)とすると以下の事後分布が計算できます:

\begin{equation} p(\boldsymbol{w}\mid X) \propto p(X\mid\boldsymbol{w}) p(\boldsymbol{w}). \end{equation}

この事後分布の平均を計算したり, モデルを事後分布で重みづけて予測分布を計算するには事後分布が複雑すぎる場合, MCMCや変分推論法が用いられます. 次に, 変分推論法の基本的なアイデアを述べます.

基本的なアイデア

変分推論法では, 事後分布を近似する分布を構成します. 目標は, 次のような分布を, 予め定めた分布族\(\mathcal{R}\)から見つけることです:

\begin{equation} p(\boldsymbol{w}\mid X) \simeq r(\boldsymbol{w}),\quad r\in \mathcal{R}. \end{equation}

ここで, 次の2点が問題になります.

  • 近似分布は事後分布にどういう意味で近いのか?
  • 分布族\(\mathcal{R}\)はどのように設定するか?

まず1点目. 分布間の近さは, Kullback-Leiblerダイバージェンスで測るのが自然です. 近似分布と事後分布の間のKullback-Leiblerダイバージェンスが最小になるように近似分布を定めます. もう少し詳しく見てみましょう. 次のような数量を定めます:

\begin{equation} \mathcal{L}[r] = -\int r(\boldsymbol{w})\log \frac{r(\boldsymbol{w})}{p(X,\boldsymbol{w})} d\boldsymbol{w}. \end{equation}

すると, \(r(\boldsymbol{w})\)と\(p(\boldsymbol{w}\mid X)\)の間のKullback-Leiblerダイバージェンスは, 次のように書けます:

\begin{equation} D_{\mathrm{KL}}[r(\boldsymbol{w})\|p(\boldsymbol{w}\mid X)] = -\mathcal{L}[r] + \log p(X). \end{equation}

Kullback-Leiblerダイバージェンスが非負であることから, 次式が得られます:

\begin{equation} \mathcal{L}[r] \leq \log p(X). \end{equation}

\(\mathcal{L}\)が周辺尤度(エビデンス)の下界であることから, この量をELBO(Evidence Lower bound)と呼びます. 当面の間の目標はKullback-Leiblerダイバージェンスの最小化ですが, これはELBOの最大化と等価です. なお, ELBOの符号反転は変分自由エネルギーと呼ばれます.

次に2点目です. 分布族の決め方はいくつかあります. 最も有名な方法が平均場近似です. 平均場近似では, 近似分布を具体的には指定せず, 事後分布が都合よく分解できると仮定します. 分解できるという仮定のみでも, 事後分布は"いい感じ"に求まります. 他には, 予め分布を指定する方法があります. 例えば予め正規分布で近似できると仮定して, パラメータを決めてやります. 以下, この2つのアイデアを詳しく掘り下げていきます.

平均場近似

平均場近似の考え方

平均場近似では, 近似分布が属する分布族を次のように定めます. ただし, \(K\geq 2\)とします:

\begin{equation} \mathcal{R} = \left\{ r(\boldsymbol{w})\mid \text{ある確率(密度)関数\(r_1,\cdots,r_K\)が存在して, }r(\boldsymbol{w}) = \prod_{i=1}^Kr_i(\boldsymbol{w_i})\right\}. \end{equation}

このような分布族から持ってきた近似分布\(r\)について, 各\(i=1,\cdots,K\)に対して, ELBOは次のように書けます:

\begin{equation} \mathcal{L}[r] = -D_{\mathrm{KL}}\left[r_i(\boldsymbol{w}_i) \| \frac{1}{Z}\exp\left(\mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right]\right) \right] + \mathrm{const}. \end{equation}

ここで, \(Z\)は正規化定数で, 上式の期待値は次のようにとります:

\begin{equation} \mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right] = \int \left(\prod_{j\neq i}r_j(\boldsymbol{w}_j)\right) \log p(X,\boldsymbol{w}) d\boldsymbol{w}^{\setminus i}. \end{equation}

2つの分布が一致するときにKLダイバージェンスが最小になることから, 各\(j\neq i\)に対して因子\(r_j\)が定まっているときに, \begin{equation} r_i(\boldsymbol{w}_i) \propto \exp \left( \mathrm{E}_{\prod_{j\neq i}r_i} [\log p(X,\boldsymbol{w})]\right) \end{equation} と因子\(r_i\)を更新することで, ELBOが最大になります. 適当な初期分布から初めて更新を繰り返していけば, ELBOがどこかで頭打ちになるはずです. このとき近似分布は事後分布に十分近いと期待します.

例: 正規分布のパラメータ推論

ここでは例として, 正規分布の平均と精度の事後分布を推定します. モデルが正規分布で, パラメータが平均パラメータと精度パラメータにあたります. 次のような生成過程を考えます:

\begin{equation} \mu\sim\mathrm{N}(0,1) , \quad \lambda \sim\mathrm{Gamma}(1,\beta),\quad x\sim\mathrm{N}(\mu,\lambda^{-1}). \end{equation}

\(\beta\)は定数です. このとき事後分布は次のようになります:

\begin{equation} p(\mu,\lambda\mid X) \propto \lambda^{\frac{N}{2}} e^{-\beta\lambda} \exp\left\{ -\frac{1}{2}\mu^2-\frac{\lambda}{2}\sum_{n=1}^N(x_n-\mu)^2\right\}. \end{equation}

例えばこの分布の平均を求めるのは困難です. ということで変分推論法を行います. 次のように分解されるとします:

\begin{equation} p(\mu,\lambda\mid X) \simeq r(\mu)r(\lambda). \end{equation}

先ほど導出した期待値計算を含む分布を導出すると, 以下のようになります:

\begin{align} r(\mu) &\propto \exp\left\{ -\frac{1}{2}\mu^2-\frac{1}{2}\mathrm{E}_\lambda [\lambda] \sum_{n=1}^N(x_n-\mu)^2 \right\} \\ r(\lambda) &\propto \exp\left\{ \frac{N}{2}\log\lambda-\beta\lambda - \frac{\lambda}{2}\sum_{n=1}^N\mathrm{E}_\mu [(x_n-\mu)^2]\right\}. \end{align}

\(\mu\)の近似分布は正規分布, \(\lambda\)の近似分布はガンマ分布であることが分かります. それぞれの分布を次のようにおきます:

\begin{equation} \mathrm{N}(\hat{\mu}, \hat{\lambda}) ,\quad \mathrm{Gamma}(\hat{\alpha}, \hat{\beta}). \end{equation}

すると, 次のように書けます:

\begin{align} \hat{\lambda} &= 1+\frac{N\hat{\alpha}}{\hat{\beta}} \\ \hat{\mu} &= \frac{\hat{\alpha}}{\hat{\lambda}\hat{\beta}}\left(\sum_{n=1}^Nx_n\right) \\ \hat{\alpha} &= 1+\frac{N}{2} \\ \hat{\beta} &= \beta + \frac{1}{2}\left(\sum_{n=1}^Nx_n^2\right)-\hat{\mu}\left( \sum_{n=1}^Nx_n\right) + \frac{N}{2}\hat{\mu}^2 + \frac{N}{2\hat{\lambda}}. \end{align}

これを更新式として, 次のようなアルゴリズムが得られます. 実際には, \(\hat{\alpha}\)が上記の更新で変化しないので, ループの外に出して実装します.

【正規分布に対する変分推論アルゴリズム】

Initialize \(\hat{\lambda}^{(0)}, \hat{\mu}^{(0)}, \hat{\alpha}^{(0)}, \hat{\beta}^{(0)}\).
Set \(k=0\) and \(\varepsilon>0\).
while\( |\mathcal{L}[r^{(k+1)}]-\mathcal{L}[r^{(k)}] | \geq \epsilon\)
update \(\hat{\lambda}^{(k+1)}, \hat{\mu}^{(k+1)}, \hat{\alpha}^{(k+1)}, \hat{\beta}^{(k+1)}\).
\(k=k+1\)
end

ちなみに, ELBOの値は以下の通りです. ただし, \(\Gamma\)はガンマ関数, \(\psi\)はディガンマ関数です:

\begin{align} \mathcal{L}[r] = -\frac{1}{2}\log\hat{\lambda} + \frac{1}{2}-\hat{\alpha}\log\hat{\beta} + \log\Gamma (\hat{\alpha}) - (\hat{\alpha}-1)(\psi (\hat{\alpha})-\log\hat{\beta}) + \hat{\alpha} + \frac{N}{2}\left( \psi (\hat{\alpha})-\log (2\pi\hat{\beta})\right)\\ - \frac{\hat{\alpha}}{2\hat{\beta}}\left( \sum_{n=1}^Nx_n^2-2\hat{\mu}\sum_{n=1}^Nx_n + N\hat{\mu}^2 + \frac{N}{\hat{\lambda}}\right) - \frac{1}{2}\left( \hat{\mu}^2+\frac{1}{\hat{\lambda}}\right) + \log\beta - \frac{\hat{\alpha}\beta}{\hat{\beta}}. \end{align}

以下に実験結果を示します. 真の分布は混合正規分布\(0.5\mathrm{N}(-2, 1.5^2)+0.5\mathrm{N}(2, 1.5^2)\)としました. この分布からデータを\(N=20\)点発生させました. \(\beta=10, \epsilon=10^{-6}\)とします. また, 初期値は\(\hat{\mu}=0.0, \hat{\lambda}=1.0, \hat{\beta}=\beta\)とします. 左下図の等高線は, 近似事後分布です. また, Gibbs samplerによって(真の)事後分布からサンプルし, プロットしました(緑点). 実際の事後分布と近似事後分布は一致しているように見えます. また, 右下図は, ELBOの変化の様子を示しました. 5回目の更新で収束しています.

【コード6の実行結果】

近似事後分布からサンプルを発生させて, 予測分布を計算し(橙実線), 真の分布(赤実線)と比較しました. さらに, Gibbs samplerで計算した予測分布(緑破線)を計算しました(青色).

【コード7の実行結果】

確率的変分推論法

確率的変分推論の考え方

平均場近似による方法は基本的に高速なのですが, 更新式が常に導出できるとは限りません. そこで, 予め分布を指定して, ELBOを勾配上昇により増加させる方法を考えます. 例えば正規分布などの簡単な分布を仮定して, 近似分布のパラメータをいい感じに調節し, 近似精度向上を目指します.

近似分布を\(r_{\boldsymbol{\eta}}\)と表示しておきましょう. パラメータ\(\boldsymbol{\eta}\)を調節して, KLダイバージェンスの最小化, もしくはそれと等価なELBOの最大化を目指します. ELBOを具体的に書き出すと, 以下のようになります:

\begin{equation} \mathcal{L}(\boldsymbol{\eta}) = \sum_{n=1}^N\int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid\boldsymbol{w})\mathrm{d}\boldsymbol{w} + \int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log p(\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log r_{\boldsymbol{\eta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}

これを, 確率的勾配降下法と同じ要領で次のように近似します:

\begin{equation} \mathcal{L}_n(\boldsymbol{\eta}) \simeq \mathcal{L}(\boldsymbol{\eta}) = N\int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid\boldsymbol{w})\mathrm{d}\boldsymbol{w} + \int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log p(\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \int r_{\boldsymbol{\eta}}(\boldsymbol{w})\log r_{\boldsymbol{\eta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}

そして, 最急"上昇"法により, 次のように最大化していきます:

\begin{equation} \boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} + \alpha_k \nabla_{\boldsymbol{\eta}} \mathcal{L}_n(\boldsymbol{\eta}^{(k)}). \end{equation}

以上の流れをアルゴリズムとしてまとめておきます.

【確率的変分推論アルゴリズム】

Initialize weights \(\boldsymbol{\eta}^{(0)}\), step size \(\alpha^{(k)}\).
Set \(k=0\) and \(\varepsilon>0\).
while\( |\mathcal{L}(\boldsymbol{\eta}^{(k+1)})-\mathcal{L}(\boldsymbol{\eta}^{(k)}) | \geq \epsilon\)
Select \(\boldsymbol{x}_n\) uniformly at random.
\(\boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} + \alpha^{(k)}\nabla_{\boldsymbol{\eta}} \mathcal{L}_{n}(\boldsymbol{\eta}^{(k)})\)
\( k=k+1\)
end

以下に, 確率的変分推論の例を2つほど挙げます.

例2: ロジスティック回帰

例として, 以下のようなロジスティックモデルを考えましょう:

\begin{align} \boldsymbol{w} &\sim \mathrm{N}(\boldsymbol{m}_0, s_0^2I_2)\\ y\mid x,\boldsymbol{w} &\sim \mathrm{Bernoulli}\left( \sigma(w_1+w_2 x) \right). \end{align}

ただし, \(\sigma\)はシグモイド関数です. データ\(X=\{x_n\}_{n=1}^{N}\), \(\{y_n\}_{n=1}^{N}\)が手に入ったときの事後分布を計算します. 事後分布の解析的計算は難しいので, 変分推論法を利用してみましょう. \(\boldsymbol{w}\)の近似事後分布\(r_\eta\)として, 以下のような2次元正規分布を使います. 近似分布の平均と標準偏差を近似事後分布のパラメータとします:

\begin{equation} \mathrm{N}\left( \boldsymbol{m}, \mathrm{diag}(\boldsymbol{s})^2\right), \quad \boldsymbol{\eta} = \begin{bmatrix} \boldsymbol{m}\\ \log\boldsymbol{s}\end{bmatrix}. \end{equation}

対数をとることで, 調整したいパラメータ\(\boldsymbol{\eta}\)の制約を取り払うことができます. このとき, 最大化したいELBOは以下のように計算できます:

\begin{equation} \mathcal{L}(\boldsymbol{\eta}) = \sum_{n=1}^{N}\int r_\eta(\boldsymbol{w})\log p(y_n\mid x_n,\boldsymbol{w})\mathrm{d}\boldsymbol{w} + \sum_{i=1}^{2}(\log s_i-\log s_0)+ 1 - \frac{1}{2s_0^2}\left(\|\boldsymbol{s}\|^2+\|\boldsymbol{m}-\boldsymbol{m}_0\|^2 \right). \end{equation}

残った積分は, モンテカルロ近似で計算できます. 以下に実験結果を示します. 本来未知のパラメータ\(w_1\), \(w_2\)の真値をそれぞれ-4,4とし, サイズ\(N=30\)の入力データ\(X\)を一様分布, 対応する出力データ\(Y\)をロジスティックモデルから作っておきました. これらをデータとみなした時の, 変分推論法の結果と比較用のHMCの結果を示します. 左図には, 真値(赤点), 近似事後分布(等高線), HMCで計算した真の事後分布からのサンプル(緑点)を示し, 右図には, 反復中のELBOの変化を示しました. 変分推論法とHMCの結果は概ね一致している一方で, 真値は外しているようです.

【コード12の実行結果】

変分推論法の反復回数は1000回, \(\boldsymbol{\eta}\)の初期値をゼロベクトルとし, 事前分布のパラメータ\(\boldsymbol{m}_0\), \(s_0\)をそれぞれゼロベクトルと1に設定しました. さらに, 予測分布\(p(y=1\mid x,\boldsymbol{w})\)の計算結果も示します. 図には, 真値(赤実線), 変分推論法の結果(橙実線), HMCの結果(緑破線)の3つを示しました.

【コード13の実行結果】

真値を外しているものの, 予測性能は良さそうです.

例3: Bayesian Neural Network

続いて, 分類問題を考えます. 下の画像のようにいくつかの点とそのラベルが与えられているとします. 点1個が1つのデータであり, 対応するラベルが赤丸と青バツです. 全部で16点あります. この2つのクラスの間を仕切る境界の推定が目標です. 要するに二値分類です.

【コード15の実行結果】

記号を導入します. 平面上の各点\(\boldsymbol{x}_n\in\mathbb{R}^{2}\)に対して, 対応するラベルデータ\(y_n\in\{1,0\}\)が与えられているとします. サンプルサイズは\(N\)として, データを次のようにまとめておきます:

\begin{equation} X = \{ \boldsymbol{x}_n\}_{n=1}^N ,\quad Y=\{ y_n\}_{n=1}^N. \end{equation}

今の場合, \(N=16\)です. 赤い点のラベルを\(y=1\)とし, 青いバツのラベルを\(y=0\)で表しておき, モデルを導入します. 今回はBayesian Neural Networkを考え, 赤い点に分類される確率をモデル化します. まず, 次のNeural Network\(\Phi_w:\mathbb{R}^{2}\to[0,1]\)を定義します:

\begin{equation} \Phi_w \left( \boldsymbol{x}\right) = \sigma\left( W^{(3)}\sigma\left( W^{(2)}\boldsymbol{x} + \boldsymbol{b}^{(2)} \right) + \boldsymbol{b}^{(3)}\right). \end{equation}

ここで, \(\sigma\)はシグモイド関数, 中間層のユニット数を5とし, \(\boldsymbol{w}\in\mathbb{R}^{d_w}\)はパラメータをまとめたベクトルとします. このNeural Networkを用いて, 次のようなBernoulliモデルを仮定します:

\begin{equation} p(y\mid \boldsymbol{x},\boldsymbol{w} ) = \Phi_w \left( \boldsymbol{x}\right)^{y}\left\{ 1-\Phi_w \left( \boldsymbol{x}\right) \right\}^{1-y}. \end{equation}

事前分布として, 以下の正規分布を仮定しておきます:

\begin{equation} \boldsymbol{w} \sim \mathrm{N}\left( \boldsymbol{0}, \lambda_w^{-1}I_{d_w}\right). \end{equation}

重みパラメータの事後分布を求めたいのですが, Neural Networkが絡むことでモデルへの寄与が複雑になります. このまま解析的に扱うのは難しそうなので, 変分推論法を用います. ここでは, 次のように近似分布\(r_\eta\)を定めます:

\begin{equation} \mathrm{N}\left( \boldsymbol{m}, \mathrm{diag}(\boldsymbol{s})^2 \right),\quad \boldsymbol{\eta} = \begin{bmatrix} \boldsymbol{m}\\ \log\boldsymbol{s}\end{bmatrix}. \end{equation}

このとき, 最大化したいELBOは以下のように計算できます:

\begin{equation} \mathcal{L}(\boldsymbol{\eta}) = \int r_{\eta}(\boldsymbol{w})\log p(y_n\mid x_n,\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \frac{\lambda_w}{2}(\|\boldsymbol{m}\|^2+\|\boldsymbol{s}\|^2)+\sum_{i=1}^{d_w}\log s_i + \frac{d_w}{2}(1+\log \lambda_w). \end{equation}

右辺の積分はモンテカルロ近似で計算できます. 実験結果を示します. 比較用に通常のニューラルネットワークに訓練結果も用意しました. 下図は, コスト関数を二乗誤差, 反復回数を\(10^5\)とし, Adamを用いた際の予測確率(左図)と反復中のコスト関数の値の変化(右図)です.

【コード16の実行結果】

次に変分推論法の適用結果を示しておきます. 反復回数は1000回, 事前分布の精度は\(\lambda_w=0.1\)としました. 下図は予測確率(左図)と反復中のELBOの値の変化(右図)です.

【コード19の実行結果】

ELBOは増加していますが, 予測はうまくいっていませんね...

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using ForwardDiff
using SpecialFunctions
using Flux

#statistics
using Random 
using Statistics
using Distributions

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack
【Juliaコード2; 平均場近似用の関数定義】
"""
    Gibbs sampler
"""
#Gibbs sampler
function Gibbs_sampler(data,model_params,n_samps)
    @unpack X,N = data
    @unpack β = model_params
    n_burnin = div(n_samps,10)
    SN = sum(X)
    μsamps = zeros(n_samps)
    λsamps = zeros(n_samps)
    μsamps[1] = 0
    λsamps[1] = 1
    @showprogress for s in 2:n_samps
        μsamps[s] = rand(Normal(λsamps[s-1]*SN/(λsamps[s-1]*N+1), 1/√(λsamps[s-1]*N+1)))
        λsamps[s] = rand(Gamma(N/2+1, 1/(β+sum((X.-μsamps[s-1]).^2)/2)))
    end
    return μsamps[n_burnin:end],λsamps[n_burnin:end]
end

"""
    Mean field approximation
"""
#computing ELBO
function ELBO(X,SN,SsqN,β,μhat,λhat,αhat,βhat)
    (
        -log(λhat)/2+1/2-αhat*log(βhat)+loggamma(αhat)-
        (αhat-1)*(digamma(αhat)-log(βhat))+αhat+
        N*(digamma(αhat)-log(βhat)-log(2*π))/2-
        αhat*(N/λhat+SsqN-2*μhat*SN+N*μhat^2)/βhat/2-
        (μhat^2+1/λhat)/2+log(β)-αhat*β/βhat
    )
end

#initialize the variational parameters
function initialize(β,N)
    return 0,1,1+N/2,β
end

#mean filed variational inference
function my_mf_VI(data,model_params,n_train,tol)
    @unpack X,N = data
    @unpack β = model_params
    μhat,λhat,αhat,βhat = initialize(β,N)
    SN = sum(X); SsqN = sum(X.^2)
    history = zeros(n_train+1); history[1] = ELBO(X,SN,SsqN,β,μhat,λhat,αhat,βhat)
    @showprogress for k in 1:n_train
        λhat = 1 + N*αhat/βhat
        μhat = SN * αhat/λhat/βhat
        βhat = β + SsqN/2 - μhat*SN + N*μhat^2/2 + N/λhat/2
        history[k+1] = ELBO(X,SN,SsqN,β,μhat,λhat,αhat,βhat)
        if abs(history[k+1]-history[k]) < tol
            return μhat,λhat,αhat,βhat,history[1:k+1]
            break
        end
    end
    return μhat,λhat,αhat,βhat,history
end

#variational posterior
function r(μ, λ, var_params)
    @unpack μhat,λhat,αhat,βhat = var_params
    return exp(logpdf(Normal(μhat, 1/√λhat), μ) + logpdf(Gamma(αhat, 1/βhat), λ))
end

#variational posterior sampling
function var_post_samps(var_params,n_samps)
    @unpack μhat,λhat,αhat,βhat = var_params
    μsamps = rand(Normal(μhat,1/√λhat),n_samps)
    λsamps = rand(Gamma(αhat,1/βhat),n_samps)
end
【Juliaコード3; データ作成】
#set the random seed
Random.seed!(42)

#true distribution and parameters
μ₁ = -2
μ₂ = 2
σ₁ = 1.5
σ₂ = 1.5
N = 20
mixture_normal = MixtureModel([Normal(μ₁, σ₁), Normal(μ₂, σ₂)])
X = rand(mixture_normal, N)
data = (X=X, N=N)

function true_pdf(x)
    pdf(mixture_normal, x)
end
【Juliaコード4; Gibbs sampler利用結果】
#Gibbs sampler for true posterior
β = 1e1
model_params = (β=β,)
n_samps = 5000
@time μsamps, λsamps = Gibbs_sampler(data,model_params,n_samps)
mcmc_samps = (μsamps=μsamps, λsamps=λsamps)
【Juliaコード5; 平均場近似利用結果】
#mean field variational inference
n_train = 1000
tol = 1e-6
@time μhat,λhat,αhat,βhat,history = my_mf_VI(data,model_params,n_train,tol)
var_params = (μhat=μhat, λhat=λhat, αhat=αhat, βhat=βhat)
【Juliaコード6; 結果の可視化】
#visualize the variational and true posterior
p1 = plot(-2:0.1:2, 0.01:0.01:0.8, (μ,λ)->r(μ,λ,var_params), st=:contour, xlabel="μ", ylabel="λ",
    title="approximated and true posterior", xlim=[-2,2], ylim=[0.01, 0.8])
plot!(μsamps, λsamps, st=:scatter, alpha=0.2, label="Gibbs sampler", markerstrokewidth=0.2, color=:green)
p2 = plot(history, title="ELBO", xlabel="iter", label=false, marker=:circle, markerstrokewidth=0, markersize=6)
fig1 = plot(p1, p2, size=(1000, 350))
savefig(fig1, "figs-VI/fig1.png")
【Juliaコード7; 予測分布の計算】
#predictive distribution
function pred(x,post_samps)
    @unpack μsamps,λsamps = post_samps
    n_samps = length(μsamps)
    preds = zeros(n_samps)
    for s in 1:n_samps
        preds[s] = pdf(Normal(μsamps[s],1/√λsamps[s]),x)
    end
    return mean(preds)
end

#visualize variational prediction
n_samps = 5000
μsamps, λsamps = var_post_samps(var_params, n_samps)
var_samps = (μsamps=μsamps, λsamps=λsamps)

#visualize the predictive pdf
xs = -3:0.1:3
fig2 = plot(xs, x->pred(x,var_samps), color=:orange, ls=:solid, lw=3, label="predictive(VI)")
plot!(xs, x->pred(x,mcmc_samps), color=:green, ls=:dash, lw=3,label="predictive(Gibbs)")
plot!(xs, true_pdf, label="true", title="true and predictive pdf", color=:red)
savefig(fig2, "figs-VI/fig2.png")
【Juliaコード8; 関数定義】
"""
    Hamiltonian Monte Carlo
"""
#one step of Störmer-Verlet method
function myStörmerVerlet(qvec, pvec, h, f)
    p_mid = pvec + h * f(qvec)/2;
    q_new = qvec + h * p_mid;
    p_new = p_mid + h * f(q_new)/2;
    return q_new, p_new
end

#update the position
function update(T, h, f, qvec, pvec)
    qvec_new = qvec
    pvec_new = pvec
    for t in 1:T
        qvec_new, pvec_new = myStörmerVerlet(qvec_new, pvec_new, h, f)
    end
    return qvec_new, pvec_new
end

#MH acceptance and rejection
function accept_or_reject(xvec, xvec_old, pvec, pvec_old, H)
    ΔH = H(xvec, pvec)-H(xvec_old, pvec_old)
    α = min(1.0, exp(-ΔH))
    u = rand()
    if u≤α
        return xvec, pvec
    else
        return xvec_old, pvec_old
    end
end

#log posterior
function logppost(wvec,data,model_params)
    @unpack X,Y,N = data
    @unpack m₀vec,s₀ = model_params
    sum([logpdf(Bernoulli(sigmoid(wvec[1]+wvec[2]*X[n])),Y[n]) for n in 1:N]) + logpdf(MvNormal(m₀vec,s₀),wvec)
end

#Hamiltonian Monte Carlo
function myHMC(data, model_params,n_samps, T, h)
    #initialization
    dw=2; wvec₀ = zeros(dw)
    n_burnin = div(n_samps,10)
    wsamps = zeros(dw, n_samps)
    wsamps[:,1] = wvec₀
    wvec = zeros(dw)
    pvec = zeros(dw)
    
    #Hamiltonian and potential
    U(wvec) = -logppost(wvec, data, model_params)
    ∇Uneg(wvec) = -ForwardDiff.gradient(U, wvec)
    H(wvec, pvec) = U(wvec) + norm(pvec)^2/2
    
    #sample
    wvec_old = wvec₀
    pvec_old = randn(dw)
    @showprogress for s in 2:n_samps
        pvec = randn(dw)
        wvec, pvec = update(T, h, ∇Uneg, wvec, pvec)
        wvec, pvec = accept_or_reject(wvec, wvec_old, pvec, pvec_old, H)
        wsamps[:,s] = wvec
        wvec_old = wvec
        pvec_old = pvec
    end
    return wsamps[:,n_burnin:end]
end  

"""
    Stochastic Vatiational Inference
"""
#split parameters
function split_params(vec)
    return vec[1:2], vec[3:end]
end

#reparametrize
function reparameterize(var_mean,var_logstd)
    var_mean + exp.(var_logstd) .* randn(2)
end

#logpmodel
function logpmodel(y,x,wvec)
    logpdf(Bernoulli(sigmoid(wvec[1]+wvec[2]*x)),y)
end

#ELBO
function ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N)
    val = 0
    for n in minibatch
        var_mean,var_logstd = split_params(ηvec)
        val += logpmodel(Y[n],X[n],reparameterize(var_mean,var_logstd))
    end
    return N*val/length(minibatch)-(norm(ηvec[1:2]-m₀vec)^2+norm(exp.(ηvec[3:end]))^2)/2/s₀^2+sum(ηvec[3:end])+(1-2*log(s₀))
end
ELBO(X,Y,ηvec,m₀vec,s₀,N) = ELBO(X,Y,ηvec,1:length(X),m₀vec,s₀,N)

#create model
function create_model(X,Y,m₀vec,s₀,N)
    ηvec = zeros(4)
    ps = Flux.params(ηvec)
    loss_func = minibatch->(-ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N))
    return ηvec,ps,loss_func
end

#stochastic variational inference
function stochastic_variational_inference(data,model_params,n_train,minibatch_size)
    @unpack X,N = data
    @unpack m₀vec,s₀ = model_params
    opt = ADAM(0.01)
    history = zeros(n_train)
    ηvec,ps,loss_func = create_model(X,Y,m₀vec,s₀,N)
    @showprogress for k in 1:n_train
        minibatch = sample(1:N,minibatch_size)
        Flux.train!(loss_func,ps,minibatch,opt)
        history[k] = ELBO(X,Y,ηvec,m₀vec,s₀,N)
    end
    return ηvec,history
end

#variational poseterior
function r(wvec,ηvec)
    pdf(MvNormal(ηvec[1:2],exp.(ηvec[3:4])),wvec)
end

#variational posterior sample
function var_post_samps(ηvec,n_samps)
    wsamps = rand(MvNormal(ηvec[1:2],exp.(ηvec[3:end])),n_samps)
    return wsamps
end
【Juliaコード9; データの作成】
#create data
Random.seed!(42)
w₁ = -4.0
w₂ = 4.0
w_true = (w₁=w₁,w₂=w₂)
N = 30
X = sort(rand(-10:10,N))
Y = [rand(Bernoulli(sigmoid(w₁+w₂*X[n]))) for n in 1:N]

function true_pdf(y,x,w_true)
    @unpack w₁,w₂ = w_true
    pdf(Bernoulli(sigmoid(w₁+w₂*x)),y)
end
【Juliaコード10; HMC利用結果】
#data and model parameters
data = (X=X,Y=Y,N=N)
model_params = (m₀vec=ones(2),s₀=1)

#HMC parameters
n_samps = 10000
T = 100
h = 0.1
@time wsamps = myHMC(data, model_params,n_samps, T, h)
【Juliaコード11; 確率的変分推論法利用結果】
#data and model parameters
data = (X=X,Y=Y,N=N)
model_params = (m₀vec=zeros(2),s₀=1)

#training
n_train = 1000
minibatch_size = N
@time ηvec,history = stochastic_variational_inference(data,model_params,n_train,minibatch_size)
【Juliaコード12; 結果の可視化】
#true posterior and variational posterior
p1 = plot(-5:0.1:3,-2:0.1:6,(w₁,w₂)->r([w₁,w₂],ηvec), st=:contour, xlabel="w₁", ylabel="w₂",
    title="approximated and true posterior", xlim=[-5,3], ylim=[-2,6])
plot!(wsamps[1,:],wsamps[2,:],st=:scatter,alpha=0.2,label="HMC",markerstrokewidth=0.2,color=:green)
plot!([w₁,],[w₂,],st=:scatter,color=:red,label="true",markersize=10)
p2 = plot(history,xlabel="iter",ylabel="ELBO",title="ELBO",label=false)
fig3 = plot(p1, p2, size=(1000, 350))
savefig(fig3,"figs-VI/fig3.png")
【Juliaコード13; 予測分布の計算】
#predictive pdf
function pred(y,x,wsamps)
    n_samps = size(wsamps,2)
    preds = zeros(n_samps)
    for s in 1:n_samps
        preds[s] = pdf(Bernoulli(sigmoid(wsamps[1,s]+wsamps[2,s]*x)),y)
    end
    return mean(preds)
end

#variational posterior samples
n_samps = 5000
var_wsamps = var_post_samps(ηvec,n_samps)

#visualize the predictive pdf
xs = -10:0.1:10
fig4 = plot(xs,x->pred(1,x,var_wsamps),color=:orange,ls=:solid,lw=3,label="predictive(VI)",legend=:bottomright)
plot!(xs,x->pred(1,x,wsamps),color=:green,ls=:dash,lw=3,label="predictive(HMC)")
plot!(xs,x->true_pdf(1,x,w_true),label="true",title="true and predictive (y=1)",color=:red)
savefig(fig4, "figs-VI/fig4.png")
【Juliaコード14; 可視化用関数定義】
#plot the data and return the figure
function plot_data(X, Y)
    _,N = size(X)
    fig = plot(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, title="data", legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

function plot_data(fig, X, Y)
    _,N = size(X)
    fig = plot!(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end
【Juliaコード15; データ作成】
#create the data
N = 16
X = [
    0.3 0.52 0.3 0.50 0.60 0.7 0.70 0.55  0.85 0.10 0.05 0.20 0.39 0.63 0.86 0.97;
    0.1 0.35 0.9 0.15 0.95 0.2 0.80 0.75  0.55 0.76 0.15 0.45 0.56 0.50 0.80 0.20;
]
Y = vcat(zeros(div(N,2)), ones(div(N,2)))
data = (X=X,Y=Y,N=N)

#plot the data
fig5 = plot_data(X,Y)
savefig(fig5, "figs-VI/fig5.png")
【Juliaコード16; ニューラルネットワーク利用結果】
#set the random seed
Random.seed!(42)

#initialize the model
nn = Chain(
        Dense(2,10,sigmoid),
        Dense(10,10,sigmoid),
        Dense(10,1,sigmoid)
        )

#cost function 
function cost(x,y)
    (y-nn(x)[1])^2
end

#training
function my_train(data,nn,n_train)
    @unpack X,Y,N = data
    costs = zeros(n_train)
    opt = ADAM(0.01)
    n = 1
    @showprogress for k in 1:n_train
        n = rand(1:N)
        Flux.train!(cost,Flux.params(nn),[(X[:,n],Y[n])],opt)
        costs[k] = mean([cost(X[:,n],Y[n]) for n in 1:N])
    end
    return costs
end

#train the model
n_train = Int(1e5)
@time costs = my_train(data,nn,n_train)

#prediction 
function pred(x₁,x₂)
    nn([x₁,x₂])[1]
end

#visualize
p1 = plot(0:0.02:1, 0:0.02:1, pred, st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
p1 = plot_data(p1,X,Y)
p2 = plot(costs,xlabel="iter",ylabel="costs",title="costs",xscale=:log10,yscale=:log10,label=false)
fig6 = plot(p1,p2,size=(1000,400))
savefig(fig6, "figs-VI/fig6.png")
【Juliaコード17; 関数定義】
#create neural network
function create_nn()
    nn = Chain(
        Dense(2,10,sigmoid),
        Dense(10,10,sigmoid),
        Dense(10,1,sigmoid)
        )
    wvec,re_nn = Flux.destructure(nn)
    dw = length(wvec)
    return re_nn,zeros(2*dw),dw
end

#split parameters
function split_params(vec,dw)
    return vec[1:dw], vec[dw+1:end]
end

#reparametrize
function reparameterize(var_mean,var_logstd,dw)
    var_mean + exp.(var_logstd) .* randn(dw)
end

#logpmodel
function logpmodel(y,xvec,wvec,re_nn)
    logpdf(Bernoulli(re_nn(wvec)(xvec)[1]),y)
end

#ELBO
function ELBO(X,Y,ηvec,minibatch,w_prec,N,dw,re_nn)
    val = 0
    for n in minibatch
        var_mean,var_logstd = split_params(ηvec,dw)
        val += logpmodel(Y[n],X[:,n],reparameterize(var_mean,var_logstd,dw),re_nn)
    end
    return N*val/length(minibatch)-w_prec*(norm(ηvec[1:dw])^2+norm(exp.(ηvec[dw+1:end]))^2)/2+sum(ηvec[dw+1:end])+dw*(1+log(w_prec))/2
end
ELBO(X,Y,ηvec,w_prec,N,dw,re_nn) = ELBO(X,Y,ηvec,1:size(X,2),w_prec,N,dw,re_nn)

#create model
function create_model(X,Y,w_prec,N)
    re_nn,ηvec,dw = create_nn()
    ps = Flux.params(ηvec)
    loss_func = minibatch->(-ELBO(X,Y,ηvec,minibatch,w_prec,N,dw,re_nn))
    return re_nn,dw,ηvec,ps,loss_func
end

#stochastic variational inference
function stochastic_variational_inference(data,model_params,n_train,minibatch_size)
    @unpack X,N = data
    @unpack w_prec = model_params
    opt = ADAM(0.01)
    history = zeros(n_train)
    re_nn,dw,ηvec,ps,loss_func = create_model(X,Y,w_prec,N)
    @showprogress for k in 1:n_train
        minibatch = sample(1:N,minibatch_size)
        Flux.train!(loss_func,ps,minibatch,opt)
        history[k] = ELBO(X,Y,ηvec,w_prec,N,dw,re_nn)
    end
    return ηvec,history,re_nn
end

#variational posterior sample
function var_post_samps(ηvec,n_samps)
    dw = div(length(ηvec),2)
    wsamps = rand(MvNormal(ηvec[1:dw],exp.(ηvec[dw+1:end])),n_samps)
    return wsamps
end

#predictive dsitribution
function pred(xvec,ηvec,re_nn)
    n_samps = 1000
    wsamps = var_post_samps(ηvec,n_samps)
    preds = zeros(n_samps)
    @showprogress for s in 1:n_samps
        preds[s] = pdf(Bernoulli(re_nn(wsamps[:,s])(xvec)[1]),1)
    end
    return mean(preds)
end
【Juliaコード18; 確率的変分推論法利用結果】
#data and model parameters
data = (X=X,N=N)
model_params = (w_prec=0.1,)

#training
n_train = 1000
minibatch_size = 1
@time ηvec,history,re_nn = stochastic_variational_inference(data,model_params,n_train,minibatch_size)
【Juliaコード19; 予測結果】
#visualize
p1 = plot(0:0.1:1, 0:0.1:1, (x₁,x₂)->pred([x₁,x₂],ηvec,re_nn), 
    st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
p1 = plot_data(p1,X,Y)
p2 = plot(history,xlabel="iter",ylabel="ELBO",title="ELBO",label=false)
fig7 = plot(p1,p2,size=(1000,400))
savefig(fig7, "figs-VI/fig7.png")
参考文献

      [1] ビショップ, パターン認識と機械学習, 丸善出版, 2019
      [2] 渡辺澄夫, ベイズ統計の理論と方法, コロナ社, 2019
      [3] 須山敦志, ベイズ深層学習, 講談社, 2020
      [4] A.Graves, Practical variational inference for neural networks, Advances in Neural Information Processing Systems, pp.2348-2356, 2011
      [5] M.D.Hoffman, D.M.Blei, C.Wang, J.Paisley, Stochastic variational inference, The Journal of Machine Learning Research, 14(1), pp.1303-1347, 2013
      [6] D.P.Kingma, M.Willing, Auto-encoding variational Bayes, International Conference on Learning Representation, 2014