変分推論

記事の内容


この記事では, 変分推論法について解説します. 最も有名な平均場近似を用いた手法と, ミニバッチ学習に対応した確率的変分推論法を紹介します. 確率的変分推論法は, 元々[5]で提案された手法ですが, ここではもっと広い意味で用いています. 原論文では確率的最適化と自然勾配の考え方を用いていますが, ここでは「確率的最適化をベースとした変分推論法」程度の意味で用いています.

概要

問題設定

変分推論は, 複雑な事後分布を近似する方法です. MCMCとは違い, 最適化がベースです. まず, どういう状況で用いる手法なのか, 問題設定を確認しておきます.

手元にデータ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\subset \mathbb{R}^D\)があるとします.このデータを発生させた分布を推定することが目的です. モデルを\(p(\boldsymbol{x}\mid\boldsymbol{w})\)でモデル化したとします. ここで, \(\boldsymbol{w}\in\mathbb{R}^d\)はパラメータです. このパラメータの事前分布を\(p(\boldsymbol{w})\)とします. すると事後分布が計算できます.

\begin{equation} p(\boldsymbol{w}\mid X) \propto p(X\mid\boldsymbol{w}) p(\boldsymbol{w}) \end{equation}

この事後分布の平均を計算したり, モデルを事後分布で重みづけて予測分布を計算するには事後分布が複雑すぎるとします. 解析的に扱えない場合, MCMCや変分推論法が用いられます. 次に, 変分推論法の基本的なアイデアを述べます.

基本的なアイデア

変分推論法では, 事後分布を近似する分布を構成します. 目標は, 次のような分布を, 予め定めた分布族\(\mathcal{R}\)から見つけることです.

\begin{equation} p(\boldsymbol{w}\mid X) \simeq r(\boldsymbol{w}),\quad r\in \mathcal{R} \end{equation}

ここで, 次の2点が問題になります.

  • 近似分布は事後分布にどういう意味で近いのか?
  • 分布族\(\mathcal{R}\)はどのように設定するか?

まず1点目. 分布間の近さは, Kullback-Leiblerダイバージェンスで測るのが自然です. 近似分布と事後分布の間のKullback-Leiblerダイバージェンスが最小になるように近似分布を定めます. もう少し詳しく見てみましょう. 次のような数量を定めます.

\begin{equation} \mathcal{L}[r] = -\int r(\boldsymbol{w})\log \frac{r(\boldsymbol{w})}{p(X,\boldsymbol{w})} d\boldsymbol{w} \end{equation}

すると, \(r(\boldsymbol{w})\)と\(p(\boldsymbol{w}\mid X)\)の間のKullback-Leiblerダイバージェンスは, 次のように書けます.

\begin{equation} D_{\mathrm{KL}}[r(\boldsymbol{w})\|p(\boldsymbol{w}\mid X)] = -\mathcal{L}[r] + \log p(X) \end{equation}

Kullback-Leiblerダイバージェンスが非負であることから, 次式が得られます.

\begin{equation} \mathcal{L}[r] \leq \log p(X) \end{equation}

\(\mathcal{L}\)が周辺尤度(エビデンス)の下界であることから, この量をELBO(Evidence Lower bound)と呼びます. 当面の間の目標はKullback-Leiblerダイバージェンスの最小化ですが, これはELBOの最大化と等価です. なお, ELBOの符号反転は変分自由エネルギーと呼ばれます.

次に2点目です. 分布族の決め方はいくつかあります. 最も有名な方法が平均場近似です. 平均場近似では, 近似分布を具体的には指定せず, 事後分布が都合よく分解できると仮定します. 分解できるという仮定のみでも, 事後分布は"いい感じ"に求まります. 他には, 予め分布を指定する方法があります. 例えば予め正規分布で近似できると仮定して, パラメータを決めてやります. 以下, この2つのアイデアを詳しく掘り下げていきます.

平均場近似

平均場近似の考え方

平均場近似では, 近似分布が属する分布族を次のように定めます.

\begin{equation} \mathcal{R} = \left\{ r(\boldsymbol{w})\mid \text{ある確率(密度)関数\(r_1,\cdots,r_K\)が存在して, }r(\boldsymbol{w}) = \prod_{i=1}^Kr_i(\boldsymbol{w_i})\right\} \end{equation}

このような分布族から持ってきた近似分布\(r\)について, 各\(i=1,\cdots,K\)に対して, ELBOは次のように書けます.

\begin{equation} \mathcal{L}[r] = -D_{\mathrm{KL}}\left[r_i(\boldsymbol{w}_i) \| \exp\left(\mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right]\right) \right] + \mathrm{const} \end{equation}

