スライスサンプリング

記事の内容


今回はスライスサンプリングの記事です. MCMCの中では地味な気はしますが...

基本的なアイデア

概要

スライスサンプリングは, MCMC(Markov連鎖モンテカルロ法)の一種で, Neal(2003)[1]により提案されました. MCMCといえばMetropolis-Hastings法やGibbs samplerが中心で, スライスサンプリングは割と地味です. 一方で, 論文を読んでいると本質的でない部分にちょこちょこ登場するので, 記事にまとめることにしました. アルゴリズム自体はシンプルなアイデアから導出されますが, 細かい調整(stepping outやshrinkage)が必要で意外に厄介です. また, アルゴリズムは文献ごとに少しブレがあるようです. この記事も, 元論文とは少し異なります. また, 以下1変数の場合のみを考えます.

アイデア

スライスサンプリングの基本的なアイデアを述べます. サンプルが欲しい目的の分布を\(p\)とします. この\(p\)は正規化定数\(Z\)を用いて次のように表されるとします. 以下, 定数分違うだけなので, \(p\)と\(f\)を同一視します.

\begin{equation} p(x) = \frac{1}{Z}f(x) \end{equation}

この\(f\)を用いてサンプルを生成します. まず, 次の式に注目します.

\begin{equation} f(x) = \int_{0}^{f(x)}1dl = \int \tilde{f}(x,l) dl \end{equation}

ここで, \begin{equation} \tilde{f}(x,l) = \begin{cases} 1 \quad (\text{if } 0\leq l \leq f(x))\\ \\ 0 \quad (\text{otherwise}) \end{cases} \end{equation} と定めます. もし\(\tilde{f}\)からのサンプルが得られれば, \(x\)の成分のみを取り出して, \(f\)のサンプルが得られます. しかも, \(\tilde{f}\)はある種の一様分布です. なんとかサンプリングできそうです. 密度関数は, 例えば次のような形をしています.

【イメージ】

ということで, ちょっと変わった形の一様分布からのサンプルが得られればうまくいきそうです. そこで, 次の集合が鍵となります.

\begin{equation} S(l) = \{y\in\mathbb{R} \mid 0\leq l \leq f(y))\}, \quad (0\leq l) \end{equation}

これは, \(l\)を固定するごとに, \(\tilde{f}\)が値1をとる集合です. 以降, この集合を高さ\(l\)のスライスと呼びます. この集合は空でなければ, \(f\)のグラフを高さ\(l\)で水平に横切るような集合です. 文字通り, グラフをスライスするイメージです. スライスが特定できれば, 一様分布からのサンプルはできそうです. ということで, 以下の議論は, どうやってスライスを求めるかが問題になります. 単調な関数なら解析的に求まりそうです.

その前に, もしスライスが計算できたとして, その後どのようにサンプリングを行うかを述べておきます. まず, 初期値\(x_0\)を与えておきます. 次に, スライスの高さを決めます.

\begin{equation} l = uf(x_0),\quad u\sim \mathrm{Uni}(0,1) \end{equation}

次に高さ\(l\)のスライスを計算します. このスライスをもとに, 新たなサンプル\(x_1\)を \begin{equation} x_1\sim \mathrm{Uni}(S(l)) \end{equation} と計算します. これを続けてサンプル列が構成できます. 詳細なアルゴリズムは後ほどご紹介します.

シンプルな例

上記のアイデアは本当にうまくいくのでしょうか?試してみます. 次のような確率密度関数をもつ分布からのサンプルを計算します.

\begin{equation} p(x) \propto e^{-\beta x} \mathbb{1}_{[0,1]} \end{equation}
【コードの実行結果】

つまり, 閉区間\([0,1]\)に定義域を制限した指数分布です. 先ほどの記号で言えば, 右辺が\(f\)に相当します. 高さ\(l\)でのスライスは, 以下の通りです. 単調な関数なので, 解析的に計算できます.

\begin{equation} S(l) = \begin{cases} \left[ 0, -\frac{1}{\beta} \log l \right] \quad (\text{if } e^{-\beta} \leq l)\\ \\ \left[ 0, 1 \right] \quad (\text{otherwise}) \end{cases} \end{equation}

