Dirichlet過程入門

記事の内容


あけましておめでとうございます. 今年もよろしくお願い申し上げます. 最近, 論文を読む過程で必要になったのでまとめておきます. Nonparametric Bayes, 特にDirichlet過程に関する記事です.

Nonparametric Bayesの概要

動機付け

まず, Nonparametric Bayes(ここではDirichlet過程)のモチベーションを確認しておきます. 今, 手元に以下のようなデータ点があるとします. これらのデータをグループ分けするタスク(クラスタリング)を考えます. この場合, 混合正規分布を使うのが適当です. 直感的には, 混合数が4か5くらいですね.

【手元にあるデータ】

とりあえず混合数4と5両方で推定しておいて, モデル比較でより良いモデルを決定できます. 今回はモデルが2つだから簡単です. しかし, トピックモデルのように候補が何通りも出てくる場合にはコストがかかります. 混合数もデータから推定する方法はないでしょうか. 混合モデルは複数の統計モデルを確率的に重みづけしたものですから, 全部で1の確率の分割の仕方を再考する必要がありそうです.

従来のモデル(パラメトリックモデル)は, 長さ(=確率)1の棒を以下のように分割するものと解釈できます. 分割した棒の長さに応じて統計モデルを重み付けします. この棒の長さはモデルのパラメータです. この方法では, 予め分割数を決め打ちする必要があります. 下図は分割数5の場合です.

【従来のモデリング】

分割数の指定を避けるために, 棒を無限に分割してみます. どうせ全体が1ですから, 重要でない統計モデルには短い棒が割り当てられます.

【やってみたいこと】

各棒の長さがモデルのパラメータですから, パラメータが無限個に増えたことになります. このように, 無限次元パラメータ空間上のBayesモデルを扱う枠組みをNonparametric Bayesと言います. ということで, クラスター数(分割数)の指定問題に対する処方箋はパラメータの無限次元化です.

Nonparametric Bayesとは

パラメータの次元を無限にしたときのBayes推定をNonparametric Bayes推定と呼びます. 定義は論文によって様々です. 包括的にまとめたサーベイ論文として, [1]が詳しいです. Nonparametric Bayesの代表格はGauss過程とDirichlet過程です. この記事ではGauss過程は扱いません. 主にDirichlet過程です. Dirichlet過程といえば[2][3][4][5]あたりが基本的な論文っぽいですね.

無限次元拡張の影響で, 主役は確率分布から確率過程に交代します. つまり, Nonparametric Bayesは確率過程の入ったBayesモデルを扱う枠組みです. Gauss過程やDirichlet過程も確率過程です. 一般に, Gauss過程は回帰や分類に用います. Dirichlet過程は累積分布の推定や密度推定に用います.

この記事の構成

この記事は, 以下のように構成されます.

まず次節では, 主役であるDirichlet過程を導入します. Dirichlet過程で想定しているモデルと, 典型的な問題設定を紹介します. 続いて数学的な定義を述べます. 数学的にはただの確率過程に過ぎません. 実際に応用するにあたって, 何かしらのサンプリング方法が必要です. サンプリングの目的に応じて, 2種類のアルゴリズムを紹介します. 節の最後には例題として, 累積分布関数の推定問題を扱います.

その次の節では, Dirichlet過程混合モデルを紹介します. このモデルはクラスタリングに応用でき, 冒頭の問題に対する解決策を提示してくれます. まずは典型的なクラスタリング問題を設定します. 一般には推定方法としてGibbs samplerが用いられますが, この記事では変分推論法を用います. 変分推論法に関しては, 過去の記事も是非参考にしてください.

なお, 以下, 測度, 確率測度, 分布, 分布関数などはほとんど区別しないで用います.

Dirichlet過程の導入

モデル

Dirichlet過程はいわば確率分布の確率分布です. サイコロを振ると目が出てくる感じで, Dirichlet過程をシミュレートすると分布が出てきます. Dirichlet過程を用いると, 冒頭の無限次元の分割の確率モデルを実現できます.

Dirichlet過程を用いた推定での基本的なモデルは以下の通りです. 集中度パラメータ\(\alpha\)と基底測度\(F_0\)は既知とします(これらに関しては後ほど).

【モデル1】(Dirichlet過程モデル)

Sample \(F\sim\mathrm{DP}(\alpha,F_0)\).
for \(n=1,\cdots,N\)
Sample \(\boldsymbol{x}_n\mid F \sim F\)
end

