変分推論法(修正版)

記事の内容


以前書いた記事の修正版です. モデルやプログラムを修正しました. 変分推論を具体例を用いて説明します. 最も有名な平均場近似を用いた手法と, ミニバッチ学習に対応した確率的変分推論法を紹介します. 確率的変分推論法は, 元々[5]で提案された手法ですが, ここではもっと広い意味で用いています. 原論文では確率的最適化と自然勾配の考え方を用いていますが, ここでは「確率的最適化をベースとした変分推論法」程度の意味で用いています.

概要

問題設定

変分推論は, 複雑な事後分布を近似する方法です. MCMCとは違い, 最適化がベースです. まず, どういう状況で用いる手法なのか, 問題設定を確認しておきます.

手元にデータ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\subset \mathbb{R}^D\)があるとします.このデータを発生させた分布を推定することが目的です. モデルを\(p(\boldsymbol{x}\mid\boldsymbol{w})\)でモデル化したとします. ここで, \(\boldsymbol{w}\in\mathbb{R}^d\)はパラメータです. このパラメータの事前分布を\(p(\boldsymbol{w})\)とします. すると事後分布が計算できます.

\begin{equation} p(\boldsymbol{w}\mid X) \propto p(X\mid\boldsymbol{w}) p(\boldsymbol{w}) \end{equation}

この事後分布の平均を計算したり, モデルを事後分布で重みづけて予測分布を計算するには事後分布が複雑すぎるとします. 解析的に扱えない場合, MCMCや変分推論法が用いられます. 次に, 変分推論法の基本的なアイデアを述べます.

基本的なアイデア

変分推論法では, 事後分布を近似する分布を構成します. 目標は, 次のような分布を, 予め定めた分布族\(\mathcal{R}\)から見つけることです.

\begin{equation} p(\boldsymbol{w}\mid X) \simeq r(\boldsymbol{w}),\quad r\in \mathcal{R} \end{equation}

ここで, 次の2点が問題になります.

  • 近似分布は事後分布にどういう意味で近いのか?
  • 分布族\(\mathcal{R}\)はどのように設定するか?

まず1点目. 分布間の近さは, Kullback-Leiblerダイバージェンスで測るのが自然です. 近似分布と事後分布の間のKullback-Leiblerダイバージェンスが最小になるように近似分布を定めます. もう少し詳しく見てみましょう. 次のような数量を定めます.

\begin{equation} \mathcal{L}[r] = -\int r(\boldsymbol{w})\log \frac{r(\boldsymbol{w})}{p(X,\boldsymbol{w})} d\boldsymbol{w} \end{equation}

すると, \(r(\boldsymbol{w})\)と\(p(\boldsymbol{w}\mid X)\)の間のKullback-Leiblerダイバージェンスは, 次のように書けます.

\begin{equation} D_{\mathrm{KL}}[r(\boldsymbol{w})\|p(\boldsymbol{w}\mid X)] = -\mathcal{L}[r] + \log p(X) \end{equation}

Kullback-Leiblerダイバージェンスが非負であることから, 次式が得られます.

\begin{equation} \mathcal{L}[r] \leq \log p(X) \end{equation}

\(\mathcal{L}\)が周辺尤度(エビデンス)の下界であることから, この量をELBO(Evidence Lower bound)と呼びます. 当面の間の目標はKullback-Leiblerダイバージェンスの最小化ですが, これはELBOの最大化と等価です. なお, ELBOの符号反転は変分自由エネルギーと呼ばれます.

次に2点目です. 分布族の決め方はいくつかあります. 最も有名な方法が平均場近似です. 平均場近似では, 近似分布を具体的には指定せず, 事後分布が都合よく分解できると仮定します. 分解できるという仮定のみでも, 事後分布は"いい感じ"に求まります. 他には, 予め分布を指定する方法があります. 例えば予め正規分布で近似できると仮定して, パラメータを決めてやります. 以下, この2つのアイデアを詳しく掘り下げていきます.

平均場近似

平均場近似の考え方

平均場近似では, 近似分布が属する分布族を次のように定めます.

\begin{equation} \mathcal{R} = \left\{ r(\boldsymbol{w})\mid \text{ある確率(密度)関数\(r_1,\cdots,r_K\)が存在して, }r(\boldsymbol{w}) = \prod_{i=1}^Kr_i(\boldsymbol{w_i})\right\} \end{equation}

このような分布族から持ってきた近似分布\(r\)について, 各\(i=1,\cdots,K\)に対して, ELBOは次のように書けます.

