仮定密度フィルタリング

記事の内容


今回は仮定密度フィルタリングです. 期待値伝播法のための準備回です.

概要

この記事では, 仮定密度フィルタリング(Assumed Density Filtering; ADF)を扱います. 事後分布を近似するタイプのオンライン学習手法です. 使い勝手は悪そうですが, アイデアはシンプルです. 仮定密度フィルタリングを扱っている文献は結構少なめです. 文献[1],[2],[3]あたりが参考になりそうです.

アルゴリズムの構成

問題設定

データ\(D_1,D_2,\cdots\)が時刻\(1,2,\cdots\)において順に得られるとします. モデルを\(p(\boldsymbol{x}\mid\boldsymbol{\theta})\)とします. ここで, \(\boldsymbol{x}\in\mathbb{R}^D,\boldsymbol{\theta}\in\mathbb{R}^d\)とします. 手元には, データ\(D_{1:t}=\left\{ D_i \right\}_{i=1}^t\)と解析的には扱いづらい事後分布\(\hat{p}(\boldsymbol{\theta}\mid D_{1:t})\)の近似分布\(r_t(\boldsymbol{\theta})\)があるとします. ただし, \(r_1\)は事前分布とします. ここでは, 次の計算を目標とします. " 新たなデータ\(D_{t+1}\)が手に入ったとき, 事後分布\(\hat{p}(\boldsymbol{\theta}\mid D_{1:t+1})\)を計算したい!" ただし, 事後分布が解析的には扱いづらい状況を考えているので, 近似を試みます.

今, 手元にある\(\boldsymbol{\theta}\)に関する最良の情報は事後分布\(\hat{p}(\boldsymbol{\theta}\mid D_{1:t})\)の近似である\(r_t\)です. そこで, 次の事後分布が計算できれば良さそうです. \(Z_{t+1}\)は正規化定数です.

\begin{equation} \hat{p}(\boldsymbol{\theta}\mid D_{1:t+1}) = \frac{1}{Z_{t+1}} p(D_{t+1}\mid \boldsymbol{\theta})r_t(\boldsymbol{\theta}) \end{equation}

ただし, この事後分布も計算できるとは限りません. これを近似します.

なお, 状態空間モデルなどでは, 事後分布\(\hat{p}\)の定義式の\(r_t\)の部分を1期先予測に置き換えることもあるようです[3].

近似分布の導入

事後分布を近似します. すなわち, 以下の近似式が成り立つような\(r_{t+1}\)を考えます.

\begin{equation} r_{t+1}(\boldsymbol{\theta}) \simeq \hat{p}(\boldsymbol{\theta}\mid D_{1:t+1}) \end{equation}

問題は, \(r_{t+1}\)をどの分布のクラスから持ってきて, どの意味での近似にするかです. まず, 分布のクラスを決めます. これは都合の良い分布であれば何でも良いと思います. 共役性を用いたいのであれば, 事前分布と同じクラスにするのも1つの選択です. この分布の確率(密度)関数の族を\(\mathcal{R}\)と表します. また, どの意味での近似にするかですが, Kullback-Leiblerダイバージェンスの意味での近似が最も自然です.

\begin{equation} r_{t+1} = \underset{r\in\mathcal{R}}{\mathrm{argmin}} D_{\mathrm{KL}}(\hat{p}\| r) \end{equation}

このようにして, 新たにデータが得られるごとに, 事後分布の近似分布が求まります. 要するに, データが手に入るごとに事後分布を都合の良い分布のクラスに射影するということです.

指数型分布族による近似

近似分布を具体的に計算するためには, 最小化問題を解く必要があります. 都合の良い分布のクラス\(\mathcal{R}\)が指数型分布族(の部分集合)であれば, もっと単純になります. 一般に指数型分布族の確率(密度)関数は次のように書けます.

\begin{equation} r(\boldsymbol{\theta}) = h(\boldsymbol{\theta})\exp\left( f(\boldsymbol{\eta})^{\mathrm{T}}g(\boldsymbol{\theta})\right) \end{equation}

ここで, 記号は次のように定めました.

\begin{equation} h:\mathbb{R}^d\to\mathbb{R},\quad f:\mathbb{R}^H\to\mathbb{R}^J,\quad g:\mathbb{R}^d\to\mathbb{R}^{J}.\quad \boldsymbol{\eta}\in\mathbb{R}^H \end{equation}

ここでは, 近似分布が指数型分布族の場合を考えます. パラメータ\(\boldsymbol{\eta}\)が決まれば, 近似分布が1つ定まります. すなわち, 上のようなある関数\(f,g,h\)に対して, \begin{equation} \mathcal{R} = \left\{ r(\boldsymbol{\theta}) \mid \text{ある}\boldsymbol{\eta}\text{に対して, }r(\boldsymbol{\theta}) = h(\boldsymbol{\theta})\exp\left( f(\boldsymbol{\eta})^{\mathrm{T}}g(\boldsymbol{\theta})\right) \right\} \end{equation} としておきます.