データを発生させた分布(関数)のモデリングに使えそうです. ただし, Dirichlet過程が生み出すのは離散分布です. まとめておくと, Dirichlet過程は無限次元の離散分布を生み出す分布(確率過程)です.

Dirichlet過程の定義

Dirichlet過程を定義しておきます.

【定義】(Dirichlet過程)

可測空間\((\chi,\mathcal{S})\)上の確率測度\(F_0:\mathcal{S}\to[0,1]\)と正の実数\(\alpha\)に対して, 確率測度\(F:\mathcal{S}\to[0,1]\)がDirichlet過程\(\mathrm{DP}(\alpha,F_0)\)に従うとは, 任意の\(n\in\mathbb{N}\)と\(\chi\)の任意の有限分割\(A_1,\cdots,A_n\in\mathcal{S}\)に対して, 次式が成り立つことである.

\begin{equation} \begin{bmatrix} F(A_1)\\ \vdots\\ F(A_n) \end{bmatrix}\sim \mathrm{Dir}\left(\alpha\begin{bmatrix}F_0(A_1)\\\vdots\\F_0(A_n)\end{bmatrix}\right) \end{equation}

このとき, \(F\sim\mathrm{DP}(\alpha,F_0)\)と書く. \(\alpha\)を集中度パラメータ, \(F_0\)を基底測度と呼ぶ.

この定義だけでは推定に利用しづらいので, 何かしらのサンプリング方法が必要です. ここでは2つ紹介します. 1つは棒折り過程です. 冒頭の棒を無限個に分割する方法とほとんど同じです. 棒折り過程によって, Dirichlet過程からのサンプルを計算できます. 2つ目は中華料理店過程です. これはDirichlet過程からのサンプルに関して周辺化したもので, Dirichlet過程モデルにおけるデータのサンプリングに相当します.

Dirichlet過程には集中度パラメータと基底測度という2つのパラメータを持ちます. Dirichlet過程からのサンプルである分布は, 基底測度のそこそこ似ています. 実は期待値が基底測度に一致します. その一致具合を調整するのは集中度パラメータです. 改めて, Dirichlet過程とは, 基底測度に似た無限次元の離散分布を生成する分布(確率過程)であると言い直すことができます.

アルゴリズム1: 棒折り過程

Dirichlet過程からのサンプルの構成方法として, 棒折り過程を紹介します. 冒頭の, 長さ1の棒を無限個に分割する方法に相当します. アルゴリズムを簡単に述べておくと, まず長さ1の棒を用意します. これをベータ分布からのサンプルの割合だけカットします. カットした分だけ棒は短くなります. 残った棒を, 再びベータ分布からのサンプル分だけカットします. この操作を繰り返します. カットした棒を基底測度からのサンプルの位置に立てていくことで, Dirichlet過程からのサンプルを作ります.

棒折り過程のアルゴリズムを以下に示します. ただし, \(\prod_{j=1}^0\)は\(1\)と約束します.

【アルゴリズム1】(棒折り過程;Stick Breaking Process)

for \(k=1,2,\cdots\) do
Sample \(\boldsymbol{\theta}_k \sim F_0\).
Sample \(v_k \sim \mathrm{Beta}\left(1,\alpha\right)\).
Compute \(\pi_k = v_k\prod_{j=1}^{k-1}(1-v_j)\).
end
Compute a sample measure \(F\) from Dirichlet Process by: \begin{equation} F = \sum_{k=1}^\infty \pi_k\delta_{\boldsymbol{\theta}_k} \end{equation}

棒折り過程を用いてサンプルを構成します. 集中度パラメータを10, 基底測度を標準正規分布としてサンプルを1個発生させ, 下図に示しました. 左図は, 棒を立てた図です. 右図はサンプルから計算した分布関数です. 本来, 棒折り過程ではループは無限回です. 実際のプログラムでは有限回で打ち切ります. 下図は1000回で打ち切っています.

【コード3の実行結果】

アルゴリズム2: 中華料理店過程

Dirichlet過程モデルからのデータのサンプル方法として, 中華料理店過程を紹介します. 中華料理店過程では, データを順次発生させます. 最初は基底測度から始めて, 2番目のデータからは経験分布と基底測度の重みつき平均からサンプルします. アルゴリズムは以下の通りです.