ここで, 上式の期待値は, 次のようにとります.

\begin{equation} \mathrm{E}_{\prod_{j\neq i}r_j}\left[ \log p(X,\boldsymbol{w}) \right] = \int \left(\prod_{j\neq i}r_i(\boldsymbol{w}_i)\right) \log p(X,\boldsymbol{w}) d\boldsymbol{w}^{\setminus i} \end{equation}

2つの分布が一致するときにKLダイバージェンスが最小になることから, 各\(j\neq i\)に対して因子\(r_j\)が定まっているときに, \begin{equation} r_i(\boldsymbol{w}_i) \propto \exp \left( \mathrm{E}_{\prod_{j\neq i}r_i} [\log p(X,\boldsymbol{w})]\right) \end{equation} と因子\(r_i\)を更新することで, ELBOが最大になります. 適当な初期分布から初めて更新を繰り返していけば, ELBOがどこかで頭打ちになるはずです. このとき近似分布は事後分布に十分近いと期待できます.

正規分布の推論

ここでは例として, 正規分布の平均と精度の事後分布を推定します. モデルが正規分布で, パラメータが平均パラメータと精度パラメータにあたります. 次のような生成過程を考えます.

\begin{equation} \mu\sim\mathrm{N}(0,1) , \quad \lambda \sim\mathrm{Gamma}(1,\beta),\quad x\sim\mathrm{N}(\mu,\lambda^{-1}) \end{equation}

\(\beta\)は定数です. このとき事後分布は次のようになります.

\begin{equation} p(\mu,\lambda\mid X) \propto \lambda^{\frac{N}{2}} e^{-\beta\lambda} \exp\left\{ -\frac{1}{2}\mu^2-\frac{\lambda}{2}\sum_{n=1}^N(x_n-\mu)^2\right\} \end{equation}

例えばこの分布の平均を求めるのは困難です. ということで変分近似を行います. 次のように分解されるとします.

\begin{equation} p(\mu,\lambda\mid X) \simeq r(\mu)r(\lambda) \end{equation}

先ほど導出した期待値計算を含む分布を導出すると, 以下のようになります.

\begin{align} r(\mu) &\propto \exp\left\{ -\frac{1}{2}\mu^2-\frac{1}{2}\mathrm{E}_\lambda [\lambda] \sum_{n=1}^N(x_n-\mu)^2 \right\} \\ r(\lambda) &\propto \exp\left\{ \frac{N}{2}\log\lambda-\beta\lambda - \frac{\lambda}{2}\sum_{n=1}^N\mathrm{E}_\mu [(x_n-\mu)^2]\right\} \end{align}

\(\mu\)の近似分布は正規分布, \(\lambda\)の近似分布はガンマ分布であることが分かります. それぞれの分布を次のようにおきます.

\begin{equation} \mathrm{N}(\hat{\mu}, \hat{\lambda}) ,\quad \mathrm{Gamma}(\hat{\alpha}, \hat{\beta}) \end{equation}

すると, 次のように書けます.

\begin{align} \hat{\lambda} &= 1+\frac{N\hat{\alpha}}{\hat{\beta}} \\ \hat{\mu} &= \frac{\hat{\alpha}}{\hat{\lambda}\hat{\beta}}\left(\sum_{n=1}^Nx_n\right) \\ \hat{\alpha} &= 1+\frac{N}{2} \\ \hat{\beta} &= \beta + \frac{1}{2}\left(\sum_{n=1}^Nx_n^2\right)-\hat{\mu}\left( \sum_{n=1}^Nx_n\right) + \frac{N}{2}\hat{\mu}^2 + \frac{N}{2\hat{\lambda}} \end{align}

これを更新式として, 次のようなアルゴリズムが得られます. 実際には, \(\hat{\alpha}\)が上記の更新で変化しないので, ループの外に出して実装します.

【正規分布に対する変分推論アルゴリズム】

Initialize \(\hat{\lambda}^{(0)}, \hat{\mu}^{(0)}, \hat{\alpha}^{(0)}, \hat{\beta}^{(0)}\), and set \(k=0, \epsilon\).
while\( |\mathcal{L}[r^{(k+1)}]-\mathcal{L}[r^{(k)}] | <\epsilon\)
update \(\hat{\lambda}^{(k+1)}, \hat{\mu}^{(k+1)}, \hat{\alpha}^{(k+1)}, \hat{\beta}^{(k+1)}\).
\(k=k+1\)
end