\begin{equation} \mathcal{L}[r] = -D_{\mathrm{KL}}\left[r_i(\boldsymbol{w}_i) \| \frac{1}{Z}\exp\left(\mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right]\right) \right] + \mathrm{const} \end{equation}

ここで, \(Z\)は正規化定数で, 上式の期待値は次のようにとります.

\begin{equation} \mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right] = \int \left(\prod_{j\neq i}r_j(\boldsymbol{w}_j)\right) \log p(X,\boldsymbol{w}) d\boldsymbol{w}^{\setminus i} \end{equation}

2つの分布が一致するときにKLダイバージェンスが最小になることから, 各\(j\neq i\)に対して因子\(r_j\)が定まっているときに, \begin{equation} r_i(\boldsymbol{w}_i) \propto \exp \left( \mathrm{E}_{\prod_{j\neq i}r_i} [\log p(X,\boldsymbol{w})]\right) \end{equation} と因子\(r_i\)を更新することで, ELBOが最大になります. 適当な初期分布から初めて更新を繰り返していけば, ELBOがどこかで頭打ちになるはずです. このとき近似分布は事後分布に十分近いと期待できます.

例: 正規分布の推論

ここでは例として, 正規分布の平均と精度の事後分布を推定します. モデルが正規分布で, パラメータが平均パラメータと精度パラメータにあたります. 次のような生成過程を考えます.

\begin{equation} \mu\sim\mathrm{N}(0,1) , \quad \lambda \sim\mathrm{Gamma}(1,\beta),\quad x\sim\mathrm{N}(\mu,\lambda^{-1}) \end{equation}

\(\beta\)は定数です. このとき事後分布は次のようになります.

\begin{equation} p(\mu,\lambda\mid X) \propto \lambda^{\frac{N}{2}} e^{-\beta\lambda} \exp\left\{ -\frac{1}{2}\mu^2-\frac{\lambda}{2}\sum_{n=1}^N(x_n-\mu)^2\right\} \end{equation}

例えばこの分布の平均を求めるのは困難です. ということで変分近似を行います. 次のように分解されるとします.

\begin{equation} p(\mu,\lambda\mid X) \simeq r(\mu)r(\lambda) \end{equation}

先ほど導出した期待値計算を含む分布を導出すると, 以下のようになります.

\begin{align} r(\mu) &\propto \exp\left\{ -\frac{1}{2}\mu^2-\frac{1}{2}\mathrm{E}_\lambda [\lambda] \sum_{n=1}^N(x_n-\mu)^2 \right\} \\ r(\lambda) &\propto \exp\left\{ \frac{N}{2}\log\lambda-\beta\lambda - \frac{\lambda}{2}\sum_{n=1}^N\mathrm{E}_\mu [(x_n-\mu)^2]\right\} \end{align}

\(\mu\)の近似分布は正規分布, \(\lambda\)の近似分布はガンマ分布であることが分かります. それぞれの分布を次のようにおきます.

\begin{equation} \mathrm{N}(\hat{\mu}, \hat{\lambda}) ,\quad \mathrm{Gamma}(\hat{\alpha}, \hat{\beta}) \end{equation}

すると, 次のように書けます.

\begin{align} \hat{\lambda} &= 1+\frac{N\hat{\alpha}}{\hat{\beta}} \\ \hat{\mu} &= \frac{\hat{\alpha}}{\hat{\lambda}\hat{\beta}}\left(\sum_{n=1}^Nx_n\right) \\ \hat{\alpha} &= 1+\frac{N}{2} \\ \hat{\beta} &= \beta + \frac{1}{2}\left(\sum_{n=1}^Nx_n^2\right)-\hat{\mu}\left( \sum_{n=1}^Nx_n\right) + \frac{N}{2}\hat{\mu}^2 + \frac{N}{2\hat{\lambda}} \end{align}

これを更新式として, 次のようなアルゴリズムが得られます. 実際には, \(\hat{\alpha}\)が上記の更新で変化しないので, ループの外に出して実装します.

【正規分布に対する変分推論アルゴリズム】

Initialize \(\hat{\lambda}^{(0)}, \hat{\mu}^{(0)}, \hat{\alpha}^{(0)}, \hat{\beta}^{(0)}\), and set \(k=0, \epsilon\).
while\( |\mathcal{L}[r^{(k+1)}]-\mathcal{L}[r^{(k)}] | <\epsilon\)
update \(\hat{\lambda}^{(k+1)}, \hat{\mu}^{(k+1)}, \hat{\alpha}^{(k+1)}, \hat{\beta}^{(k+1)}\).
\(k=k+1\)
end