【アルゴリズム2】(中華料理店過程;Chinese Restaurant Process)

Sample \(\boldsymbol{x}_1\sim F_0\).
for \(n=1,\cdots,N-1\)
Sample \(\boldsymbol{x}_{n+1}\mid\boldsymbol{x}_1,\cdots,\boldsymbol{x}_n \sim \frac{\alpha}{n+\alpha}F_0+\frac{n}{n+\alpha}\frac{1}{n}\sum_{i=1}^n\delta_{\boldsymbol{x}_i}\).
end

中華料理店過程を用いてデータを生成します. 集中度パラメータを10, 基底測度を標準正規分布としてサンプルを1個発生させ, 下図に示しました. 左図はデータのヒストグラムです. 右図は生成したデータから計算した経験分布です. このアルゴリズムではDPからのサンプルに関して周辺化してあるので, 有限回のループで十分です.

【コード4の実行結果】

実験: CDFの推定

Dirichlet過程を用いた例題として, 累積分布関数の推定問題を扱います. 手元にサイズ\(N=10\)のデータが\(\left\{x_n\right\}_{n=1}^N\)あるとします. このデータを生成した分布の分布関数を推定します. まず事前分布にDirichlet過程を設定します. 集中度パラメータを10, 基底測度を標準正規分布とします. このとき, 事後分布は以下のDirichlet過程になります.

\begin{equation} \mathrm{DP}\left( \alpha+N, \frac{\alpha}{\alpha+N}F_0 + \frac{N}{N+\alpha}\frac{1}{N}\sum_{n=1}^{N}\delta_{x_n} \right) \end{equation}

以下が推定結果です. 真の分布関数(赤実線), 事前分布からのサンプル(青点線), 事後分布からのサンプル(橙一点鎖線)を示しました(1サンプルだけでは分かりづらいですね...反省). まあ, 気持ち近づいているのかなって感じです. ちなみに真の分布は正規分布\(\mathrm{N}(-0.1,1.4^2)\)です.

【コード5の実行結果】

Dirichlet過程混合モデル

問題設定

Dirichlet過程を用いた推定として, 累積分布関数の推定問題を扱いました. ここからは, Dirichlet過程の応用として, 分布の密度推定(クラスタリング)を試します.

Dirichlet過程混合モデルでは, 以下のようなモデルを考えます. パラメータ\(\alpha\),と基底測度\(F_0\)は既知とします. また, 基底測度\(F_0\)は既知の密度関数を持つとします. 前半のfor文は棒折り過程, 後半のfor文は潜在変数とデータの生成過程です.

【モデル2】(Dirichlet過程混合モデル)

Compute \(\left\{ (v_k,\boldsymbol{\theta}_k) \right\}_{k=1}^{\infty}\) by Stick Breaking Process, that is:
for \(k=1,2,\cdots\) do
Sample \(\boldsymbol{\theta}_k \sim F_0\).
Sample \(v_k \sim \mathrm{Beta}\left(1,\alpha\right)\).
Compute \(\pi_k = v_k\prod_{j=1}^{k-1}(1-v_j)\).
end
for \(n=1,\cdots,N\) do
Sample \(z_n\mid\boldsymbol{\pi}\sim \mathrm{Cat}(\boldsymbol{\pi})\).
Sample \(\boldsymbol{x}_n\mid z_n\sim p(\boldsymbol{x}_n\mid z_n,\boldsymbol{\theta}_{z_n})\).
end

このモデルは混合数が無限大の混合モデルです. 重要でない因子には低い確率が割り当てられる(と期待できる)ので無視できます.

実験: 変分推論法でクラスタリング

密度推定の方法として, 変分推論法を紹介します. 変分推論に関しては過去の記事でも何度か扱っています. 今回用いるのは平均場近似による方法[6],[7]です. 他の推論方法として, Gibbs samplerあります. こちらは既存の文献[10],[11]で詳しく解説されています. 他にもスライスサンプリングによる方法があります. 変分推論法やGibbs samplerでは無限和を有限で打ち切る必要が生じますが, スライスサンプリングでは変数を追加することで解析的に有限和に置き換えます. スライスサンプリングを使えば有限に帰着するというよりは, \(\pi_k\)の無限級数がほとんど確実に収束することがポイントです(多分).

