期待値伝播法

記事の内容


  • 概要
    • アルゴリズムの概要
    • アルゴリズム導出の手順
  • 実験
    • 問題設定
    • 実験結果
  • コード

期待値伝播法, てこずったので記録として残しておきます.

概要

アルゴリズムの概要

期待値伝播法[1]は, 事後分布を近似する近似推論法です. 変分推論法と少し似ています. 事後分布の積の形を近似分布で再現して, KLダイバージェンスの意味で近づけます. 今回使ってみた感想としては, ちょっと不便かなって感じです. 以下に, メリットとデメリットを書いておきます.

期待値伝播法のメリットとデメリット(個人の感想)

【メリット】
  • 実行時間は早い.
【デメリット】
  • 事後分布と近似分布の間のKLダイバージェンスが最小になるという保証はない.
  • 更新式の導出は面倒.
  • 適用できるモデルは多くはなさそう?.

アルゴリズムの概要を述べます. モデル\(p(\boldsymbol{x}\mid\boldsymbol{\theta})\), 事前分布\(p(\boldsymbol{\theta})\), データ\(X=\{\boldsymbol{x}_n\}_{n=1}^N\)に対して, 以下の事後分布を考えます.

\begin{align} p(\boldsymbol{\theta}\mid X) = \frac{1}{p(X)}\prod_{n=0}^Nt^{(n)}(\boldsymbol{\theta}),\quad t^{(n)}(\boldsymbol{\theta}) = \begin{cases}p(\boldsymbol{x}_n\mid\boldsymbol{\theta})\quad &(n=1,\cdots,N)\\p(\boldsymbol{\theta})\quad &(n=0)\end{cases} \end{align}

この事後分布を近似する方法を考えます. 以下のような形の近似分布は自然な選択です.

\begin{equation} r(\boldsymbol{\theta}) \propto \prod_{n=0}^Nr^{(n)}(\boldsymbol{\theta}) \end{equation}

各因子\(\)は予め選択した分布の確率密度関数とします. 期待値伝播法では, この因子を反復的に更新していきます. アルゴリズムの大まかな手順は以下の通りです.

【期待値伝播法の概要】

  1. 更新したい因子を除いた, 残りの因子だけを取り出す. 実際には, 両辺のパラメータを比較して更新式を導出する.
  2. \begin{equation} r^{\setminus n}(\boldsymbol{\theta}) \propto \frac{r(\boldsymbol{\theta})}{r^{(n)}(\boldsymbol{\theta})} \end{equation}
  3. モデルの情報を取り込んだ事後分布の代わりの分布を用意する. \(\hat{p}\)の分布を特定できれば良いが, モーメントが分かれば十分なケースもある.
  4. \begin{equation} \hat{p}^{(n)}(\boldsymbol{\theta}) = \frac{1}{Z_n} r^{\setminus n}(\boldsymbol{\theta})t^{(n)}(\boldsymbol{\theta}) \end{equation}
  5. 近似分布を用意した分布にKLダイバージェンスの意味で近づける. 近似分布が指数型分布族の場合には, この後述べるモーメントマッチングを用いて更新式を導出する.
  6. \begin{equation} r=\underset{r\in\mathcal{R}}{\mathrm{argmin}}D_{\mathrm{KL}}[\hat{p}^{(n)}\| r] \end{equation}
  7. 因子を更新する. 実際には, 両辺のパラメータを比較して更新式を導出する.
  8. \begin{equation} r^{(n)}(\boldsymbol{\theta}) \propto \frac{r(\boldsymbol{\theta})}{r^{\setminus n}(\boldsymbol{\theta})} \end{equation}

以下にアルゴリズムを示します. 実際には近似分布のパラメータだけ調節すれば計算できます.

【期待値伝播法】