ちなみに, ELBOの値は以下のようになります. 少し長いです. \(\Gamma\)はガンマ関数, \(\psi\)はディガンマ関数です.

\begin{align} \mathcal{L}[r] = -\frac{1}{2}\log\hat{\lambda} + \frac{1}{2\hat{\lambda}}-\hat{\alpha}\log\hat{\beta} + \log\Gamma (\hat{\alpha}) - (\hat{\alpha}-1)(\psi (\hat{\alpha})-\log\hat{\beta}) + \hat{\alpha} + \frac{N}{2}\left( \psi (\hat{\alpha})-\log (2\pi\hat{\beta})\right)\\ - \frac{\hat{\alpha}}{2\hat{\beta}}\left( \sum_{n=1}^Nx_n^2-2\hat{\mu}\sum_{n=1}^Nx_n + N\hat{\mu}^2 + \frac{N}{\hat{\lambda}}\right) - \frac{1}{2}\left( \hat{\mu}^2+\frac{1}{\hat{\lambda}}\right) + \log\beta - \frac{\hat{\alpha}\beta}{\hat{\beta}} \end{align}

以下に実験結果を示します. 真の分布は混合正規分布\(0.5\mathrm{N}(-1.4, 1.1^2)+0.5\mathrm{N}(1.5, 0.9^2)\)としました. この分布からデータを\(N=20\)点発生させました. \(\beta=10, \epsilon=10^{-6}\)とします. また, 初期値は\(\hat{\mu}=0.0, \hat{\lambda}=1.0, \hat{\beta}=\beta\)とします. 左下図の等高線は, 近似事後分布を表します. また, Gibbs samplerによって(真の)事後分布からサンプルし, プロットしました(緑点). 実際の事後分布と近似事後分布は一致しているように見えます. また, 右下図は, ELBOの変化の様子を示しました. 7回目の更新で収束しています.

【コード3の実行結果】

近似事後分布からサンプルを発生させて, 予測分布を計算しました(青線). 緑色が真の分布です. また, Gibbs samplerで得た事後分布からのサンプルを用いて, 真の予測分布を計算しました(赤色). 予測分布同士はほぼ一致しています.

【コード4の実行結果】

確率的変分推論

確率的変分推論の考え方

確率的変分推論は, 冒頭で述べたとおり, 少し広い意味で解説します. 確率的最適化をベースとしたミニバッチ版の変分推論を指します. 平均場近似とは異なり, 予め近似分布が属する分布族\(\mathcal{R}\)の分布形を指定します. すると, 先程のような期待値の指数関数とは異なる形が導出されます. 例として, 次のような正規分布を仮定します.

\begin{equation} \mathcal{R} = \left\{ r_{\boldsymbol{\theta}}(\boldsymbol{w}) \mid r_{\boldsymbol{\theta}}(\boldsymbol{w}) = \prod_{i=1}^d \left( \frac{\lambda_i}{2\pi}\right) ^d\exp\left\{ -\frac{\lambda_i}{2}(w_i-\mu_i)^2\right\} \right\} \end{equation}

ここで, 近似分布のパラメータを\(\boldsymbol{\theta}=\{ (\mu_i, \lambda_i) \}_{i=1}^d\)とおきました. 確率的変分推論では, ELBOを全データではなく一部のデータ(ミニバッチ)で近似します. ここではデータ点1つからなるミニバッチを考えます. より一般に複数データをミニバッチとする場合には各データから計算できるELBOを平均します[3]. データ点\((\boldsymbol{x}_n, \boldsymbol{y}_n)\)を全データからランダムにサンプルします. このデータを用いて, ELBOを次のように近似します.

\begin{equation} \mathcal{L}[r_{\boldsymbol{\theta}}] \simeq \mathcal{L}_n[r_{\boldsymbol{\theta}}] = N\left( \mathrm{E}_{r_{\boldsymbol{\theta}}}\left[ \log p(\boldsymbol{x}_n\mid \boldsymbol{w})\right]\right) - D_{\mathrm{KL}}\left[ r_{\boldsymbol{\theta}(\boldsymbol{w})} \| p(\boldsymbol{w})\right] \end{equation}

この近似値\(\mathcal{L}_n\)を最大化することで, 近似分布のパラメータ\(\boldsymbol{\theta}\)を定めます. \(\mathcal{L}_n\)の微分値が求まれば, 最急"上昇"によって最大化できます. ということで, 上の近似値を微分します. まず第1項の期待値の微分を考えます.