次のようなデータ点を考えます. データは全部で\(N=100\)点あります. これらを適当にグループ化(クラスタリング)します. Dirichlet過程混合モデルを用いると, 予めクラスター数を定めずにデータから推定できます. 普通の有限混合モデルの代わりに無限混合モデル(Dirichlet過程混合モデル)を用います. ただし, コンピュータの有限性から無限和を有限和で近似する必要が生じます.

【コード7の実行結果】

観測値のモデル\(p(\boldsymbol{x}\mid z,\boldsymbol{\theta}_z)\)は正規分布\(\mathrm{N}(\boldsymbol{\theta}_z,\sigma^2I_2)\)とします. また, 基底測度(\(\boldsymbol{\theta}\)の事前分布)\(F_0\)は正規分布\(\mathrm{N}(\boldsymbol{0},\sigma_0^2I_2)\)とします.

以上のモデルの下, 近似分布を以下のように設定します[10][11]. Kは既知の整数です. 無限個の因子の扱いは困難です. ここではKで打ち切っています.

\begin{equation} r(\boldsymbol{v},\boldsymbol{\theta},\boldsymbol{z}) = \left\{\prod_{k=1}^{K-1}r(v_k)\right\}\left\{\prod_{k=1}^{K}r(\boldsymbol{\theta}_k)\right\}\left\{\prod_{n=1}^{N}r(z_n)\right\} \end{equation}

これらを用いて(頑張って)計算すると, 以下の更新式が得られます.

\begin{align} r(v_k) &= \mathrm{Beta}(v_k\mid\alpha_k,\beta_k) \\ r(\boldsymbol{\theta}_k) &= \mathrm{N}(\boldsymbol{\theta}\mid \boldsymbol{\mu}_k,\sigma_k^2I_2) \\ r(z_n) &= \mathrm{Cat}(z_n\mid \boldsymbol{\pi_n}) \end{align}

ここで, 各分布のパラメータは以下のようになります. \(\psi\)はディガンマ関数です.

\begin{align} \alpha_k &= \alpha + \sum_{n=1}^N \pi_{nk}\\ \beta_k &= 1 + \sum_{n=1}^N\sum_{j=k+1}^K \pi_{nk}\\ (\sigma_k^2)^{-1} &= \frac{1}{\sigma_0^2} + \frac{1}{\sigma^2}\sum_{n=1}^N\pi_{nk}\\ \boldsymbol{\mu}_k &= \frac{\sigma_k^2}{\sigma^2}\sum_{n=1}^N\pi_{nk}\boldsymbol{x}_n\\ \pi_{nk} &\propto \exp\left\{-\frac{1}{2\sigma^2}\left(\|\boldsymbol{\mu}_k\|^2+2\sigma_k^2-2\boldsymbol{\mu}_k^{\mathrm{T}}\boldsymbol{x}_n\right) + \psi(\alpha_k)-\psi(\alpha_k+\beta_k)+\sum_{j=1}^{k-1}\left(\psi(\beta_j)-\psi(\alpha_j+\beta_j)\right)\right\} \end{align}

これらのパラメータを順に更新していきます. 反復回数を\(1000\)回としました. また, \(K=20\), \(\alpha=0.8\), \(\sigma^2=1\), \(\sigma_0^2=4\)としました. 結果を以下に示します. 最も確率の高いクラスターごとに色分けしました. そこそこうまくいってそうです. 元のデータは混合数5の混合正規分布から生成しています. 下図では1個だけ無視されています. 観測モデルや事前分布, ハイパーパラメーター等は工夫する必要がありそうです.

【コード8の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using SpecialFunctions

#dataframe
using DataFrames

#statistics
using Random
using Statistics
using Distributions
using EmpiricalCDFs:EmpiricalCDF,push!,sort!

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack
【Juliaコード2; 関数の定義】
#base measure
abstract type AbstractBaseMeasure end
struct BaseMeasure<:AbstractBaseMeasure
    ps::Vector{Float64}
    prior_base::UnivariateDistribution
    emp_dist::DiscreteNonParametric
    function BaseMeasure(prior_base,α,X)
        N = length(X)
        if N==0
            return prior_base
        end
        emp_dist = empirical_dist(X,N)
        return new([α/(N+α),N/(N+α)],prior_base,emp_dist)
    end
end

#empirical distribution
function empirical_dist(X,N)
    data_df = DataFrame(X=X)
    freq_df = combine(groupby(data_df,:X),nrow=>:Freq)
    xs = freq_df.X
    ps = freq_df.Freq/N
    return DiscreteNonParametric(xs,ps)