Initialize approximation distribution \(\{r^{(n)}\}_{n=0}^N\) and \(r\).
for \(k=1,2,\cdots,\)
for \(n=0,\cdots,N\)
Compute \(r^{\setminus n}\)
Compute \(Z_n\) and \(\hat{p}^{(n)}\)
Compute \(r=\underset{r\in\mathcal{R}}{\mathrm{argmin}}D_{\mathrm{KL}}[\hat{p}^{(n)}\| r]\)
Compute \(r^{(n)}\)
end
end

近似分布が指数型分布族の場合には, KLダイバージェンス最小化の部分は解析的に更新式が導出できます. \begin{equation} r_{\eta}(\boldsymbol{\theta}) = h(\boldsymbol{\theta})\exp(f(\boldsymbol{\eta})^{\mathrm{T}}g(\boldsymbol{\theta})) \end{equation} とするとき, 以下の最適化問題を解くことになります.

\begin{equation} \boldsymbol{\eta}=\underset{\eta}{\mathrm{argmin}}D_{\mathrm{KL}}[\hat{p}^{(n)}\| r_\eta] \quad {\mathrm{s.t.}} \int r_{\eta}(\boldsymbol{\theta})\mathrm{d}\boldsymbol{\theta}=1 \end{equation}

これをLagrangeの未定乗数法で解くと, 次式を満たすようにパラメータをとることで最小値が達成できると期待できます.

\begin{equation} \mathbb{E}_{\hat{p}^{(n)}} [g(\boldsymbol{\eta})] = \mathbb{E}_{\hat{r}_{\eta}} [g(\boldsymbol{\eta})] \end{equation}

\(r\)が1次元正規分布ならば, \begin{equation} g(\theta) = [\theta^2,\ \theta,\ 1]^{\mathrm{T}} \end{equation} とできるので, \(r\)の平均と分散を\(\hat{p}\)のものに合わせればOKです.

実験

問題設定

実験を行います. 次のような設定を考えます. 以下の混合正規分布から, データ\(X=\{x_n\}_{n=1}^N\)を発生させます.

\begin{equation} a\mathrm{N}(2,1) + (1-a)\mathrm{N}(-2,1),\quad a=0.6 \end{equation}

モデルは以下の混合正規分布とします.

\begin{equation} t^{(n)}(\theta) = p(x_n\mid\theta) = a\mathrm{N}(x_n\mid \theta,1) + (1-a)\mathrm{N}(x_n\mid-2,1),\quad a=0.6 \end{equation}

\(\theta\)を推定します. \(\theta\)の事前分布\(t^{(0)}(\theta)\)として, 平均が\(0\), 分散が\(100\)の正規分布を用います. 近似分布の各因子\(r^{(n)}\)を正規分布とします. 各因子の平均と分散パラメータを反復的に調節します. このとき, \(r\)は正規分布, \(\hat{p}\)は混合正規分布になります. また, \(r^{(0)}\)は, 事前分布をそのまま利用します. \(r^{(0)}\)は無視して学習を行います. ヒストグラムと真の分布の様子を下図に示します. \(N=10\)とします.

【コード2の実行結果】

具体的な更新式は, コードをご覧ください. 導出は結構面倒です.

実験結果

実験結果を示します. 100反復ほど計算しました. 近似事後分布からのサンプルを1000サンプルほど用いて予測分布を計算しました. 下図の赤線が真の分布, 青線が推測結果です. 割と近くなっていますね. 学習はうまくいってそうです.

【コード4の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra

#statistics
using Random
using Statistics
using Distributions

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack
【Juliaコード2; データの生成】
#set the random seed
Random.seed!(42)

#create data
a = 0.6
N = 10
μ2 = -2
σsq2 = 1
true_dist = MixtureModel(Normal[Normal(2,1), Normal(μ2,σsq2)], [a,1-a])
X = rand(true_dist, N)
data = (X=X, N=N)

#visualize
fig1 = plot(-5:0.1:5,x->pdf(true_dist,x),xlabel="x",ylabel="prob_dens",title="true distribution",
    label="true pdf", color=:red)