\begin{equation} \nabla_{\boldsymbol{\theta}}\mathrm{E}_{r_{\boldsymbol{\theta}}}[\log p(\boldsymbol{x}_n \mid \boldsymbol{w})] = \mathrm{E}_{r_{\boldsymbol{\theta}}}\left[ \log p(\boldsymbol{x}_n\mid \boldsymbol{w}) \nabla_{\boldsymbol{\theta}}\log r_{\boldsymbol{\theta}}(\boldsymbol{w})\right] \simeq \frac{1}{S}\sum_{s=1}^S \log p\left(\boldsymbol{x}_n\mid \boldsymbol{w}^{(s)}\right) \nabla_{\boldsymbol{\theta}}\log r_{\boldsymbol{\theta}}\left(\boldsymbol{w}^{(s)}\right) \end{equation}

ここで, \(\boldsymbol{w}^{(s)}\)は\(r_{\boldsymbol{\theta}}\)からのサンプルとします. さらに, 期待値内の微分計算については次のようになります.

\begin{equation} \frac{\partial}{\partial\mu_j} \log r_{\boldsymbol{\theta}}(\boldsymbol{w}) = \lambda_j(w_j-\mu_j) ,\quad \frac{\partial}{\partial\lambda_j}\log r_{\boldsymbol{\theta}}(\boldsymbol{w}) = \frac{1}{2}\left\{ \frac{1}{\lambda}_j - (w_j-\mu_j)^2\right\}\quad (j=1,\cdots,d) \end{equation}

次に, 第2項の微分を求めます. ここでは, 事前分布を次のような正規分布に限定します.

\begin{equation} \boldsymbol{w} \sim \mathrm{N}(\boldsymbol{0} , \lambda_w^{-1}I_d) \end{equation}

ここで, \(\lambda_w\)は定数とします. すると, 第2項は解析的に計算できて, 次のようになります.

\begin{equation} D_{\mathrm{KL}}\left[r_{\boldsymbol{\theta}}(\boldsymbol{w}) \| p(\boldsymbol{w})\right] = \frac{1}{2}\sum_{i=1}^d\log\left(\frac{\lambda_i}{2\pi}\right) - \frac{d}{2} - \frac{d}{2}\log\left( \frac{\lambda_w}{2\pi}\right) + \frac{\lambda_w}{2}\sum_{i=1}^d\left( \mu_i^2 + \frac{1}{\lambda_i}\right) \end{equation}

これをパラメータ値に関して微分することで, 次式を得ます.

\begin{equation} \frac{\partial}{\partial\mu_j}D_{\mathrm{KL}}\left[ r_{\boldsymbol{\theta}}(\boldsymbol{w}) \|p(\boldsymbol{w})\right] = \lambda_w\mu_j,\quad \frac{\partial}{\partial\lambda_j}D_{\mathrm{KL}}\left[ r_{\boldsymbol{\theta}}(\boldsymbol{w}) \|p(\boldsymbol{w})\right] = \frac{1}{2\lambda_j} + \frac{\lambda_w}{2\lambda_j^2} \quad (j=1,\cdots,d) \end{equation}

以上の結果から, パラメータ値を次のように更新すればELBOが最大になると期待できます.

\begin{equation} \boldsymbol{\theta}^{(k+1)} = \boldsymbol{\theta}^{(k)} + \alpha_k \nabla_{\boldsymbol{\theta}} \mathcal{L}_n[r_{\boldsymbol{\theta}}^{(k)}] \end{equation}

更新幅\(\alpha_k\)は例えば\(\alpha_k=\frac{1}{k}\)のようにとります.

Bayesian Neural Networkの分類問題

上のような考え方を用いて実験します. ここではBayesian Neural Networkによる二値分類を行います. まず, 次のようなNeural Networkを考えます.

\begin{equation} \Phi(\boldsymbol{x}, \boldsymbol{w} ) = \sigma \left( W_3\sigma\left( W_2 \boldsymbol{x} + \boldsymbol{b}_2\right)+ \boldsymbol{b}_3 \right) \end{equation}

\(\sigma\)はシグモイド関数とします. また, 中間層のユニット数を\(D_0=5\)とします. さらに, 次のようなモデルを考えます.

\begin{equation} p(\boldsymbol{y} \mid \boldsymbol{x},\boldsymbol{w} ) = \boldsymbol{y}^{\mathrm{T}}\Phi(\boldsymbol{x}, \boldsymbol{w}) \end{equation}

ここで, \(\boldsymbol{w}\)は次のようにパラメータをまとめたベクトルとします.