end

function Base.rand(post_base::BaseMeasure)
    u = rand()
    if u<=post_base.ps[1]
        return rand(post_base.prior_base)
    else
        return rand(post_base.emp_dist)
    end
end

#base measure
abstract type AbstractBaseMeasure end
struct BaseMeasure<:AbstractBaseMeasure
    ps::Vector{Float64}
    prior_base::UnivariateDistribution
    emp_dist::DiscreteNonParametric
    function BaseMeasure(prior_base,α,X)
        N = length(X)
        if N==0
            return prior_base
        end
        emp_dist = empirical_dist(X,N)
        return new([α/(N+α),N/(N+α)],prior_base,emp_dist)
    end
end

#empirical distribution
function empirical_dist(X,N)
    data_df = DataFrame(X=X)
    freq_df = combine(groupby(data_df,:X),nrow=>:Freq)
    xs = freq_df.X
    ps = freq_df.Freq/N
    return DiscreteNonParametric(xs,ps)
end

function Base.rand(post_base::BaseMeasure)
    u = rand()
    if u<=post_base.ps[1]
        return rand(post_base.prior_base)
    else
        return rand(post_base.emp_dist)
    end
end

#stick breaking process
function StickBreakingProcess(kmax,prior_base,α,X)
    base = BaseMeasure(prior_base,α,X)
    α += length(X)
    logwghks = zeros(kmax) #weights
    posks = zeros(kmax) #positions
    logwghk = 0.0
    lenk_old = 0.0
    lenk_new = 1/2
    @showprogress for k in 1:kmax
        posk = rand(base)
        posks[k] = posk
        lenk_old = lenk_new
        lenk_new = rand(Beta(1,α))
        logwghk += log(lenk_new)-log(lenk_old)+log(1-lenk_old)
        logwghks[k] = logwghk
    end
    emp_df = combine(groupby(DataFrame(pos=posks,prob=exp.(logwghks)),:pos),:prob=>sum)
    return DiscreteNonParametric(emp_df.pos,emp_df.prob_sum)
end 
StickBreakingProcess(kmax,base,α) = StickBreakingProcess(kmax,base,α,[])

#sample X
function sampleF(ecdf,F₀,α,n)
    u = rand()
    if u<=(n-1)/(n+α-1)
        samp = rand(ecdf)
    else
        samp = rand(F₀)
    end
    return samp
end


#Chinese Restaurant Process
function ChineseRestaurantProcess(N,F₀,α)
    ecdf = EmpiricalCDF()
    samps = zeros(N)
    samps[1] = rand(F₀)
    for n in 2:N
        samp = sampleF(ecdf,F₀,α,n)
        samps[n] = samp
        push!(ecdf,samp)
        sort!(ecdf)
    end
    return samps,ecdf
end

#random number from π₁F₁+π₂F₂
function Base.rand(F₁,F₂,π₁,π₂)
    u = rand()
    if u<=π₁
        return rand(F₁)
    else
        return rand(F₂)
    end
end
【Juliaコード3; 棒折り過程】
#base distiribution and intensity parameter
α = 10
prior_base = Normal(0,1)

#sample via SBP
Random.seed!(42)
kmax = Int(1e3)
@time Fsamp = StickBreakingProcess(kmax,prior_base,α)

#visualize
p1 = plot(Fsamp.support,Fsamp.p,st=:bar,xlabel="X",ylabel="probability and density",
    label="SBP",bar_width=0.05,title="Dirichlet Process via SBP")
plot!(x->pdf(prior_base,x),label="base")
p2 = plot(Fsamp.support,cumsum(Fsamp.p),title="Cumulative",xlabel="X",ylabel="probability and density",
    label="SBP")
plot!(x->cdf(prior_base,x),label="base")
fig1 = plot(p1,p2,size=(1000,400))
savefig(fig1,"figs-NPB/fig1.png")
【Juliaコード4; 中華料理店過程】
#sample from CRP
Random.seed!(42)
N = 100
α = 10
prior_base = Normal(0,1)
@time xsamps,ecdf = ChineseRestaurantProcess(N,prior_base,α)

#visualize
p1 = histogram(xsamps,normed=true,bins=20,xlabel="X",ylabel="relative frequency",title="Sample X via CRP",
    label="CRP")
