期待値伝播法, てこずったので記録として残しておきます.
期待値伝播法[1]は, 事後分布を近似する近似推論法です. 変分推論法と少し似ています. 事後分布の積の形を近似分布で再現して, KLダイバージェンスの意味で近づけます. 今回使ってみた感想としては, ちょっと不便かなって感じです. 以下に, メリットとデメリットを書いておきます.
アルゴリズムの概要を述べます. モデル\(p(\boldsymbol{x}\mid\boldsymbol{\theta})\), 事前分布\(p(\boldsymbol{\theta})\), データ\(X=\{\boldsymbol{x}_n\}_{n=1}^N\)に対して, 以下の事後分布を考えます.
\begin{align} p(\boldsymbol{\theta}\mid X) = \frac{1}{p(X)}\prod_{n=0}^Nt^{(n)}(\boldsymbol{\theta}),\quad t^{(n)}(\boldsymbol{\theta}) = \begin{cases}p(\boldsymbol{x}_n\mid\boldsymbol{\theta})\quad &(n=1,\cdots,N)\\p(\boldsymbol{\theta})\quad &(n=0)\end{cases} \end{align}この事後分布を近似する方法を考えます. 以下のような形の近似分布は自然な選択です.
\begin{equation} r(\boldsymbol{\theta}) \propto \prod_{n=0}^Nr^{(n)}(\boldsymbol{\theta}) \end{equation}各因子\(\)は予め選択した分布の確率密度関数とします. 期待値伝播法では, この因子を反復的に更新していきます. アルゴリズムの大まかな手順は以下の通りです.
以下にアルゴリズムを示します. 実際には近似分布のパラメータだけ調節すれば計算できます.
Initialize approximation distribution \(\{r^{(n)}\}_{n=0}^N\) and \(r\).
for \(k=1,2,\cdots,\)
for \(n=0,\cdots,N\)
Compute \(r^{\setminus n}\)
Compute \(Z_n\) and \(\hat{p}^{(n)}\)
Compute \(r=\underset{r\in\mathcal{R}}{\mathrm{argmin}}D_{\mathrm{KL}}[\hat{p}^{(n)}\| r]\)
Compute \(r^{(n)}\)
end
end
近似分布が指数型分布族の場合には, KLダイバージェンス最小化の部分は解析的に更新式が導出できます. \begin{equation} r_{\eta}(\boldsymbol{\theta}) = h(\boldsymbol{\theta})\exp(f(\boldsymbol{\eta})^{\mathrm{T}}g(\boldsymbol{\theta})) \end{equation} とするとき, 以下の最適化問題を解くことになります.
\begin{equation} \boldsymbol{\eta}=\underset{\eta}{\mathrm{argmin}}D_{\mathrm{KL}}[\hat{p}^{(n)}\| r_\eta] \quad {\mathrm{s.t.}} \int r_{\eta}(\boldsymbol{\theta})\mathrm{d}\boldsymbol{\theta}=1 \end{equation}これをLagrangeの未定乗数法で解くと, 次式を満たすようにパラメータをとることで最小値が達成できると期待できます.
\begin{equation} \mathbb{E}_{\hat{p}^{(n)}} [g(\boldsymbol{\eta})] = \mathbb{E}_{\hat{r}_{\eta}} [g(\boldsymbol{\eta})] \end{equation}\(r\)が1次元正規分布ならば, \begin{equation} g(\theta) = [\theta^2,\ \theta,\ 1]^{\mathrm{T}} \end{equation} とできるので, \(r\)の平均と分散を\(\hat{p}\)のものに合わせればOKです.
実験を行います. 次のような設定を考えます. 以下の混合正規分布から, データ\(X=\{x_n\}_{n=1}^N\)を発生させます.
\begin{equation} a\mathrm{N}(2,1) + (1-a)\mathrm{N}(-2,1),\quad a=0.6 \end{equation}モデルは以下の混合正規分布とします.
\begin{equation} t^{(n)}(\theta) = p(x_n\mid\theta) = a\mathrm{N}(x_n\mid \theta,1) + (1-a)\mathrm{N}(x_n\mid-2,1),\quad a=0.6 \end{equation}\(\theta\)を推定します. \(\theta\)の事前分布\(t^{(0)}(\theta)\)として, 平均が\(0\), 分散が\(100\)の正規分布を用います. 近似分布の各因子\(r^{(n)}\)を正規分布とします. 各因子の平均と分散パラメータを反復的に調節します. このとき, \(r\)は正規分布, \(\hat{p}\)は混合正規分布になります. また, \(r^{(0)}\)は, 事前分布をそのまま利用します. \(r^{(0)}\)は無視して学習を行います. ヒストグラムと真の分布の様子を下図に示します. \(N=10\)とします.
具体的な更新式は, コードをご覧ください. 導出は結構面倒です.
実験結果を示します. 100反復ほど計算しました. 近似事後分布からのサンプルを1000サンプルほど用いて予測分布を計算しました. 下図の赤線が真の分布, 青線が推測結果です. 割と近くなっていますね. 学習はうまくいってそうです.
#mathematics
using LinearAlgebra
#statistics
using Random
using Statistics
using Distributions
#visualize
using Plots
pyplot()
#macros
using ProgressMeter
using UnPack
#set the random seed
Random.seed!(42)
#create data
a = 0.6
N = 10
μ2 = -2
σsq2 = 1
true_dist = MixtureModel(Normal[Normal(2,1), Normal(μ2,σsq2)], [a,1-a])
X = rand(true_dist, N)
data = (X=X, N=N)
#visualize
fig1 = plot(-5:0.1:5,x->pdf(true_dist,x),xlabel="x",ylabel="prob_dens",title="true distribution",
label="true pdf", color=:red)
plot!(X, st=:histogram, bins=15, normed=true, label="data", alpha=0.5)
savefig(fig1, "figs-EP/fig1.png")
#normal pdf
npdf(x,μ,σsq) = pdf(Normal(μ,sqrt(σsq)), x)
#update mean and var of approximation distribution except one factor rn
function update_params_except_n(m, ssq, mn, ssqn)
ssqtmp = 1/(1/ssq-1/ssqn)
mtmp = m + ssqtmp*(m-mn)/ssqn
return mtmp, ssqtmp
end
#normalizing
calcZn(xn,mtmp,ssqtmp,a,μ2,σsq2) = a*npdf(xn,mtmp,ssqtmp+1) + (1-a)*npdf(xn,μ2,σsq2)
function moment_matching(xn,mtmp,ssqtmp,a,Zn,μ2,σsq2)
m = (a*npdf(xn,mtmp,ssqtmp+1)*((ssqtmp*xn+mtmp)/(ssqtmp+1)) + (1-a)*npdf(xn,μ2,σsq2)*mtmp)/Zn
ssq = (a*npdf(xn,mtmp,ssqtmp+1)*(ssqtmp/(ssqtmp+1)) + (1-a)*npdf(xn,μ2,σsq2)*ssqtmp)/Zn
return m,ssq
end
#update mean and var of factor rn
function update(m,ssq,mtmp,ssqtmp)
ssqn = 1/(1/ssq-1/ssqtmp)
mn = mtmp + ssqn*(m-mtmp)/ssq
return mn,ssqn
end
#expectation propagation
function myEP(data, model_params, n_train)
@unpack X,N = data
@unpack μ2,σsq2,a = model_params
#initialize
mns = zeros(N); ssqns = ones(N);
m = 0; ssq=1/(sum(1 ./ssqns)+1/100);
#train loop
@showprogress for k in 1:n_train
#update r1,...,rN
for n in 1:N
mtmp, ssqtmp = update_params_except_n(m, ssq, mns[n], ssqns[n])
Zn = calcZn(X[n],mtmp,ssqtmp,a,μ2,σsq2)
m,ssq = moment_matching(X[n],mtmp,ssqtmp,a,Zn,μ2,σsq2)
mns[n],ssqns[n] = update(m,ssq,mtmp,ssqtmp)
end
end
return m,ssq
end
#posterior mean and variation
n_train = 100
model_params = (μ2=μ2, σsq2=σsq2, a=a)
@time m,ssq = myEP(data, model_params, n_train)
println("posterior mean=$(round(m,digits=3)), posterior var=$(round(sqrt(ssq),digits=3))")
#model pdf
function pmodel(x, θ, model_params)
@unpack μ2,σsq2,a = model_params
pdf(MixtureModel(Normal[Normal(θ,1), Normal(μ2,σsq2)], [a,1-a]),x)
end
#predictive distribution
n_samps = 1000
θsamps = rand(Normal(m,sqrt(ssq)), n_samps)
function pred(x, θsamps, n_samps, model_params)
preds = zeros(n_samps)
for s in 1:n_samps
preds[s] = pmodel(x, θsamps[s], model_params)
end
return mean(preds)
end
#visualize
fig2 = plot(-5:0.1:5,x->pdf(true_dist,x),xlabel="x",ylabel="prob_dens",title="true and predictive pdf",
label="true pdf", color=:red)
plot!(x->pred(x, θsamps, n_samps, model_params), color=:blue, label="predictive")
savefig(fig2, "figs-EP/fig2.png")