そして, 先ほどのアイデアを素直に実装した結果が以下の通りです. \(\beta=2.5\)とし, サンプル数は\(10000\), burn-in期間として\(1000\)サンプルを除去しました.

【コードの実行結果】

指数分布が再現できました. また, 各サンプル点(赤点)とスライス(青点線)の変化をアニメーションで示しました.

【コードの実行結果】

アルゴリズムの構成

方針

ここからは, もう少し複雑な場合を考えます. 先ほどの例では, 逆関数が解析的に計算できました. しかし, 逆関数は常に計算できるとは限りません. ここからは, 逆関数が簡単には求まらない場合に, スライスを近似的に計算する方法を考えます. 方針としては, まず現在のサンプル点を含む小さい区間から出発して, 少しずつ区間を広げていきます. 区間がスライスから少しはみ出たところで, 逆に区間を少しずつ小さくして, 次のサンプル点を計算します. 区間を広げていく過程をstepping out, 狭くする過程をshrinkage procedureと呼びます.

ここまでの話をまとめておきましょう. まず, 冒頭に紹介したアイデア(スライスする)がアルゴリズムの根幹をなします. したがってスライスサンプリングの基本的な流れは次のようになります.

  1. \(u\)を一様分布からサンプルする.
  2. \(f\)のグラフを\(l=uf(x)\)の高さでスライスする.
  3. \(x\)をスライスした区間上の一様分布からサンプルする.

このうち, 2番目のスライスに伴う区間の計算が困難な場合を考えます. これを実現する方法がstepping outとshrinkage procedureです. この2つで区間計算を代用します.

  1. \(u\)を一様分布からサンプルする.
  2. stepping out.
  3. shrinkage procedure.
  4. \(x\)をスライスした区間上の一様分布からサンプルする.

ということで, この2つの操作を簡単に解説します.

stepping out

いま, 手元にサンプル\(x\)があるとします. このサンプルをもとに, 目的の分布からの新たなサンプルを生成する方法を考えます. 基本的には冒頭のアイデア, スライス\(S(l)\)上の一様分布からのサンプルです. そして, 今はこのスライスを近似的に計算することが目的です. スライスの高さ\(l\)も手元にあるとします. 高さ\(l\)のスライスは, \begin{equation} S(l) = \{y\in\mathbb{R} \mid l \leq f(y)\} \end{equation} と表されます. 典型的には, スライスは\(x\)の両側に幅を持って広がっています. 最初に小さめの区間をとって, \(S(l)\)とほぼ一致するまで広げていくことで, スライスを近似できると期待できます. これをstepping outと呼びます. 以上のアイデアを定式化しておきます.

まず, 幅\(w\)と最大反復回数\(n_{\max}\)を予め定めておきます. 既にサンプル\(x\)と高さ\(l\)が計算してあるとします. 初期区間\([L,R]\)を以下のように, 幅\(w\)の区間としてとります.

\begin{equation} L = x-\tilde{u} w,\quad R = L + w,\quad \tilde{u} \sim \mathrm{Uni}(0,1) \end{equation}

その後, \(L\)は\(w\)ずつ減らしていきます. 逆に, \(R\)は\(w\)ずつ増やしていいきます. 最大で\(n_{\max}\)回反復計算します. 端点がスライスからはみ出るタイミング, すなわち, \begin{equation} f(L)\leq l, \quad f(R)\leq l \end{equation} で計算を止めるのが適当です. 以上が区間を広げる過程です.

shrinkage

次に, 広げた区間から新たなサンプルを生成します. そもそもスライス上の一様分布からのサンプルが欲しいので, 近似区間がスライスからはみ出ていては困ります. ということで, 広げすぎた区間を, 必要ならば縮めながらサンプルします. 具体的には, 近似した区間上の一様分布からのサンプルを計算し, 新たな\(x\)の候補とします. これがスライスからはみ出ているならば, 区間を縮めます. 逆に, スライス内ならば縮める必要はありません. 新たなサンプルとして採択します. もう少し詳しく述べます.

まず, stepping outで計算した近似区間から, 候補サンプル\(\tilde{x}\)を \begin{equation} \tilde{x} \sim \mathrm{Uni}(L,R) \end{equation} と計算します. これがスライスからはみ出ているか判定します. もし\(f(\tilde{x})\leq l\)ならばスライス外です. このとき, 端点の一方を\(\tilde{x} \)で置き換えて, 区間を縮めます. 区間を縮めて候補を再計算します. これをスライス内のサンプルが得られるまで繰り返します.