plot!(x->pdf(prior_base,x),label="base")
p2 = plot(-3:0.1:3,x->ecdf(x),title="Cumulative",xlabel="X",ylabel="probability and density",label="CRP")
plot!(x->cdf(prior_base,x),label="base")
fig2 = plot(p1,p2,size=(1000,400))
savefig(fig2,"figs-NPB/fig2.png")
【Juliaコード5; CDFの推定】
#set the random seed
Random.seed!(42)

#create data
N = 10
true_dist = Normal(-0.1,1.4)
X = rand(true_dist,N)

#model and algorithm parameters
α = 10
prior_base = Normal(0,1)

#prior sampling
kmax = Int(1e3)
prior_F_samp = StickBreakingProcess(kmax,prior_base,α)
fig3 = plot(prior_F_samp.support,cumsum(prior_F_samp.p),title="Cumulative Estimation",
    xlabel="X",ylabel="CDF",label="SBP(prior sample)",legend=:topleft,ls=:dash)

#posterior sampling
post_F_samp = StickBreakingProcess(kmax,prior_base,α,X)
plot!(post_F_samp.support,cumsum(post_F_samp.p),title="Cumulative",label="SBP(posterior sample)",ls=:dashdot)
plot!(x->cdf(true_dist,x),label="true CDF",color=:red)
savefig(fig3,"figs-NPB/fig3.png")
【Juliaコード6; 関数の定義】
#variational parameters and distirbutions
abstract type AbstractVariationalParameters end
abstract type AbstractVariationalDistributions end
mutable struct VariationalParameters<:AbstractVariationalParameters
    αks::Vector{Float64}
    βks::Vector{Float64}
    σksqs::Vector{Float64}
    μks::Matrix{Float64}
    πns::Matrix{Float64}
    function VariationalParameters(α,σ₀sq,d,K,N)
        αks = α*ones(K)
        βks = ones(K)
        σksqs = σ₀sq*ones(K)
        μks = zeros(d,K)
        πns = ones(K,N)/K
        return new(αks,βks,σksqs,μks,πns)
    end
end
mutable struct VariationalDistributions<:AbstractVariationalDistributions
    vk_dists::Vector{Distribution}
    θk_dists::Vector{Distribution}
    zn_dists::Vector{Distribution}
    function VariationalDistributions(params::VariationalParameters)
        @unpack αks,βks,σksqs,μks,πns = params
        vk_dist = []; θk_dist = []; πn_dist = [];
        for k in 1:K-1
            push!(vk_dist,Beta(αks[k],βks[k]))
        end
        for k in 1:K
            push!(θk_dist,MvNormal(μks[:,k],σksqs[k]))
        end
        for n in 1:N
            push!(πn_dist,Categorical(πns[:,n]))
        end
        return new(vk_dist,θk_dist,πn_dist)
    end
end

#data and model parameters
abstract type AbstractData end
abstract type AbstractModelParameters end
struct Data<:AbstractData
    X::Matrix{Float64}
    N::Int64
end
struct ModelParameters<:AbstractModelParameters
    α::Float64
    σsq::Float64
    σ₀sq::Float64
    d::Int64
    K::Int64
end

#update v's distriibution
function update_vks!(var_params,α,K)
    for k in 1:K-1
        var_params.αks[k] = α + sum(var_params.πns[k,:])
        var_params.βks[k] = 1 + sum(var_params.πns[k+1:end,:])
    end
end

#update θ's distribution
function update_θks!(var_params,X,σsq,σ₀sq,K)
    for k in 1:K
        var_params.σksqs[k] = 1/(1/σ₀sq+sum(var_params.πns[k,:])/σsq)
        var_params.μks[:,k] = var_params.σksqs[k]*(X*var_params.πns[k,:])/σsq
    end
end

#logτk(log unnormalized probability of zk)
logτk(var_params,k,n,X,σsq,d) = (
    -(norm(var_params.μks[:,k])^2+d*var_params.σksqs[k]-2*dot(var_params.μks[:,k],X[:,n]))/2/σsq 
    + digamma(var_params.αks[k])-digamma(var_params.αks[k]+var_params.βks[k])
    +sum(digamma.(var_params.βks[1:k-1])-digamma.(var_params.αks[1:k-1]+var_params.βks[1:k-1]))
)

