Dirichlet分布の意味とは?

記事の内容


今回はDirichlet分布について解説します. 今回はとりあえず視覚化のみ. また時間があればサンプリングについても解説したいな(未定).

定義と性質

定義

【Dirichlet分布の定義】

\(\boldsymbol{x}=[x_1,\cdots,x_K]^{\mathrm{T}}\)が \begin{equation} \sum_{k=1}^{K}x_k=1 ,\quad x_k\geq 0 \ (k=1,\cdots,K) \end{equation} を満たすとする. 正の成分からなるパラメータ\(\boldsymbol{\alpha}\in \mathbb{R}^K\)に対して, 確率密度関数が \begin{equation} p(\boldsymbol{x}\mid \boldsymbol{\alpha}) = \frac{\Gamma\left( \sum_{k=1}^{K} \alpha_i\right)}{\prod_{k=1}^{K}\Gamma\left( \alpha_i\right)} \prod_{k=1}^{K} x^{\alpha_i-1} \end{equation} で定義されるとき, この分布をDirichlet分布という.

密度関数の観察

Dirichlet分布の密度関数を観察します. まずは等方的な場合を観察します. つまり, αの成分が全て等しい場合です. プロットのタイトルのαの値は, その共通の値です.

【コード3の実行結果】

次はヒートマップを見てみましょう. 上とほとんど同じだけど...

【コード4の実行結果】

次に, 必ずしも等方的でない場合を見てみましょう. Dirichlet分布からのサンプルを散布図で描きます. プロットのタイトルはαの値を示しています. このようにサンプルは単体上に分布します. αの値が大きい方にサンプルが集中します. また, αの値が大きいほど, 真ん中に集中します.

【コード5の実行結果】

この単体の表面のみを取り出してみましょう. 図では少し(√2倍)拡大しています.

【コード6の実行結果】

この表面を取り出す作業にはちょっとした座標変換が必要です. 計算は面倒なので省略します. 変換の概要は次の通りです.

直感的な意味

Dirichlet分布はその条件から, 確率ベクトルの従う分布であることが分かります. パラメータαは, 観測数のような役割を果たします. 観測数に偏りがあれば, その方向にサンプルが偏ります. 一方, どの観測数も同じくらいであれば, 真ん中ら辺に散らばって, 曖昧な状態になります.

ソースコード

共通のコード

【Juliaコード1; インポート】
using Plots
using Random
using Distributions
using SpecialFunctions
using Formatting
【Juliaコード2; 密度関数の定義】
#密度関数
function dirichlet_pdf(x::Array{Float64,1}, α::Array{Float64,1})
    if x[end] >= 0.0
        return gamma(sum(α))/prod(gamma.(α)) * prod(x.^(α.-1))
    else
        return NaN
    end
end

密度関数のプロット

【Juliaコード3; 3Dプロットのアニメーション】
#プロット範囲
x₁ = x₂ = range(0,1,step=0.02);

#初期化
anim = Animation();

#αを変化させる
αinit = ones(3);
for αabs in 0.05:0.05:4.0
    α = αabs * αinit;
    d = [dirichlet_pdf([xx; yy; 1-(xx+yy)], α) for yy in x₁, xx in x₂];
    pdf_plot = surface(x₁, x₂, d, camera=(40,65), title="α = $(format(αabs, precision=2))", xlabel="x₁", ylabel="x₂", zlim=[0,8], c=:thermal, colorbar=false);
    frame(anim, pdf_plot);
end

#保存
gif(anim, "figs-dirichlet/anim1.gif", fps=20);
【Juliaコード4; ヒートマップのアニメーション】
#プロット範囲
x₁ = x₂ = range(0,1,step=0.02);

#初期化
anim = Animation();

#αを変化させる
αinit = ones(3);
for αabs in 0.05:0.05:4.0
    α = αabs * αinit;
    d = [dirichlet_pdf([xx; yy; 1-(xx+yy)], α) for yy in x₁, xx in x₂];
    pdf_plot = heatmap(x₁, x₂, d, title="α = $(format(αabs, precision=2))", xlabel="x₁", ylabel="x₂", c=:thermal, colorbar=false, size=(400,400));
    frame(anim, pdf_plot);
end

#保存
gif(anim, "figs-dirichlet/anim2.gif", fps=20);
【Juliaコード5; 3次元の散布図】
#αの候補
αarray = [
    [0.1, 0.1, 0.1],
    [1.0, 1.0, 1.0],
    [5.0, 5.0, 5.0],
    [0.1, 0.1, 0.9],
    [1.0, 1.0, 3.0],
    [3.0, 3.0, 1.0]
];

#散布図
n_samps = 3000;
scatter_list3d = [];
for i in 1:6
    p3d = rand(Dirichlet(αarray[i]), n_samps);
    s3d = scatter(p3d[1,:], p3d[2,:], p3d[3,:], markersize=1, title="$(αarray[i])", titlefont=font(10), xlim=[0,1], ylim=[0,1], zlim=[0,1]);
    append!(scatter_list3d, [s3d]);
end
scatter_samps3d = plot(
    scatter_list3d[1],
    scatter_list3d[2],
    scatter_list3d[3],
    scatter_list3d[4],
    scatter_list3d[5],
    scatter_list3d[6],
    layout=(3,2), camera=(70,50), legend=false, size=(600,800));

#保存
savefig(scatter_samps3d, "figs-dirichlet/fig1.png");
【Juliaコード6; 2次元の散布図】
#αの候補
αarray = [
    [0.1, 0.1, 0.1],
    [1.0, 1.0, 1.0],
    [5.0, 5.0, 5.0],
    [0.1, 0.1, 0.9],
    [1.0, 1.0, 3.0],
    [3.0, 3.0, 1.0]
];

#散布図
n_samps = 3000;
scatter_list2d = [];
for i in 1:6
    p2d = rand(Dirichlet(αarray[i]), n_samps);
    s2d = scatter(p2d[2,:].-p2d[1,:], sqrt(3)*p2d[3,:], markersize=1, title="$(αarray[i])", titlefont=font(10),  xlim=[-1,1], ylim=[0,sqrt(3)]);
    append!(scatter_list2d, [s2d]);
end
scatter_samps2d = plot(
    scatter_list2d[1],
    scatter_list2d[2],
    scatter_list2d[3],
    scatter_list2d[4],
    scatter_list2d[5],
    scatter_list2d[6],
    layout=(3,2), legend=false, size=(600,800));

#保存
savefig(scatter_samps2d, "figs-dirichlet/fig2.png");
参考文献
      [1]佐藤一誠, トピックモデルによる統計的潜在意味解析, コロナ社, 初版第4刷, 2018