plot!(X, st=:histogram, bins=15, normed=true, label="data", alpha=0.5)
savefig(fig1, "figs-EP/fig1.png")
【Juliaコード3; 関数定義】
#normal pdf
npdf(x,μ,σsq) = pdf(Normal(μ,sqrt(σsq)), x)

#update mean and var of approximation distribution except one factor rn
function update_params_except_n(m, ssq, mn, ssqn)
    ssqtmp = 1/(1/ssq-1/ssqn)
    mtmp = m + ssqtmp*(m-mn)/ssqn
    return mtmp, ssqtmp
end

#normalizing
calcZn(xn,mtmp,ssqtmp,a,μ2,σsq2) = a*npdf(xn,mtmp,ssqtmp+1) + (1-a)*npdf(xn,μ2,σsq2)

function moment_matching(xn,mtmp,ssqtmp,a,Zn,μ2,σsq2)
    m = (a*npdf(xn,mtmp,ssqtmp+1)*((ssqtmp*xn+mtmp)/(ssqtmp+1)) + (1-a)*npdf(xn,μ2,σsq2)*mtmp)/Zn
    ssq = (a*npdf(xn,mtmp,ssqtmp+1)*(ssqtmp/(ssqtmp+1)) + (1-a)*npdf(xn,μ2,σsq2)*ssqtmp)/Zn
    return m,ssq
end

#update mean and var of factor rn
function update(m,ssq,mtmp,ssqtmp)
    ssqn = 1/(1/ssq-1/ssqtmp)
    mn = mtmp + ssqn*(m-mtmp)/ssq
    return mn,ssqn
end

#expectation propagation
function myEP(data, model_params, n_train)
    @unpack X,N = data
    @unpack μ2,σsq2,a = model_params
    
    #initialize
    mns = zeros(N); ssqns = ones(N);
    m = 0; ssq=1/(sum(1 ./ssqns)+1/100);
    
    #train loop
    @showprogress for k in 1:n_train
        #update r1,...,rN
        for n in 1:N
            mtmp, ssqtmp = update_params_except_n(m, ssq, mns[n], ssqns[n])
            Zn = calcZn(X[n],mtmp,ssqtmp,a,μ2,σsq2)
            m,ssq = moment_matching(X[n],mtmp,ssqtmp,a,Zn,μ2,σsq2)
            mns[n],ssqns[n] = update(m,ssq,mtmp,ssqtmp)
        end
    end
    return m,ssq
end
【Juliaコード4; 学習】
#posterior mean and variation
n_train = 100
model_params = (μ2=μ2, σsq2=σsq2, a=a)
@time m,ssq = myEP(data, model_params, n_train)
println("posterior mean=$(round(m,digits=3)), posterior var=$(round(sqrt(ssq),digits=3))")

#model pdf
function pmodel(x, θ, model_params)
    @unpack μ2,σsq2,a = model_params
    pdf(MixtureModel(Normal[Normal(θ,1), Normal(μ2,σsq2)], [a,1-a]),x)
end

#predictive distribution
n_samps = 1000
θsamps = rand(Normal(m,sqrt(ssq)), n_samps)
function pred(x, θsamps, n_samps, model_params)
    preds = zeros(n_samps)
    for s in 1:n_samps 
        preds[s] = pmodel(x, θsamps[s], model_params)
    end
    return mean(preds)
end

#visualize
fig2 = plot(-5:0.1:5,x->pdf(true_dist,x),xlabel="x",ylabel="prob_dens",title="true and predictive pdf",
    label="true pdf", color=:red)
plot!(x->pred(x, θsamps, n_samps, model_params), color=:blue, label="predictive")
savefig(fig2, "figs-EP/fig2.png")
参考文献

      [1]T.P.Minka, Expectation Propagation for approximate Bayesian inference, Proceedings of the Seventeenth Conference on Uncertainty in Artificial Intelligence, 2001