ちなみに, ELBOの値は以下のようになります. 少し長いです. \(\Gamma\)はガンマ関数, \(\psi\)はディガンマ関数です.

\begin{align} \mathcal{L}[r] = -\frac{1}{2}\log\hat{\lambda} + \frac{1}{2}-\hat{\alpha}\log\hat{\beta} + \log\Gamma (\hat{\alpha}) - (\hat{\alpha}-1)(\psi (\hat{\alpha})-\log\hat{\beta}) + \hat{\alpha} + \frac{N}{2}\left( \psi (\hat{\alpha})-\log (2\pi\hat{\beta})\right)\\ - \frac{\hat{\alpha}}{2\hat{\beta}}\left( \sum_{n=1}^Nx_n^2-2\hat{\mu}\sum_{n=1}^Nx_n + N\hat{\mu}^2 + \frac{N}{\hat{\lambda}}\right) - \frac{1}{2}\left( \hat{\mu}^2+\frac{1}{\hat{\lambda}}\right) + \log\beta - \frac{\hat{\alpha}\beta}{\hat{\beta}} \end{align}

以下に実験結果を示します. 真の分布は混合正規分布\(0.5\mathrm{N}(-1.4, 1.1^2)+0.5\mathrm{N}(1.5, 0.9^2)\)としました. この分布からデータを\(N=20\)点発生させました. \(\beta=10, \epsilon=10^{-6}\)とします. また, 初期値は\(\hat{\mu}=0.0, \hat{\lambda}=1.0, \hat{\beta}=\beta\)とします. 左下図の等高線は, 近似事後分布を表します. また, Gibbs samplerによって(真の)事後分布からサンプルし, プロットしました(緑点). 実際の事後分布と近似事後分布は一致しているように見えます. また, 右下図は, ELBOの変化の様子を示しました. 5回目の更新で収束しています.

【コード3の実行結果】

近似事後分布からサンプルを発生させて, 予測分布を計算しました(橙線). 赤色が真の分布です. また, Gibbs samplerで得た事後分布からのサンプルを用いて, 真の予測分布を計算しました(青色). 予測分布同士はほぼ一致しています.

【コード4の実行結果】

概要

確率的変分推論の考え方

事後分布の近似方法として, 予め分布を指定する方法が考えられます. 例えば正規分布などの簡単な分布を仮定して, パラメータを調節して近似精度向上を目指します. この方法でも, 近似分布と真の事後分布のKLダイバージェンスを最小化します.

近似分布を\(r_{\boldsymbol{\theta}}\)と表示しておきます. パラメータ\(\boldsymbol{\theta}\)を調節して, KLダイバージェンスの最小化, もしくはそれと等価なELBOの最大化を目指します. ELBOを具体的に書き出すと, 以下のようになります.