\begin{equation} \boldsymbol{w} = \mathrm{vec}(\{ W_2, W_3, \boldsymbol{b}_2, \boldsymbol{b}_3\} ) \in\mathbb{R}^d \end{equation}

\(\boldsymbol{x}\)は平面上の点座標を表します. \(\boldsymbol{y}\)はラベルデータで, one-hot表現のベクトルです. クラス1(赤色)に分類される場合には\([1,0]^{\mathrm{T}}\)とし, クラス2(青色)に分類される場合には\([0,1]^{\mathrm{T}}\)とします. 次のようなデータが得られたとします.

【コード6の実行結果】

これらの点と対応するラベルをデータサイズ\(N=14\)のデータセット,\(X=\{\boldsymbol{x}_n\}_{n=1}^N, Y=\{\boldsymbol{y}_n\}_{n=1}^N\)で表します. 事前分布は次の正規分布とします.

\begin{equation} \boldsymbol{w} \sim \mathrm{N}\left( \boldsymbol{0}, \lambda_w^{-1}I_d\right) \end{equation}

近似する分布族\(\mathcal{R}\)は先ほど定めた正規分布の族とします. 前セクションのアイデアから, 次のようなアルゴリズムが得られます.

【BNNの確率的変分推論アルゴリズム】

Initialize weights \(\boldsymbol{\theta}^{(0)}\), step size \(\alpha^{(k)}\) and \(k=0, \epsilon\).
while\( |\mathcal{L}[r_{\boldsymbol{\theta}}^{(k+1)}]-\mathcal{L}[r_{\boldsymbol{\theta}}^{(k)}] | <\epsilon\)
Select \(\boldsymbol{x}_n, \boldsymbol{y_n}\) uniformly at random.
\(\boldsymbol{\theta}^{(k+1)} = \boldsymbol{\theta}^{(k)} + \alpha^{(k)}\nabla_{\boldsymbol{\theta}} \mathcal{L}_{n}[r_{\boldsymbol{\theta}}^{(k)}]\)
\( k=k+1\)
end

ELBOの値は厳密には計算できないので, 適当な回数だけ更新することにします. 今回の実験では1000回ほど更新しています. また, 更新幅は\(\alpha_k=\frac{0.1}{k}\)とし, \(\lambda_w\)はWAICが最も小さくなるような値に設定しました. 計算の結果を以下に示します. クラス1(赤点)に分類される予測確率が高い部分が赤色で塗りつぶされます. 実験はうまくいっていないようです. カラーバーを見ると0.5付近の値しか出ていないので, ほとんど全域に渡って赤とも青とも言えない状態です.

【コード6の実行結果】

実験がうまくいかなかった原因は様々あると思います. 近似が数カ所入っていることや, ステップサイズ, 更新回数などです. プログラムにバグがある可能性もありますね... また何か分かったら実験を追加します.

コード

【Juliaコード1; インポート】
using Random
using Statistics
using Distributions
using LinearAlgebra
using SpecialFunctions
using ProgressMeter
using Plots
pyplot()
【Juliaコード2; 平均場近似: 関数定義】
#ELBO
function ELBO1(μhat, λhat, αhat, βhat, X, β)
    SN = sum(X)
    SsqN = sum(X.^2)
    return (-log(λhat)/2+1/λhat/2-αhat*log(βhat)+log(gamma(αhat))-(αhat-1)*(digamma(αhat)-log(βhat))+αhat 
        + N*(digamma(αhat)-log(2*pi*βhat))/2-αhat*(SsqN-2*μhat*SN+N*μhat^2+N/λhat)/βhat/2-(μhat^2+1/λhat)/2
        + log(β)-β*αhat/βhat)
end

#variational inference : mean field approximation
function my_VI1(μhat₀, λhat₀, βhat₀, max_iter, ϵ, X, β)
    SN = sum(X)
    SsqN = sum(X.^2)
    μhat = μhat₀
    λhat = λhat₀
    αhat = N/2+1
    βhat = βhat₀
    
    #ELBO
    elbos = zeros(max_iter+1)
    elbos[1] = ELBO1(μhat, λhat, αhat, βhat, X, β)
    
    for i in 1:max_iter
        λhat = 1 + N*αhat/βhat
        μhat = SN * αhat/λhat/βhat
        βhat = β + SsqN/2 - μhat*SN + N*μhat^2/2 + N/λhat/2
        elbo = ELBO1(μhat, λhat, αhat, βhat, X, β)
        elbos[i+1] = elbo
        if abs(elbos[i+1]-elbos[i])<ϵ
            return μhat, λhat, αhat, βhat, elbos[1:i+1]
            break
        end
    end
    return μhat, λhat, αhat, βhat, elbos