Kullback-Leiblerダイバージェンスを最小化したいので, \(\boldsymbol{\eta}\)に関して微分します. 微分すると次のようになります. 行列\(J_f\)は\(f\)のJacobi行列です.

\begin{equation} \nabla_{\eta} D_{\mathrm{KL}}(\hat{p}\|r) = J_f(\boldsymbol{\eta})^{\mathrm{T}}\mathrm{E}_{\hat{p}}[g(\boldsymbol{\theta})] = \boldsymbol{0} \end{equation}

一方, 規格化条件\(\int r(\boldsymbol{\theta})d\boldsymbol{\theta}=1\)を微分して, 次を得ます.

\begin{equation} \nabla_\eta \int r(\boldsymbol{\theta})d\boldsymbol{\theta} = J_f(\boldsymbol{\eta})^{\mathrm{T}} E_r[g(\boldsymbol{\theta})] = \boldsymbol{0} \end{equation}

以上より, 次式を得ます.

\begin{equation} J_f(\boldsymbol{\eta})^{\mathrm{T}}\left(\mathrm{E}_{\hat{p}}[g(\boldsymbol{\theta})] - \mathrm{E}_r[g(\boldsymbol{\theta})]\right) = \boldsymbol{0} \end{equation}

Jacobi行列が正則性ならば, 次式を得ます.

\begin{equation} \mathrm{E}_{\hat{p}}[g(\boldsymbol{\theta})] = \mathrm{E}_r[g(\boldsymbol{\theta})] \end{equation}

逆に, これを満たすように\(r_{t+1}\)を取れば, Kullback-Leiblerダイバージェンスは最小になると期待できます. この方法をモーメントマッチングと言います.

数値実験

以下に簡単な実験結果を示します. ほとんど近似の意味のない自明な例になってしまいました...

真の分布は混合正規分布とします.

\begin{equation} q(x) = 0.5\mathrm{N}(x\mid \mu_1, \sigma_1) + 0.5\mathrm{N}(x\mid\mu_2, \sigma_2) \end{equation}

また, モデルは正規分布とします. \(\sigma_0\)は既知とします.

\begin{equation} p(x\mid \theta) = \mathrm{N}(x\mid\mu,\sigma_0) \end{equation}

1組100個の(真の分布からの)サンプルからなるデータ集合が全部で30個, 逐次的に手に入るとします. 平均パラメータ\(\theta\)の事前分布は標準正規分布とします. モデルが正規分布, その平均の事前分布が正規分布ですから, 事後分布は解析的に計算できて, 正規分布です. ここでは, あえてこの事後分布を近似することにします. 事後分布は, 次のような正規分布の族により近似します.

\begin{equation} \mathcal{R} = \left\{ \mathrm{N}(\mu, \sigma^2) \mid \mu\in\mathbb{R}, \sigma>0\right\} \end{equation}

以上の設定から分かるように, 正規分布を正規分布で近似するだけのつまらない例です. また別の記事で複雑な例を扱います. 下図に, データ, 真の分布, モデルの3つを示しました. モデルには平均パラメータの初期値0を放り込んでいます. 学習がうまくいけば, これが真の分布側にスライドしていくと考えられます. ここでは, \(\sigma_0^2=2, \mu_1=0.9, \mu_2=3.9 , \sigma_1^2=1.3, \sigma_2^2=1.3 \)としています.

【コード2の実行結果】

それでは近似計算をしてみましょう. 今回のケースは簡単です. まず, 正規分布族は指数型分布族の部分集合ですから, モーメントマッチングによりパラメータの更新式が得られます. 正規分布を指数型分布族の定義式に合わせて書いておきます.

\begin{equation} r(\theta) = h(\theta)\exp\left( f(\boldsymbol{\eta})^{\mathrm{T}}g(\theta)\right),\quad \boldsymbol{\eta} = \begin{bmatrix} \mu\\ \sigma^2\end{bmatrix} \end{equation} \begin{equation} h(\theta) = \frac{1}{\sqrt{2\pi}},\quad f(\boldsymbol{\eta}) = \begin{bmatrix} -\frac{1}{2\sigma^2}\\ \frac{\mu}{\sigma^2} \\ \frac{\mu}{2\sigma^2}+\log\sigma^2\end{bmatrix} , \quad g(\theta) = \begin{bmatrix} \theta^2\\ \theta\\ -1\end{bmatrix} \end{equation}