\begin{equation} \mathcal{L}(\boldsymbol{\theta}) = \sum_{n=1}^N\int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid\boldsymbol{w})\mathrm{d}\boldsymbol{w} + \int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log p(\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log r_{\boldsymbol{\theta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w} \end{equation}

これを, 確率的勾配降下法と同じ要領で, 次のように近似します.

\begin{equation} \mathcal{L}_n(\boldsymbol{\theta}) \simeq \mathcal{L}(\boldsymbol{\theta}) = N\int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid\boldsymbol{w})\mathrm{d}\boldsymbol{w} + \int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log p(\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \int r_{\boldsymbol{\theta}}(\boldsymbol{w})\log r_{\boldsymbol{\theta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w} \end{equation}

これを用いて, 最急"上昇"法で次のように最大化していきます.

\begin{equation} \boldsymbol{\theta}^{(k+1)} = \boldsymbol{\theta}^{(k)} + \alpha_k \nabla_{\boldsymbol{\theta}} \mathcal{L}_n(\boldsymbol{\theta}^{(k)}) \end{equation}

アルゴリズムとしてまとめておきます.

【確率的変分推論アルゴリズム】

Initialize weights \(\boldsymbol{\theta}^{(0)}\), step size \(\alpha^{(k)}\) and \(k=0, \epsilon\).
while\( |\mathcal{L}(\boldsymbol{\theta}^{(k+1)})-\mathcal{L}(\boldsymbol{\theta}^{(k)}) | <\epsilon\)
Select \(\boldsymbol{x}_n\) uniformly at random.
\(\boldsymbol{\theta}^{(k+1)} = \boldsymbol{\theta}^{(k)} + \alpha^{(k)}\nabla_{\boldsymbol{\theta}} \mathcal{L}_{n}(\boldsymbol{\theta}^{(k)})\)
\( k=k+1\)
end

例: Bayesian Neural Network

次のような分類問題を考えます. 画像のようにいくつかの点とそのラベルが与えられているとします. 点1個が1つのデータであり, 対応するラベルが赤丸と青バツです. 全部で16点あります. この2つのクラスの間を仕切る境界の推定が目標です. 要するに二値分類です.

【コード6の実行結果】

記号を導入します. 平面上の各点\(\boldsymbol{x}_n\in\mathbb{R}^{D_x}\)に対して, 対応するラベルデータ\(y_n\in\{1,0\}\)が与えられているとします. サンプルサイズは\(N\)として, データを次のようにまとめておきます.

\begin{equation} X = \{ \boldsymbol{x}_n\}_{n=1}^N ,\quad Y=\{ y_n\}_{n=1}^N \end{equation}

今の場合, \(N=16\)です. 赤い点のラベルを\(y=1\)とし, 青いバツのラベルを\(y=0\)とします. モデルを導入します. 今回はBayesian Neural Networkを考えます. 赤い点に分類される確率をモデル化します. まず, 次のNeural Network\(\Phi:\mathbb{R}^{D_x}\to\mathbb{R}\)を定義します.

\begin{equation} \Phi \left( \boldsymbol{x}, \boldsymbol{w}\right) = \sigma\left( W^{(3)}\sigma\left( W^{(2)}\boldsymbol{x} + \boldsymbol{b}^{(2)} \right) + \boldsymbol{b}^{(3)}\right) \end{equation}

ここで, \(\sigma\)はシグモイド関数, 中間層の幅を\(D_0=5\)とし, \(\boldsymbol{w}\in\mathbb{R}^{d_w}\)はパラメータをまとめたベクトルとします. このNeural Networkを用いて, 次のようなBernoulliモデルを仮定します.

\begin{equation} p(y\mid \boldsymbol{x},\boldsymbol{w} ) = \Phi \left( \boldsymbol{x}, \boldsymbol{w}\right)^{y}\left\{ 1-\Phi \left( \boldsymbol{x}, \boldsymbol{w} \right) \right\}^{1-y} \end{equation}

事前分布として, 正規分布を仮定しておきます.

\begin{equation} \boldsymbol{w} \sim \mathrm{N}\left( \boldsymbol{0}, \lambda_w^{-1}I_{d_w}\right) \end{equation}

重みパラメータの事後分布を求めたいのですが, Neural Networkが絡むことでモデルへの寄与が複雑になります. このまま解析的に扱うのは難しそうなので, 変分推論法を用います. ここでは, 次のように近似分布の族を定めます.

\begin{equation} \mathcal{R} = \left\{ r_{\boldsymbol{\theta}}(\boldsymbol{w}) \mid r_{\boldsymbol{\theta}}(\boldsymbol{w}) = \prod_{i=1}^{d_w} \left( \frac{1}{2\pi\sigma_i^2}\right)^\frac{1}{2} \exp\left\{ -\frac{1}{2\sigma_i^2}(w_i-\mu_i)^2\right\} \right\} \end{equation}

ここで, 近似分布のパラメータを\(\boldsymbol{\theta}=[\mu_1, \cdots,\mu_{d_w},\log\sigma_1,\cdots,\log\sigma_{d_w}]^{\mathrm{T}}\)とおきました. 以上の仮定の下, ELBOは以下のように計算できます.

\begin{equation} \mathcal{L}(\theta) = \sum_{n=1}^N\int r_{\boldsymbol{\theta}}(\boldsymbol{\theta})\log p(y_n\mid\boldsymbol{x}_n,\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \frac{\lambda_w}{2}\sum_{i=1}^{d_w} \left(\theta_i^2+e^{2\theta_{dw+i}}\right) + \sum_{i=1}^{d_w}\theta_{dw+i} + \frac{d_w}{2} + \frac{d_w}{2}\log\lambda_w \end{equation}

データ点\((\boldsymbol{x}_n, y_n)\)を全データからランダムにサンプルして, ELBOを次のように近似します.

\begin{equation} \mathcal{L}(\boldsymbol{\theta}) \simeq \mathcal{L}_n(\boldsymbol{\theta}) = N\int r_{\boldsymbol{\theta}}(\boldsymbol{\theta})\log p(y_n\mid\boldsymbol{x}_n,\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \frac{\lambda_w}{2}\sum_{i=1}^{d_w} \left(\theta_i^2+e^{2\theta_{dw+i}}\right) + \sum_{i=1}^{d_w}\theta_{dw+i} + \frac{d_w}{2} + \frac{d_w}{2}\log\lambda_w \end{equation}

この近似値\(\mathcal{L}_n\)を最大化することで, 近似分布のパラメータ\(\boldsymbol{\theta}\)を定めます. \(\mathcal{L}_n\)の勾配を求めます. 各\(j=1,\cdots,d_w\)に対して, 次式が成り立ちます. \(\{\boldsymbol{w}^{(s)}\}_{s=1}^S\)は近似分布からのサンプルです.

\begin{align} \frac{\partial}{\partial\theta_j}\mathcal{L}_n(\boldsymbol{\theta}) &= N\int \log p(y_n\mid\boldsymbol{x}_n,\boldsymbol{w})\left( \frac{\partial}{\partial\theta_j} \log r_{\boldsymbol{\theta}}(\boldsymbol{w})\right)r_{\boldsymbol{\theta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w} - \lambda_w\theta_j \\ &\simeq Ne^{-2\theta_{dw+j}}\frac{1}{S}\sum_{s=1}^S (w_j^{(s)}-\theta_j)\log p(y_n\mid\boldsymbol{x},\boldsymbol{w}^{(s)}) - \lambda_w\theta_j \\ \frac{\partial}{\partial\theta_{d_w+j}}\mathcal{L}_n(\boldsymbol{\theta}) &= N\int \log p(y_n\mid\boldsymbol{x}_n,\boldsymbol{w})\left( \frac{\partial}{\partial\theta_{d_w+j}} \log r_{\boldsymbol{\theta}}(\boldsymbol{w})\right)r_{\boldsymbol{\theta}}(\boldsymbol{w})\mathrm{d}\boldsymbol{w} +(1 - \lambda_we^{2\theta_j}) \\ &\simeq N\frac{1}{S}\sum_{s=1}^S \left\{(w_j^{(s)}-\theta_j)^2e^{-2\theta_{dw+j}}-1\right\}\log p(y_n\mid\boldsymbol{x},\boldsymbol{w}^{(s)}) +(1 - \lambda_we^{2\theta_j}) \\ \end{align}

以下に実験結果を示します. \(\lambda_w=10^{-3}\), 反復回数を\(2000\)としました. また, ステップサイズは\(\alpha_k=\frac{0.4}{k}\)としました. 直前の反復でのELBOとの差が\(10^{-6}\)を下回った時点で反復を終了します. 結果は以下の通りです. 予測確率が0.5付近です. 境界は得られませんでした. .

【コード8の実行結果】

ELBOの変化を下図に示します. 思惑通り, 増加していますね. なお, ELBOの解析的に計算できない積分の項はモンテカルロ法で近似しています.

【コード9の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using ForwardDiff
using SpecialFunctions

#statistics
using Random 
using Statistics
using Distributions

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack
【Juliaコード;2 平均場近似の関数定義】
#ELBO
ELBO(X, SN, SsqN, β, μhat, λhat, αhat, βhat) = (-log(λhat)/2+1/2-αhat*log(βhat)+
    log(gamma(αhat))-(αhat-1)*(digamma(αhat)-log(βhat))+αhat 
    + N*(digamma(αhat)-log(2*pi*βhat))/2-αhat*(SsqN-2*μhat*SN+N*μhat^2+N/λhat)/βhat/2-(μhat^2+1/λhat)/2
    + log(β)-β*αhat/βhat)

#variational inference using mean field approximation
function myVI(data, model_params, inits, n_train, tol)
    @unpack X,N = data
    @unpack β = model_params
    @unpack μhat₀, λhat₀, βhat₀ = inits
    SN = sum(X)
    SsqN = sum(X.^2)
    μhat = μhat₀
    λhat = λhat₀
    αhat = N/2+1
    βhat = βhat₀
    Lvec = zeros(n_train+1)
    Lvec[1] = ELBO(X, SN, SsqN, β, μhat, λhat, αhat, βhat)
    
    @showprogress for k in 1:n_train
        λhat = 1 + N*αhat/βhat
        μhat = SN * αhat/λhat/βhat
        βhat = β + SsqN/2 - μhat*SN + N*μhat^2/2 + N/λhat/2
        Lvec[k+1] = ELBO(X, SN, SsqN, β, μhat, λhat, αhat, βhat)
        if abs(Lvec[k+1]-Lvec[k])≤tol
            return μhat, λhat, αhat, βhat, Lvec[1:k+1]
            break
        end
    end
    return μhat, λhat, αhat, βhat, Lvec
end

#variational posterior sample
function post_samps(var_prams, n_samps)
    μsamps = rand(Normal(μhat, 1/√λhat), n_samps)
    λsamps = rand(Gamma(αhat, 1/βhat), n_samps)
    return μsamps, λsamps
end

#predictive distribution
function pred(x, samps)
    @unpack μsamps, λsamps = samps
    n_samps = length(μsamps)
    preds = zeros(n_samps)
    for s in 1:n_samps
        preds[s] = pdf(Normal(μsamps[s], 1/√λsamps[s]),x)
    end 
    return mean(preds)
end

#Gibbs sampler for comparing the posterior to variational posterior
function myGibbs(data, model_params, inits, n_samps, n_burnin)
    @unpack X,N = data
    @unpack β = model_params
    @unpack μ₀,λ₀ = inits
    SN = sum(X)
    μsamps = zeros(n_samps)
    λsamps = zeros(n_samps)
    μsamps[1] = μ₀
    λsamps[1] = λ₀
    @showprogress for s in 2:n_samps
        μsamps[s] = rand(Normal(λsamps[s-1]*SN/(λsamps[s-1]*N+1), 1/√(λsamps[s-1]*N+1)))
        λsamps[s] = rand(Gamma(N/2+1, 1/(β+sum((X.-μsamps[s-1]).^2)/2)))
    end
    return μsamps[n_burnin:end], λsamps[n_burnin:end]
end
【Juliaコード3; 平均場近似の実行】
#set the seed
Random.seed!(42)

#create data
μ₁ = -1.4
μ₂ = 1.5
σ₁ = 1.1
σ₂ = 0.9
N = 20
mixture_normal = MixtureModel([Normal(μ₁, σ₁), Normal(μ₂, σ₂)])
true_pdf(x) = pdf(mixture_normal, x)
X = rand(mixture_normal, N)
data = (X=X, N=N)

#variational inference
β = 1e1
model_params = (β=β,)
inits = (μhat₀=0.0 , λhat₀=1.0, βhat₀=β)
n_train = 1000
tol = 1e-6
μhat, λhat, αhat, βhat, Lvec = myVI(data, model_params, inits, n_train, tol)
var_params = (μhat=μhat, λhat=λhat, αhat=αhat, βhat=βhat)

#approximated posterior pdf
function r(μ, λ, var_params)
    @unpack μhat,λhat,αhat,βhat = var_params
    return exp(logpdf(Normal(μhat, 1/√λhat), μ) + logpdf(Gamma(αhat, 1/βhat), λ))
end

#Gibbs sampler
n_samps = 5000
n_burnin = div(n_samps, 10)
inits = (μ₀=0.0, λ₀=1.0)
μsamps, λsamps = myGibbs(data, model_params, inits, n_samps, n_burnin)
samps = (μsamps=μsamps, λsamps=λsamps)

#visualize the variational and true posterior
p1 = plot(-2:0.1:2, 0.01:0.01:0.8, (μ,λ)->r(μ,λ,var_params), st=:contour, xlabel="μ", ylabel="λ", title="approximated posterior", xlim=[-2,2], ylim=[0.01, 0.8])
plot!(μsamps, λsamps, st=:scatter, alpha=0.2, label="Gibbs sampler", markerstrokewidth=0.2, color=:green)
p2 = plot(Lvec, title="ELBO", xlabel="iter", ylabel="ELBO", label=false, marker=:circle, markerstrokewidth=0, markersize=6)
fig1 = plot(p1, p2, size=(800, 350))
savefig(fig1, "figs-VI/fig1.png")
【Juliaコード4; 平均場近似の予測分布】
#variational prediction
n_samps = 5000
μsamps_var, λsamps_var = post_samps(var_params, n_samps)
var_samps = (μsamps=μsamps_var, λsamps=λsamps_var)

#visualize the predictive pdf
xs = -3:0.1:3
fig2 = plot(xs, x->pred(x,var_samps), color=:orange, ls=:solid, lw=3, label="predictive(VI)")
plot!(xs, x->pred(x,samps), color=:green, ls=:dash, lw=3,label="predictive(Gibbs)")
plot!(xs, true_pdf, label="true", title="predictive pdf", color=:red)
savefig(fig2, "figs-VI/fig2.png")
【Juliaコード5; 確率的変分推論の関数定義】
#plot the data and return the figure
function plot_data(X, Y)
    _,N = size(X)
    fig = plot(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, title="data", legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

function plot_data(fig, X, Y)
    _,N = size(X)
    fig = plot!(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

#initialize the parameter
function init_params(st)
    @unpack  Dx, Dy, D₀ = st
    W₂ = rand(D₀, Dx)
    W₃ = rand(Dy, D₀)
    b₂ = zeros(D₀)
    b₃ = zeros(Dy)
    return W₂, W₃, b₂, b₃
end

#stick the weights and biases to a large matrix
function stick_params(W₂, W₃, b₂, b₃, Dx)
    tmp1 = vcat(b₂', W₂')
    tmp2 = hcat(tmp1, zeros(Dx+1))
    tmp3 = hcat(W₃, b₃)
    return vcat(tmp2, tmp3)
end

#devide the paramters vector to weights and biases
function reshape_params(wvec, st)
    @unpack  Dx, Dy, D₀ = st
    W = reshape(wvec, (Dx+Dy+1, D₀+1))
    W₂ = view(W, 2:Dx+1, 1:D₀)'
    W₃ = view(W, Dx+2:Dx+Dy+1, 1:D₀)
    b₂ = view(W, 1, 1:D₀)
    b₃ = view(W, Dx+2:Dx+Dy+1, D₀+1)
    return W₂, W₃, b₂, b₃
end 

#sigmoid function
σ(ξ) = 1/(1+exp(-ξ)) 

#Neural Network
function nn(x, wvec, st)
    W₂, W₃, b₂, b₃ = reshape_params(wvec, st)
    return σ.(W₃*σ.(W₂*x+b₂) + b₃)
end

#loss function 
Ln(y_pred, y_data) = norm(y_pred-y_data)^2

#train the neural network
function train_nn(data, n_train, α, wvec₀, st)
    wvec = wvec₀
    @unpack X,Y,N = data
    @unpack Dx,Dy,D₀ = st
    ∇Ln(wvec, idx, X, Y, st) = ForwardDiff.gradient(wvec->Ln(nn(X[:,idx], wvec, st)[1], Y[idx]), wvec)
    @showprogress for k in 1:n_train
        #choose the sample uniformaly
        idx = rand(1:N)
        
        #gradient descent
        wvec = wvec - α*∇Ln(wvec, idx, X, Y, st)
    end
    return wvec
end

#log pdf of prior, model, posterior
logpprior(wvec, λ, dw) = logpdf(MvNormal(zeros(dw),1/sqrt(λ)), wvec)
logpmodel(y, x, wvec, st) = logpdf(Bernoulli(nn(x,wvec,st)[1]), y)
loglik(X, Y, N, wvec, st) = sum([logpmodel(Y[n], X[:,n], wvec, st) for n in 1:N])

function logppost(wvec, data, model_params)
    @unpack X,Y,N = data
    @unpack λw,dw,st = model_params
    return loglik(X, Y, N, wvec, st) + logpprior(wvec, λw, dw)
end

function logppost(wvec, λw, data, model_params)
    @unpack X,Y,N = data
    @unpack dw,st = model_params
    return loglik(X, Y, N, wvec, st) + logpprior(wvec, λw, dw)
end

#predictive: returns the probability to new data classified to class 1
function ppred(x, wsamps, st)
    _, n_samps = size(wsamps)
    preds = zeros(n_samps)
    for j in 1:n_samps
        preds[j] = exp(logpmodel(1, x, wsamps[:,j], st))
    end
    return mean(preds)
end

#∇θLn(θ)
function ∇θLn(x, y, N, wsamps, θvec, λw, st)
    dw,S = size(wsamps)
    centwvec = wsamps-θvec[1:dw]*ones(S)'
    logpmodels = zeros(S)
    for s in 1:S
        logpmodels[s] = logpmodel(y, x, wsamps[:,s], st)
    end
    arr1 = centwvec .* (ones(dw)*logpmodels')
    arr2 = centwvec .* arr1
    ∇θLnvec = zeros(2*dw)
    for j in 1:dw
        ∇θLnvec[j] = N*exp(-2*θvec[dw+j])*mean(arr1[j,:])-λw*θvec[j]
        ∇θLnvec[dw+j] = N*exp(-2*θvec[dw+j])*mean(arr2[j,:])-N*mean(logpmodels)+1-λw*exp(2*θvec[dw+j])
    end
    return ∇θLnvec
end

#calculate ELBO
function ELBO(X, Y, N, wsamps, θvec, λw, st)
    dw,S = size(wsamps)
    logpmodels = zeros(S)
    for s in 1:S
        for n in 1:N
            logpmodels[s] += logpmodel(Y[n], X[:,n], wsamps[:,s], st)
        end
    end
    return mean(logpmodels)-λw*sum(θvec[1:dw].^2)/2-λw*sum(exp.(2*θvec[dw+1:end]))/2+sum(θvec[dw+1:end])+dw/2+dw*log(λw)/2
end

#sample from approximation distribution r
post_samps(θvec, n_samps, dw) = rand(MvNormal(θvec[1:dw], exp.(θvec[dw+1:2*dw])), n_samps)

#variational inference
function myVI(data, model_params, α, n_train, tol)
    @unpack X,Y,N = data
    @unpack λw, dw, st = model_params
    θvec = vcat(zeros(dw), ones(dw))
    n_samps = 5000
    wsamps = zeros(dw, n_samps)
    history = zeros(n_train)
    history[1] = ELBO(X, Y, N, wsamps, θvec, λw, st)
    @showprogress for k in 2:n_train
        idx = rand(1:N)
        x = X[:,idx]
        y = Y[idx]
        wsamps = post_samps(θvec, n_samps, dw)
        θvec = θvec + α*∇θLn(x, y, N, wsamps, θvec, λw, st)/k
        history[k] = ELBO(X, Y, N, wsamps, θvec, λw, st)
        if abs(history[k]-history[k-1])<tol
            return θvec, history[1:k]
        end
    end
    return θvec, history
end
【Juliaコード6; データ作成】
#create the data
N = 16
X = [
    0.3 0.52 0.3 0.50 0.60 0.7 0.70 0.55  0.85 0.10 0.05 0.20 0.39 0.63 0.86 0.97;
    0.1 0.35 0.9 0.15 0.95 0.2 0.80 0.75  0.55 0.76 0.15 0.45 0.56 0.50 0.80 0.20;
]
Y = vcat(zeros(div(N,2)), ones(div(N,2)))
data = (X=X,Y=Y,N=N)

#size
Dx,N = size(X)
Dy = 1
D₀ = 5
st = (Dx=Dx, Dy=Dy, D₀=D₀)

#plot the data
fig3 = plot_data(X, Y)
savefig(fig3, "figs-VI/fig3.png")
【Juliaコード7; 通常のNeural Networkの訓練】
#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
Ws = stick_params(W₂, W₃, b₂, b₃, st.Dx)
wvec₀ = Ws[:]

#train NN
n_train = Int(1e6)
α = 0.1
@time wvec = train_nn(data, n_train, α, wvec₀, st)

#predict
fig4 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->nn([x1,x2], wvec, st)[1], st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig4 = plot_data(fig4,X,Y)
plot!(title="prediction")
savefig(fig4, "figs-VI/fig4.png")
【Juliaコード8; 確率的変分推論による予測分布】
#initialize NN
W₂, W₃, b₂, b₃ = init_params(st)
Ws = stick_params(W₂, W₃, b₂, b₃, Dx)
wvec₀ = Ws[:]
dw = length(wvec₀)

#model params
λw = 1e-3
model_params = (λw=λw, dw=dw, st=st)

#calculate the variational parameters
α = 0.4
tol = 1e-6
n_train = 2000
@time θvec, history = myVI(data, model_params, α, n_train, tol)

#posterior sample
n_samps = 5000
wsamps = post_samps(θvec, n_samps, dw)

#visualize predictive
fig5 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->ppred([x1,x2], wsamps, st), st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig5 = plot_data(fig5, X, Y)
plot!(title="prediction: Variational Inference")
savefig(fig5, "figs-VI/fig5.png")
【Juliaコード9; ELBOの変化】
fig6 = plot(history, xlabel="iter", title="ELBO", label=false)
savefig(fig6, "figs-VI/fig6.png")
参考文献

      [1] ビショップ, パターン認識と機械学習, 丸善出版, 2019
      [2] 渡辺澄夫, ベイズ統計の理論と方法, コロナ社, 2019
      [3] 須山敦志, ベイズ深層学習, 講談社, 2020
      [4] A.Graves, Practical variational inference for neural networks, Advances in Neural Information Processing Systems, pp.2348-2356, 2011
      [5] M.D.Hoffman, D.M.Blei, C.Wang, J.Paisley, Stochastic variational inference, The Journal of Machine Learning Research, 14(1), pp.1303-1347, 2013
      [6] D.P.Kingma, M.Willing, Auto-encoding variational Bayes, International Conference on Learning Representation, 2014