アルゴリズムのまとめ

以上をまとめて, 以下のアルゴリズムが得られます.

【スライスサンプリングのアルゴリズム】

Initialize \(x_0\) and set \(w\), \(\mathrm{n_{\max}}\)
for \(k=0,1,\cdots\)
\(u\sim\mathrm{Uni}(0,1)\)
\(l = uf(x_k)\)
#stepping out
\(\tilde{u} \sim \mathrm{Uni}(0,1)\)
\(L = x_k-\tilde{u} w\)
\(R = L + w\)
for \(k=1,\cdots,n_{\max}\)
\(L=L-w\)
if \(f(L)\leq l\); break
end
for \(k=1,\cdots,n_{\max}\)
\(R=R+w\)
if \(f(R)\leq l\); break
end
#shrinkage procedure
\(\tilde{x} \sim \mathrm{Uni}(L,R)\)
while \(f(\tilde{x})\leq l\)
if \(\tilde{x}< x\); \(L=x\)
else \(R=x\)
\(\tilde{x} \sim \mathrm{Uni}(L,R)\)
end
\(x_{k+1}=\tilde{x}\)
end

例: 混合正規分布からのサンプル

上記のアルゴリズムでサンプルできるか, 実験で確かめます. 混合正規分布からのサンプルを計算します.

\begin{equation} p(x) = a_1\mathrm{N}(\mu_1, \sigma_1^2) + a_2\mathrm{N}(\mu_2, \sigma_2^2) \end{equation}

以下の実験では, \begin{equation} a_ 1 = 0.4,\quad a_2 = 0.6,\quad \mu_1 = -1.0,\quad \mu_2 = 1.0,\quad \sigma_1 = 0.6,\quad \sigma_2 = 0.5 \end{equation} を既知とします.

【コードの実行結果】

スライスサンプリングを用いて計算した結果を以下に示します. サンプル数は\(10000\)とし, burn-in期間として\(1000\)サンプルを除去しました.

【コードの実行結果】

また, 各サンプル点(赤点)とスライスの近似(青点線)の変化をアニメーションで示しました.

【コードの実行結果】

コード

【Juliaコード1; インポート】
using Plots
using LinearAlgebra
using Distributions
pyplot()
【Juliaコード2; インポート】
#密度関数の主要項
β = 2.5
func1(x,β) = 0. ≤ x ≤ 1.0 ? exp(-β*x) : 0
f1(x) = func1(x, β)

fig1 = plot(0:0.001:1.25, f1, xlabel="x", ylabel="prob dens", title="β=$(β)", legend=false)
savefig(fig1, "figs-slice/fig1.png")
【Juliaコード3; インポート】
#区間を計算する
function calc_interval(x, β, l)
    if exp(-β) ≤ l
        x = -1/β*log(l) * rand()
    else
        x = rand()
    end
    return x
end

#slice sampling
function my_slice_sampling1(x₀, n_iter, n_burnin, β, f)
    #初期値と保存用配列
    x = x₀
    xsamps = zeros(n_iter)
    lsamps = zeros(n_iter)
    xsamps[1] = x

    for k in 2:n_iter
        #uのサンプル
        l = f(x)*rand()
        lsamps[k] = l

        #xのサンプル
        x = calc_interval(x, β, l)
        xsamps[k] = x
    end
    return lsamps[n_burnin:end], xsamps[n_burnin:end]
end
【Juliaコード4; インポート】
n_samps = 10000
n_burnin = 1000
β = 2.5
lsamps1,xsamps1 = my_slice_sampling1(0.0, n_samps, n_burnin, β, f1) 
【Juliaコード5; インポート】
fig2 = plot(xsamps1, st=:histogram, xlabel="x", ylabel="prob dens", normed=true, legend=false, title="Exp Dist")
savefig(fig2, "figs-slice/fig2.png")
【Juliaコード6; インポート】
anim1 = @animate for i in 1:100
    x = xsamps1[i]
    l = lsamps1[i+1]
    plot(0:0.001:1.25, f1, xlabel="x", ylabel="prob dens", legend=false)
    plot!([x, x], [0, l], ls=:dot, color=:red)
    plot!([x], [l], st=:scatter, color=:red)
    if exp(-β) ≤ l
        R = -1/β*log(l)
    else
        R = 1
    end
    plot!([0,R], [l,l], ls=:dot, color=:blue)