時刻\(t\)での近似分布\(r\)が手元にあるとして, 時刻\(t+1\)における近似分布\(r_{t+1}\)を計算します. といっても, 近似分布が正規分布なので, 平均パラメータ\(\mu_{t+1}\)と分散パラメータ\(\sigma_{t+1}^2\)が分かればOKです. 近似分布\(r_{t+1}\)に関する\(g(\theta)\)の期待値計算から \begin{equation} \mathrm{E}_{r_{t+1}}[g(\theta)] = \begin{bmatrix} \mu_{t+1}^2+\sigma_{t+1}^2 \\ \mu_{t+1} \\ -1\end{bmatrix} \end{equation} を得ます. 一方, 近似したい分布\(\hat{p}\)に関する\(g(\theta)\)の期待値計算から, \begin{equation} \mathrm{E}_{\hat{p}}[g(\theta)] = \begin{bmatrix} \hat{\mu}^2 + \hat{\sigma}^2 \\ \hat{\mu} \\ -1\end{bmatrix} \end{equation} を得ます. ここで, \begin{equation} \hat{\mu} = \frac{\left( \sum_{i+1}^{N_{t+1}} x_i\right)\sigma_t^2 + \sigma_0^2\mu_t}{N_{t+1}\sigma_t^2 + \sigma_0^2},\quad \hat{\sigma}^2 = \frac{\sigma_0^2\sigma_t^2}{N_{t+1}\sigma_t^2 + \sigma_0^2} \end{equation} とします. よって, 近似分布のパラメータ更新式は次式で与えられます.

\begin{equation} \mu_{t+1} = \hat{\mu},\quad \sigma_{t+1}^2 = \hat{\sigma}^2 \end{equation}

正規分布を正規分布で近似するので, 自明な更新式です. この更新式に従ってデータが手に入るごとに更新していきます. データが30組あるので, 30回更新します. 結果を以下に表示します. モデルに学習後の平均パラメータを代入しました. 当然ですが, うまくいってますね.

【コード3の実行結果】

コード

【Juliaコード1; インポート】
using Distributions
using Statistics
using Plots
pyplot()
【Juliaコード2; データの作成等】
#true distributions
function true_F(μ₁, μ₂, σ₁_sq, σ₂_sq)
    normal1 = Normal(μ₁, σ₁_sq)
    normal2 = Normal(μ₂, σ₂_sq)
    return MixtureModel([normal1, normal2])
end

#create data
function create_data(n_data, N, true_F)
    Xs = zeros(N, n_data)
    for t in 1:n_data
        Xs[:,t] = rand(true_F, N)
    end
    return Xs
end

#create the data and plot them
n_data = 30 #there are the data D1,D1,...,D30
N = 100
μ₁ = 0.9
μ₂ = 3.9
σ₁_sq = 1.3
σ₂_sq = 1.3
mixed_normal = true_F(μ₁, μ₂, σ₁_sq, σ₂_sq)
Xs = create_data(n_data, N, mixed_normal)

#true pdf
q(x) = pdf(mixed_normal, x)

#model parameter and the pdf
σ₀_sq = 2.0
p(x, μ) = pdf(Normal(μ, σ₀_sq), x)


fig1 = plot(-3:0.1:7, q, color=:red, label="true", xlim=[-3,7], xlabel="x")
plot!(-3:0.1:7, x->p(x,0), color=:green, label="model(initial)")
plot!(Xs[:], st=:histogram, bins=30, normed=true, color=:gray, alpha=0.5, label="data")
savefig(fig1, "figs-ADF/fig1.png")
【Juliaコード3; 学習と結果】
#update the mean and variance
function update_params(X, N, μ, σ_sq, σ₀_sq)
    μ_new = (σ_sq * sum(X) + σ₀_sq * μ)/(N * σ_sq + σ₀_sq)
    σ_sq_new = σ₀_sq * σ_sq / (N * σ_sq + σ₀_sq)
    return μ_new, σ_sq_new
end

#ADF
function my_ADF(Xs, σ₀_sq)
    N, n_data= size(Xs)
    
    #パラメータ保存用
    μ = 0.0
    σ_sq = 1.0
    params = zeros(2, n_data+1)
    params[1,1] = μ
    params[2,1] = σ_sq
    
    for t in 1:n_data
        μ, σ_sq = update_params(Xs[:,t], N, μ, σ_sq, σ₀_sq)
        params[1,t+1] = μ
        params[2,t+1] = σ_sq
    end
    return params
end

#show the result
params = my_ADF(Xs, σ₀_sq)
println("sample mean = $(mean(Xs[:]))")
println("approximated posterior mean = $(params[1,end])")
println("approximated posterior std = $(params[2,end])")

fig2 = plot(-3:0.1:7, q, color=:red, label="true", xlim=[-3,7], xlabel="x", title="mean plug-in")
plot!(-3:0.1:7, x->p(x,params[1,end]), color=:green, label="estimated")
plot!(Xs[:], st=:histogram, bins=30, normed=true, color=:gray, alpha=0.5, label="data")
savefig(fig2, "figs-ADF/fig2.png")
参考文献

      [1]T.P.Minka, Expectation Propagation for approximate Bayesian inference, Proceedings of the Seventeenth Conference on Uncertainty in Artificial Inteligence, pp.332-369, 2001
      [2]須山敦志, ベイズ深層学習, 講談社, 2020
      [3]K.Murphy, Machine Learning: A Probabilistic Perspective , The MIT Press, 2012