end

#variational posterior sample
function post_sample1(μhat, λhat, αhat, βhat, n_samps)
    μsamps = rand(Normal(μhat, 1/√λhat), n_samps)
    λsamps = rand(Gamma(αhat, 1/βhat), n_samps)
    return μsamps, λsamps
end

#predictive distribution
function pred1(x, μsamps, λsamps)
    n_samps = length(μsamps)
    preds = zeros(n_samps)
    for i in 1:n_samps
        preds[i] = pdf(Normal(μsamps[i], 1/√λsamps[i]), x)
    end
    return mean(preds)
end

#Gibbs sampler : used for comparing the true posterior and variational posterior
function my_Gibbs1(μ₀, λ₀, n_samps, n_burnin, X, β)
    SN = sum(X)
    μsamps = zeros(n_samps)
    λsamps = zeros(n_samps)
    μsamps[1] = μ₀
    λsamps[1] = λ₀
    for i in 2:n_samps
        μsamps[i] = rand(Normal(λsamps[i-1]*SN/(λsamps[i-1]*N+1), 1/√(λsamps[i-1]*N+1)))
        λsamps[i] = rand(Gamma(N/2+1, 1/(β+sum((X.-μsamps[i-1]).^2)/2)))
    end
    return μsamps[n_burnin:end], λsamps[n_burnin:end]
end
【Juliaコード3; 平均場近似】
Random.seed!(42)

#create the data
μ₁ = -1.4
μ₂ = 1.5
σ₁ = 1.1
σ₂ = 0.9
N = 20
mixture_normal = MixtureModel([Normal(μ₁, σ₁), Normal(μ₂, σ₂)])
mixture_normal_pdf(x) = pdf(mixture_normal, x)
X = rand(mixture_normal, N)

#infernce 
β = 1e1
μhat₀ = 0.0
λhat₀ = 1.0
βhat₀ = β
max_iter = 1000
ϵ = 1e-6
μhat, λhat, αhat, βhat, elbos = my_VI1(μhat₀, λhat₀, βhat₀, max_iter, ϵ, X, β)

#approximated posterior distribution
r(μ, λ, μhat, λhat, αhat, βhat) = pdf(Normal(μhat, 1/√λhat), μ) * pdf(Gamma(αhat, 1/βhat), λ)
r(μ, λ) = r(μ, λ, μhat, λhat, αhat, βhat)

#Gibbs sampler : true posterior
μ₀ = 0.0
λ₀ = 1.0
n_samps = 5000
n_burnin = div(n_samps, 10)
μsamps, λsamps = my_Gibbs1(μ₀, λ₀, n_samps, n_burnin, X, β)

#plot the posterior
p1 = plot(-2:0.1:2, 0.01:0.01:0.8, r, st=:contour, xlabel="μ", ylabel="λ", title="approximated posterior", xlim=[-2,2], ylim=[0.01, 0.8])
plot!(μsamps, λsamps, st=:scatter, alpha=0.2, label="Gibbs sampler", markerstrokewidth=0.2, color=:green)
p2 = plot(elbos, title="change of ELBO", xlabel="iter", ylabel="ELBO", label=false, marker=:circle, markerstrokewidth=0, markersize=6)
fig1 = plot(p1, p2, size=(1000, 400))
savefig(fig1, "figs-VI/fig1.png")
【Juliaコード4; 予測分布の計算】
#variational prediction
n_samps = 5000
μsamps_var, λsamps_var = post_sample1(μhat, λhat, αhat, βhat, n_samps)
pred1_var(x) = pred1(x, μsamps_var, λsamps_var)

#true predictive distribution 
pred1(x) = pred1(x, μsamps, λsamps)

#plot
xs = -3:0.1:3
fig2 = plot(xs, pred1_var, title="variational prediction", label="predictive")
plot!(xs, pred1, label="predictive(Gibbs)")
plot!(xs, mixture_normal_pdf, label="true")

savefig(fig2, "figs-VI/fig2.png")
【Juliaコード5; 確率的変分推論: 関数定義】
#initialize the parameter
function init_params(Dx, D₀, Dy)
    W₂ = rand(D₀, Dx)
    W₃ = rand(Dy, D₀)
    b₂ = zeros(D₀)
    b₃ = zeros(Dy)
    return W₂, W₃, b₂, b₃
end