end
gif(anim1, "figs-slice/anim1.gif", fps=5)
【Juliaコード7; インポート】
#混合正規分布
function f2(x)
    return 0.4*pdf(Normal(-1.0, 0.6),x) + 0.6*pdf(Normal(1.0, 0.5),x)
end

fig3 = plot(-2.5:0.1:2.5, f2, xlabel="x", ylabel="prob dens", title="Mixed Gauss", legend=false)
savefig(fig3, "figs-slice/fig3.png")
【Juliaコード8; インポート】
#区間の端の更新
function update_end(e, w, l, f, nmax)
    for k in 1:nmax
        e = e + w #e=Lの場合はwを-wにする
        if f(e) ≤ l
            return e
            break
        end
    end
    return e
end

#stepping out
function stepping_out(x, w, l, f, nmax)
    #LとRの初期値
    u = rand()
    L = x - u*w
    R = L + w

    #LとRの更新
    L = update_end(L, -w, l, f, nmax)
    R = update_end(R, w, l, f, nmax)
    return L, R
end

#shrinkage procedure
function shrinkage(x, L, R, l, f)
    #区間の端点と候補点
    Lnew = L
    Rnew = R
    xnew = Lnew+(Rnew-Lnew)*rand()

    while f(xnew) ≤ l
        #区間の端点の更新
        if xnew < x
            Lnew = xnew
        else
            Rnew = xnew
        end

        #候補点の更新
        xnew = Lnew + (Rnew-Lnew)*rand()
    end
    return xnew, Lnew, Rnew
end

#slice sampling
function my_slice_sampling2(x₀, n_samps, n_burnin, w, nmax, f)
    #初期値と保存用配列
    xsamps = zeros(n_samps)
    lsamps = zeros(n_samps)
    Ls = zeros(n_samps)
    Rs = zeros(n_samps)
    xsamps[1] = x₀

    #サンプリング
    x = x₀
    for k in 2:n_samps
        #スライスする高さ
        l = f(x)*rand()

        #区間を定める
        L,R = stepping_out(x, w, l, f, nmax)

        #区間からサンプルする
        x, Lnew, Rnew = shrinkage(x, L, R, l, f)
        Ls[k] = Lnew
        Rs[k] = Rnew

        xsamps[k] = x
        lsamps[k] = l
    end
    return  lsamps[n_burnin:end], xsamps[n_burnin:end], Ls[n_burnin:end], Rs[n_burnin:end]
end
【Juliaコード9; インポート】
n_samps = 10000
n_burnin = 1000
w = 0.1
nmax = 100
lsamps2, xsamps2, Ls2, Rs2 = my_slice_sampling2(0.0, n_samps, n_burnin, w, nmax, f2)
【Juliaコード10; インポート】
fig4 = plot(xsamps2, st=:histogram, xlabel="x", ylabel="prob dens", normed=true, legend=false, title="Mixed Gauss")
savefig(fig4, "figs-slice/fig4.png")
【Juliaコード11; インポート】
anim2 = @animate for i in 1:100
    x = xsamps2[i]
    l = lsamps2[i+1]
    L = Ls2[i+1]
    R = Rs2[i+1]
    plot(-2.5:0.1:2.5, f2, xlabel="x", ylabel="prob dens", legend=false, xlim=[-2.5, 2.5])
    plot!([x, x], [0, l], ls=:dot, color=:red)
    plot!([x], [l], st=:scatter, color=:red)
    plot!([L,R], [l,l], ls=:dot, color=:blue)
    plot!([L,R], [l,l], st=:scatter, color=:blue)
end
gif(anim2, "figs-slice/anim2.gif", fps=5)
参考文献

      [1]R.Neal, Slice Sampling (with discussion), The Annals of Statistics, 31(3), pp.705-767, 2003
      [2]R.Robert, G.Casella, Monte Carlo Statistical Methods, Springer, 2004