#update z's distribution
function update_zns!(var_params)
    for n in 1:N
        for k in 1:K
            var_params.πns[k,n] = exp(logτk(var_params,k,n,X,σsq,d))
        end
        var_params.πns[:,n] = var_params.πns[:,n]/sum(var_params.πns[:,n])
    end
end

#mean field variational inference
function VariationalInference(data,n_train,model_params)
    #initialize
    @unpack X,N = data
    @unpack α,σsq,σ₀sq,d,K = model_params
    var_params = VariationalParameters(α,σ₀sq,d,K,N)
    
    #update variational parameters
    @showprogress for s in 1:n_train
        update_vks!(var_params,α,K)
        update_θks!(var_params,X,σsq,σ₀sq,K)
        update_zns!(var_params)
    end
    
    #variational distirbutions
    return VariationalDistributions(var_params)
end
【Juliaコード7; データの作成】
#set the random seed
Random.seed!(40)

#Guassian mixture 
θ = 2*π/5
μ_func(i, θ) = 2*[cos(π/10+i*θ), sin(π/10+i*θ)]
Σ = I(2)/2
μ₀ = μ_func(0, θ)
μ₁ = μ_func(1, θ)
μ₂ = μ_func(2, θ)
μ₃ = μ_func(3, θ)
μ₄ = μ_func(4, θ)
true_dist = MixtureModel(
    [MvNormal(μ₀,Σ), MvNormal(μ₁,Σ), MvNormal(μ₂,Σ), MvNormal(μ₃,Σ), MvNormal(μ₄,Σ)]
)

#data
N = 100
X = rand(true_dist, N)
data = Data(X,N)

#visualize
fig4 = plot(X[1,:],X[2,:],st=:scatter,label="data",xlabel="x₁",ylabel="x₂",title="data",aspect_ratio=1,
markerstrokewidth=0.5,markersize=8)
savefig(fig4,"figs-NPB/fig4.png")
【Juliaコード8; 推論の実行】
#initialize model parameters
α = 0.8
σsq = 1
σ₀sq = 4
d = 2
K = 30
model_params = ModelParameters(α,σsq,σ₀sq,d,K)

#mean field variational inference for Dirichlet Process Mixture Model
n_train = 1000
@time var_dists = VariationalInference(data,n_train,model_params)

#visualize
zns = [argmax(var_dists.zn_dists[n].p) for n in 1:N]
fig5 = plot(X[1,:],X[2,:],group=zns,st=:scatter,markerstrokewidth=0.5,markersize=8,label=false,aspect_ratio=1,
title="clustering by Dirichlet Process Mixture",xlabel="x₁",ylabel="x₂")
savefig(fig5,"figs-NPB/fig5.png")
参考文献

      [1] J.Xuan, J.Lu, G.Zhang, A Survey on Bayesian Nonparametric Learning, ACM Computing Surveys, 52(1), pp.1-36, 2019.
      [2] T.S.Ferguson, A Bayesian Analysis of some nonparametric problems, Ann.Stat, 1(2), pp.209-230, 1973.
      [3] D.Blackwell, J.B.MacQueen, Ferguson Distributions Via Polya Urn Schemes, Ann.Stat, 1(2), pp.353-355, 1973.
      [4] J.Sethuraman, A constructive definition of Dirichlet priors, Stat.Sinica, 4(2), pp.639-650, 1994.
      [5] H.Ishwaran, M.Zarepour, Exact and Approximate Sum Representations for the Dirichlet Process, The Canadian Journal of Statistics , 30(2), pp.269-283,2002.
      [6] D.Blei, M.Jordan, Variational methods for the Dirichlet process, Proceedings of the twenty-first international conference on Machine learning, 2004.
      [7] D.Blei, M.Jordan, Variational inference for Dirichlet process mixtures, Bayesian Anal, 1(1), pp.121-143, 2006.
      [8] S.Ghosal, A.van der Varrt, Fundamentals of Nonparametric Bayesian Inference, Cambridge University Press, 2017.
      [9] L.Wsserman, Nonparametric Bayesian Methods, Available: https://www.stat.cmu.edu/~larry/=sml/nonparbayes, (最終アクセス: 2022/01/01).
      [10] 石井健一郎, 上田修功, 続・わかりやすいパターン認識ー教師なし学習入門ー, オーム社, 2014.
      [11] 佐藤一誠, ノンパラメトリックベイズ 点過程と統計的機械学習の数理, 講談社, 2016.