#stick the weights and biases to a large matrix
function stick_params(W₂, W₃, b₂, b₃)
    tmp1 = vcat(b₂', W₂')
    tmp2 = hcat(tmp1, zeros(Dx+1))
    tmp3 = hcat(W₃, b₃)
    return vcat(tmp2, tmp3)
end

#devide the paramters vector to weights and biases
function reshape_params(params_vec, Dx, D₀, Dy)
    W = reshape(params_vec, (Dx+Dy+1, D₀+1))
    W₂ = view(W, 2:Dx+1, 1:D₀)'
    W₃ = view(W, Dx+2:Dx+Dy+1, 1:D₀)
    b₂ = view(W, 1, 1:D₀)
    b₃ = view(W, Dx+2:Dx+Dy+1, D₀+1)
    return W₂, W₃, b₂, b₃
end 

#sigmoid function
σ(ξ) = 1/(1+exp(-ξ)) 

#Neural Network
function Φ(x, params_vec, Dx, D₀, Dy)
    W₂, W₃, b₂, b₃ = reshape_params(params_vec, Dx, D₀, Dy)
    return σ.(W₃*σ.(W₂*x+b₂) + b₃)
end

#plot the data and return the figure
function plot_data(X, Y)
    _,N = size(X)
    fig = plot(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, title="data", legend=false)
    for k in 1:N
        if Y[1,k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

function plot_data(fig, X, Y)
    _,N = size(X)
    fig = plot!(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, legend=false)
    for k in 1:N
        if Y[1,k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

#plot the probability to classify new data to class 1(red point)
function plot_prob_1(X, Y, pred_func, title)
    T = 50
    X1s = range(0,1,length=T)
    X2s = range(0,1,length=T)
    preds = zeros(T,T)
    @showprogress for j in 1:T
        x2 = X2s[j]
        for i in 1:T
            x1 = X1s[i]
            preds[i,j] = pred_func([x1,x2])
        end
    end
    fig = heatmap(X1s, X2s, preds, c=cgrad(:coolwarm), alpha=0.6, title=title)
    return plot_data(fig, X, Y)
end

#prior for each element
prior_dist(λw) = Normal(0, 1/√λw)
logpprior(w, λw) = logpdf(prior_dist(λw), w)
pprior(w, λw) = exp(logpprior(w, λw))

#model
logpmodel(y, x, wvec, Φ) = log(dot(y, Φ(x,wvec))/sum(Φ(x,wvec)))
pmodel(y, x, wvec, Φ) = exp(logpmodel(y, x, wvec, Φ))

#liklihood
loglik(wvec, Φ, X, Y, N) = sum([logpmodel(Y[:,n], X[:,n], wvec, Φ) for n in 1:N])

#posterior
logppost(wvec, λw, Φ, X, Y, N) = sum(logpprior(wvec,λw)) + loglik(wvec, Φ, X, Y, N)
ppost(wvec, λw, Φ, X, Y, N) = exp(logppost(wvec, λw, Φ, X, Y, N))

#predictive: returns the probability to new data classify to class 1
function ppred(x, wsamps, pmodel)
    _,n_samps = size(wsamps)
    preds = zeros(n_samps)
    for i in 1:n_samps
        preds[i] = pmodel([1,0], x, wsamps[:,i])
    end
    return mean(preds)
end

#∇logr
function ∇logr(wvec, params)
    d = length(wvec)
    μs = params[1:d]
    λs = params[d+1:end]
    ∇μs = λs .* (wvec - μs)
    ∇λs = (1 ./λs - (wvec - μs).^2)/2
    return vcat(∇μs, ∇λs)
end

#function : Er[f]
f(x, y, params, wvec, logpmodel) = logpmodel(y, x, wvec) * ∇logr(wvec, params)

#sample func
function sample_from_r(params)
    d = div(length(params), 2)
    samp = zeros(d)
    for j in 1:d
        μj = params[j]
        λj = params[d+j]
        samp[j] = rand(Normal(μj, 1/√λj))
    end
    return samp
end
    
#MC approximation : expectation with respect to r(w)
function Er(x, y, params, n_samps, f)
    d = div(length(params), 2)
    Esamps = zeros(2*d, n_samps)
    samp = zeros(d)
    for i in 1:n_samps
        samp = sample_from_r(params)
        Esamps[:,i] = f(x, y, params, n_samps, logpmodel)
    end
    return mean(Esamps, dims=2)
end

#gradient Kullback-Leibler divergence with respect to variational parameter
function ∇DKL(params, λw)
    d = div(length(params), 2)
    μs = params[1:d]
    λs = params[d+1:end]
    ∇μDKL = λw * μs
    ∇λDKL = (1 ./λs - λw./λs.^2)/2
    return vcat(∇μDKL, ∇λDKL)
end

#variational infernce
function myVI(X, Y, λw, n_train, n_samps, d, ϵ)
    _, N = size(X)
    params = vcat(zeros(d), ones(d))
    @showprogress for k in 1:n_train
        idx = rand(1:N)
        x = X[:,idx]
        y = Y[:,idx]
        params += ϵ*(N * Er(x, y, params, n_samps, f) - ∇DKL(params, λw))/k
    end
    return params
end

#sample from posterior
function sample_from_posterior(params, n_samps)
    d = div(length(params), 2)
    wsamps = zeros(d, n_samps)
    for j in 1:d
        μj = params[j]
        λj = params[d+j]
        wsamps[j,:] = rand(Normal(μj, 1/√λj), n_samps)
    end
    return wsamps
end

#calculate WAIC
function calc_WAIC(X, Y, wsamps, pmodel, logpmodel)
    _,N = size(X)
    _,n_samps = size(wsamps)
    logpreds = zeros(N)
    vars = zeros(N)
    for n in 1:N
        logpreds[n] = -log(mean([pmodel(Y[:,n], X[:,n], wsamps[:,i]) for i in 1:n_samps]))
        vars[n] = var([logpmodel(Y[:,n], X[:,n], wsamps[:,i]) for i in 1:n_samps])
    end
    return mean(logpreds) + mean(vars)
end

#model selection
function model_selection(X, Y, λws, n_samps, ϵ, pmodel, logpmodel)
    n_train = 1000
    M = length(λws)
    WAICs = zeros(M)
    for i in 1:M
        params = myVI(X, Y, λs[i]w, n_train n_samps, d, ϵ)    
        wsamps = sample_from_posterior(params, n_samps)
        WAICs[i] = calc_WAIC(X, Y, wsamps, pmodel, logpmodel)
    end
    return WAICs
end
【Juliaコード6; 確率的変分推論】
#create the data
Random.seed!(46)
N = 14
X = hcat(rand(Beta(1.1,3), 2, div(N,2)), rand(Beta(1.9,1.1), 2, div(N,2)))
Y = vcat(vcat(ones(div(N,2)), zeros(div(N,2)))', vcat(zeros(div(N,2)), ones(div(N,2)))')

#plot the data
fig3 = plot_data(X, Y)
savefig(fig3, "figs-VI/fig3.png")

#initialize the parameters
Random.seed!(42)
Dx, N = size(X)
Dy, N = size(Y)
D₀ = 5
W₂, W₃, b₂, b₃ = init_params(Dx, D₀, Dy)
Ws = stick_params(W₂, W₃, b₂, b₃)
wvec₀ = Ws[:]
d = length(wvec₀)

#Neural Network 
Φ(x, wvec) = Φ(x, wvec, Dx, D₀, Dy) 

#model, posterior, predictive
logpmodel(y, x, wvec) = logpmodel(y, x, wvec, Φ)
pmodel(y,x,wvec) = exp(logpmodel(y, x, wvec))
ppred(x, wsamps) = ppred(x, wsamps, pmodel)

#model selection
ϵ = 0.1
λws = 1 ./ collect(0.1:0.1:1.0) .^2
WAICs = model_selection(X, Y, λws, 5000, ϵ, pmodel, logpmodel)
λw = λws[argmin(WAICs)]
println("λw=$(λw) minimize WAIC (std=$(1/sqrt(λw))). ")

#calculate the variational parameters
n_train = 1000
params = myVI(X, Y, λw, n_train,n_samps, d, ϵ)

#posterior sample
n_samps = 5000
wsamps = sample_from_posterior(params, n_samps)

#plot the result
fig4 = plot_prob_1(X, Y, x->ppred(x, wsamps), "prediction: variational inference")
savefig(fig4, "figs-VI/fig4.png")
参考文献

      [1]ビショップ, パターン認識と機械学習, 丸善出版, 2019
      [2]渡辺澄夫, ベイズ統計の理論と方法, コロナ社, 2019
      [3]須山敦志, ベイズ深層学習, 講談社, 2020
      [4]A.Graves, Practical variational inference for neural networks, Advances in Neural Information Processing Systems, pp.2348-2356, 2011
      [5]M.D.Hoffman, D.M.Blei, C.Wang, J.Paisley, Stochastic variational inference, The Journal of Machine Learning Research, 14(1), pp.1303-1347, 2013
      [6]D.P.Kingma, M.Willing, Auto-encoding variational Bayes, International Conference on Learning Representation, 2014