WAICを使おう!

記事の内容


  • WAICの概要
    • いつ使うのか
    • 基礎概念の定義
    • 使い方
  • 1次元の例
    • 正規分布の例
    • t分布の例
    • 混合正規分布の例
  • 2次元の例
    • 真の分布とモデル
    • Gibbs saplerによる学習
    • Bayes推定
    • モデル選択
  • 注意点: モデル選択の失敗
    • 逆相関の観察
  • コード
    • 共通のコード
    • 正規分布の例
    • t分布の例
    • 1次元混合正規分布の例
    • 2次元混合正規分布の例
    • 逆相関の例

コードと本文で記号が少し違います. 許してください.

WAICの概要

いつ使うのか

データ解析では, データ解析者がモデルを設定します. 一度に複数のモデルを考えることもできます. これらのうち, どのモデルが最も良いモデルなのか. それを判定する指標の1つがWAICです. もう少し数学的に背景を述べます.

基礎概念の定義

いま, 手元にデータの組 \begin{equation} D = \left\{ x_1, \cdots, x_n \right\} \subset \mathbb{R}^N \end{equation} があるとします. このデータたちは, ある分布qに独立に従う確率変数\( X _1, \cdots, X_n \)の実現値と考えます. この分布qを真の分布と呼びます. 統計学や機械学習は, この真の分布を知りたいというモチベーションがあります. 一般に真の分布は未知です. そこで, モデルを用意して真の分布を推測するという方針を採ります. つまり, 真の分布\(q(x)\)に対して, データを解析する人間は確率モデル\(p(x|\theta)\)を用意します. ここで, \(\theta\in \mathbb{R}^d\)はモデルのパラメータです. また, Bayes統計ではパラメータ\(\theta\)の事前分布\(p(\theta)\)を用意します. 確率モデルと事前分布が用意できれば, パラメータ\(\theta\)の事後分布\(p(\theta|D)\)が計算できます.

\begin{equation} p(\theta|D) \propto p(x|\theta) p(\theta) \end{equation}

真の分布\(q(x)\)に近い分布を作る方法はいくつかあります. 例えば, \(\hat{\theta}\)を最尤推定値とするとき, \(p(x|\hat{\theta})\)は真の分布に近くなると期待できます. これを最尤推測といいます. また, \(\tilde{\theta}\)を事後分布の平均とするとき, \(p(x|\tilde{\theta})\)は真の分布に近くなると期待できます. これを平均プラグイン推測といいます. 一方, Bayes推測では, 次の予測分布を用います. 予測分布はモデルを事後分布で重みづけた平均です. 事後分布の情報をフルに活用しているので, 真の分布に近いと期待できます.

【定義: 予測分布】

データ\(D\), 確率モデル\((x|\theta)\), 事前分布\(p(\theta)\)に対して, 予測分布を次式で定義する:

\begin{equation} p^*(x) = \int p(x|\theta) p(\theta|D)d\theta \end{equation}

当然, 予測分布はデータ\(D\)に依存します. こうして構成された予測分布は真の分布にどれほど近いでしょうか. 分布間の近さを測る指標が必要です. この役割をKullback-Leiblerダイバージェンスが担います.

【定義: Kullback-Leiblerダイバージェンス】

2つの確率密度関数\(q(x)\), \(p(x)\)に対して, \(q(x)\)に関する\(p(x)\)のKullback-Leiblerダイバージェンスを次式で定める.

\begin{equation} \mathrm{D}(q\| p) = \int q(x)\log \frac{q(x)}{p(x)}dx \end{equation}

真の分布と予測分布の近さは, \(\mathrm{D}(q\| p^*)\)で測ります. 値が小さいほど予測分布は真の分布に近いと考えます. ここで, \begin{equation} \mathrm{D}(q\| p^*) = \int q(x)\log q(x)dx - \int q(x)\log p^*(x)dx \end{equation} であり, 第1項は真の分布にのみ依存します. 第2項はモデルと真の分布に依存します. 真の分布に関するモデルのKullback-Leiblerダイバージェンス\(\mathrm{D}(q\| p^*)\)によりモデルを比較する上では, 第2項のみを比較すれば良い訳です. そこで第2項のみを汎化損失と定義します.

【定義: 汎化損失】

真の分布\(q(x)\)と予測分布\(p^*(x)\)に対して, 汎化損失を次式で定める.

\begin{equation} \mathrm{G}_n = -\int q(x)\log p^*(x)dx \end{equation}

こうして, 汎化損失\(\mathrm{G}_n\)を用いてモデル比較ができるようになりました. Kullback-Leiblerダイバージェンス\(\mathrm{D}(q\|p)\)が小さい方が, 予測分布は真の分布に近く, 汎化損失\(\mathrm{G}_n\)の小さい方がKullback-Leiblerダイバージェンスは小さいです. したがって, 汎化損失\(\mathrm{G}_n\)の小さいモデルほど良いモデルであると期待できます. しかし, 汎化損失の計算には真の分布が必要です. そこで, 真の分布を知らなくても計算できる汎化損失の推定値があると便利です.

情報量規準WAICを用いると, 汎化損失\(\mathrm{G}_n\)を推測することができます. そして, WAICは真の分布を知らなくても計算できます.

【定義: WAIC】

確率モデル\((x|\theta)\), 事前分布\(p(\theta)\)に対して, WAICを次式で定義する:

\begin{equation} \mathrm{WAIC} = -\frac{1}{n}\sum_{i=1}^{n} \log \mathrm{E}[p(X_i|\theta)] + \frac{1}{n} \sum_{i=1}^{n} \left\{ \mathrm{E}[(\log p(X_i|\theta))^2]-\mathrm{E}[\log p(X_i|\theta)]^2 \right\} \end{equation}

ただし, 期待値は事後分布に関する期待値である.

WAICは汎化損失\(\mathrm{G}_n\)と平均的に一致します.

【定理: WAICと汎化損失の関係】

WAICと汎化損失\(\mathrm{G}_n\)の間に次の関係式が成り立つ:

\begin{equation} \mathrm{E}[\mathrm{WAIC}] = \mathrm{E}[\mathrm{G}_n] + o_p\left(\frac{1}{n}\right) \end{equation}

ただし, 期待値は真の分布に関する期待値である. また, Landauのスモールオーダー表記\(o_{p}\)は確率収束に関するものである.

なお, 情報量基準としてはAIC(Akaike Information Criterion)が有名です. しかし, AICが使えるのはうまくパラメータ値を調節すればモデルによって真の分布を当てることができる(=モデルにより真の分布が実現可能である)場合で, かつ事後分布が正規分布で近似できる場合です. AICが使える状況では, WAICはAICに, 漸近的に一致します. AICは次式で定義されます.

【定義: AIC】

確率モデル\((x|\theta)\)に対して, AICを次式で定義する:

\begin{equation} \mathrm{AIC} = -\frac{1}{n}\sum_{i=1}^{n}\log p(X_i\mid \hat{\theta}) + \frac{d}{n} \end{equation}

ここで, \(\hat{\theta} \in \mathbb{R}^d\)は最尤推定量である.

使い方

いくつかモデルがあるとき, 汎化損失\(\mathrm{G}_n\)の小さいモデルが良いモデルです(少なくとも期待はできます). そして, WAICは(データの出方に対して)平均的には汎化損失\(\mathrm{G}_n\)と一致します. したがって, WAICの小さいモデルの方が良いモデルであると期待できます. そして, WAICは真の分布を知らなくても計算できます. 次のように使います. まずデータ解析者が複数のモデルを用意します. これらのうち, どのモデルが良いか分からないとします. それぞれのモデルに対して, WAICが計算できます. そして, WAICの値が最も小さいモデルを最良のモデルとして選択します. もしくは, WAICがほとんど変化しなくなったとき, 最も単純なモデルを選びます. なお, WAICを解析的に計算するのは難しいので, 期待値計算にはMCMCを用いることが多いです. すなわち, \(\theta\)のMCMCサンプルを\(\left\{ \theta^{(m)}\right\}_{m=1}^M\)として, 次のように計算できます:

\begin{equation} \mathrm{WAIC} \simeq -\frac{1}{n}\sum_{i=1}^n \log \left\{ \frac{1}{M}\sum_{m=1}^{M} p(X_i\mid\theta^{(m)}) \right\} + \frac{1}{n}\sum_{i=1}^n \left\{ \frac{1}{M-1}\sum_{m=1}^{M}\left( \log p(X_i\mid \theta^{(m)}) - \frac{1}{M}\sum_{l=1}^{M} \log p(X_i\mid \theta^{(l)})\right)^2\right\} \end{equation}

要するに標本平均と標本不偏分散から計算できます.

1次元の例

以下に続く3つの小節で, Bayes推定を行います. ここでの目的は, 次の通りです. "少なくとも正規分布など, 綺麗な場合には, WAICがAICや汎化損失とほぼ同じ値になることを確かめる". それぞれの小節で, "真の分布, 確率モデル, 事前分布"は以下の通りとします. もちろん現実では真の分布は未知です. ここでは真の分布を設定した仮想的な状況を考えます. 確率モデルと事前分布は全て共通で, データの発生源である真の分布だけを変えて, モデルの評価を行います. なお1次元の例では, モデル選択は扱いません.

真の分布 確率モデル 事前分布
N(2, 1/2) N(μ, 1/2) N(0, 1)
t(4) N(μ, 1/2) N(0, 1)
0.5N(-1, 1) + 0.5N(1, 1) N(μ, 1/2) N(0, 1)

モデルは全て共通なので, 事後分布や予測分布も全て共通です. 逆に言うと, 真の分布が何であろうと, モデルと事前分布とデータだけで推定は可能であり, だからこそモデルの評価が必要であるとも言えます. ここでは導出しませんが, パラメータμの事後分布と予測分布は以下の通りです. Nはデータのサイズです.

  • μの事後分布
  • \begin{equation} \mathrm{N}\left( \hat{\mu}_0, \hat{\lambda}_0\right),\quad \hat{\mu}_0 = \frac{\lambda_0\mu_0+\lambda\sum_{i=1}^{n}x_i}{\hat{\lambda}_0}, \quad \hat{\lambda}_0 = \lambda_0+N\lambda \end{equation}
  • 予測分布
  • \begin{equation} \mathrm{N}\left( \mu^*, \lambda^* \right),\quad \mu^*=\hat{\mu}_0,\quad \lambda^* = \frac{\lambda\lambda_0 + N\lambda^2}{\lambda_0+(n+1)\lambda} \end{equation}

以下の各小節では, 真の分布の推測方法として, 予測分布の計算と, 最尤推測を用います. また, モデルの評価指標として, AICとWAICを用います. 今回の人工的な実験では真の分布も分かっているので, 汎化損失も計算します. 当然, 現実では真の分布も汎化損失も未知です.

また, 最尤推測を行う場合, μの最尤推定値は以下の標本平均です.

\begin{equation} \hat{\mu} = \frac{1}{n} \sum_{i=1}^{n} x_i \end{equation}

正規分布の例

まずは正規分布です. うまくパラメータμを調節すれば, モデルによって真の分布を当てることができます. つまり実現可能です. 事後分布も正規分布なので, 最尤推測やAIC的には最高の条件です.

実験では, 真の分布である正規分布からデータをN=50個ほど発生させ, 上述の事後分布や予測分布を計算しました. 左下図では, データのヒストグラム, 真の分布(赤実線), 予測分布(青実線), 最尤推測(緑点線)をプロットしました. 右下図では, 事前分布(水色実線), 事後分布(橙色実線), 最尤推定値(緑色点線)をプロットしました.

予測分布も最尤推測も, 真の分布とほぼ一致しています.

【コード2の実行結果】

今回のケースではWAICは手計算で求めました. 計算ミスがなければ次式になるはずです(ちょっと不安...).

\begin{equation} \mathrm{WAIC} = \frac{1}{2}\left(\frac{\lambda}{\hat{\lambda}_0}\right)^2 -\frac{1}{2}\log \left\{ \frac{1}{2\pi}\left( \frac{\lambda\hat{\lambda}_0}{\lambda+\hat{\lambda}_0}\right)\right\} + \left\{ \frac{\lambda^2}{\hat{\lambda}_0} + \frac{1}{2}\left(\frac{\lambda\hat{\lambda}_0}{\lambda+\hat{\lambda}_0} \right)\right\} \frac{1}{n}\sum_{i=1}^{n} (x_i-\hat{\mu}_0)^2 \end{equation}

汎化損失\(\mathrm{G}_n\)も手計算があっていれば以下の通りです.

\begin{equation} \mathrm{G}_n = -\frac{1}{2} \log \left( \frac{\lambda^*}{2\pi}\right) + \frac{\lambda^*}{2\lambda} + \frac{1}{2}\lambda^* (m-\mu^*) \end{equation}

汎化損失と各種情報量規準の値は以下の通りです. (ちょっと不安な値です. どこかで計算間違いとかあるかも?)

汎化損失\(\mathrm{G}_n\) AIC WAIC
1.0733 1.1466 1.1264

t分布の例

次に, 正規分布に似ていると言われるt分布です(今思えば, 精度を学習させるタスクの方が面白かったかも...どなたかやってみてください). 確率モデル, 事前分布は先ほどと同様. 真の分布のみが変わっています. データサイズやグラフの見方も同様. 真の分布(赤の実線)からは結構ずれていますね.

【コード3の実行結果】

しかし現実には, このずれさえも認識できないのがちょっと怖いところ. 右の推定結果だけを眺めるとうまくいってそうに見えます. そこでモデル評価. 以下に, 汎化損失\(\mathrm{G}_n\)と各種情報量規準の値を示します. 汎化損失とWAICに関しては, モンテカルロ法で積分を近似計算しています.

汎化損失\(\mathrm{G}_n\) AIC WAIC
1.7478 1.7527 1.7432

混合正規分布の例

最後に混合正規分布です(これも精度の学習の方が良かったかも?). 確率モデル, 事前分布, データサイズ, グラフの見方は今までと同様です. やはり今回も, 真の分布と推測結果にそれなりにずれがあります.

【コード4の実行結果】

汎化損失\(\mathrm{G}_n\)と各種情報量規準の値を以下に示します. 汎化損失とWAICに関しては, モンテカルロ法で積分を近似計算しています.

汎化損失\(\mathrm{G}_n\) AIC WAIC
1.7542 1.7599 1.7504

2次元の例

真の分布とモデル

この節では, 2次元の混合正規分布を用いて推定・モデル評価します. ここでの目的を端的に言えば, "混合数3の混合正規分布を, 混合数5の混合正規分布で推定したときの, 推定の様子とモデル評価について調べる"ことです. まずは真の分布を以下に示します.

\begin{equation} q(\boldsymbol{x}) = \sum_{k=1}^{K} \pi_k \mathrm{N}\left(\boldsymbol{x}\mid \boldsymbol{\mu}_k, (\lambda_k)^{-1}I_N\right) \end{equation} ここで, \(K\)=3, \(N\)=2で, \begin{equation} \boldsymbol{\mu}_1 = \begin{bmatrix}0\\ 2 \end{bmatrix},\quad \boldsymbol{\mu}_2 = \begin{bmatrix}\sqrt{3}\\ -1 \end{bmatrix}, \quad \boldsymbol{\mu}_3 = \begin{bmatrix}-\sqrt{3}\\ -1 \end{bmatrix},\quad \lambda_k = 1,\quad \pi_k = \frac{1}{3} \quad (k=1,\cdots,K) \end{equation}

真の分布をヒートマップで示しておきます. また, この分布からデータを生成しておきます. このデータはこの後の推測で利用します. データサイズはn=200です.

【コード6の実行結果】

この真の分布をデータから推測するために, データ解析者がモデルとして混合数5の混合正規分布を採用したとします. すなわち, 混合数K=5に対して, 次のモデルを仮定します.

\begin{align} p(\boldsymbol{x}\mid \boldsymbol{\mu}, \boldsymbol{\lambda}, \boldsymbol{\pi}) &= \sum_{\boldsymbol{z_j}} p(\boldsymbol{x}, \boldsymbol{z}_j \mid \boldsymbol{\mu}, \boldsymbol{\lambda}, \boldsymbol{\pi}) \\ &= \sum_{\boldsymbol{z_j}} \prod_{k=1}^{K} \left\{ \pi_k\mathrm{N}\left(\boldsymbol{x}\mid \boldsymbol{\mu}_k, (\lambda_k)^{-1}\right)\right\}^{z_{jk}} \end{align}

ここで, \begin{equation} \boldsymbol{\mu} = \left\{ \boldsymbol{\mu}_1, \cdots, \boldsymbol{\mu}_K \right\},\quad \boldsymbol{\lambda} = \left\{ \lambda_1, \cdots, \lambda_K \right\},\quad \boldsymbol{\pi} = \left\{ \pi_1, \cdots, \pi_K \right\} \end{equation} です. \(\boldsymbol{z}_j\)は潜在変数で, 要素が1つだけ1で, それ以外全て0のK次元ベクトルです.

事前分布として, 次を仮定します.

\begin{align} p(\boldsymbol{\mu}, \boldsymbol{\lambda}) &\propto \prod_{k=1}^{K} \lambda_k^{\frac{n}{2}}\exp\left( -\frac{\lambda_k}{2}\|\mu_k\|^2-\lambda_k\right)\\ p(\boldsymbol{\pi}) &\propto \prod_{k=1}^{K} \pi_k^{\alpha_k-1} \end{align}

Gibbs saplerによる学習

推定にあたって, Gibbsサンプラーによって事後分布を計算することにします. Gibbsサンプラーで用いる条件付き分布は以下の通りです.

\begin{align} & z_{jk}\mid \boldsymbol{x_j}, \boldsymbol{\mu}_k, \lambda_k, \pi_k \sim \mathrm{Cat}(\boldsymbol{\eta}),\quad \eta_{jk} \propto \exp\left(-\frac{\lambda_k}{2}\|\boldsymbol{x_j}-\boldsymbol{\mu}_k\|^2+\frac{N}{2}\log \lambda_k + \log \pi_k\right) \\ & \boldsymbol{\mu}_k\mid \boldsymbol{x}, \boldsymbol{z}, \lambda_k \sim \mathrm{N}\left( \frac{\sum_{j=1}^{n} z_{jk}\boldsymbol{x_j}}{\left( 1+\sum_{j=1}^{n} z_{jk}\right)}, \quad \left\{\lambda_k\left( 1+\sum_{j=1}^{n} z_{jk}\right) I_N\right\}^{-1}\right) \\ & \lambda_k\mid \boldsymbol{z}, \boldsymbol{x}, \boldsymbol{\mu}_k \sim \mathrm{Gamma}\left( \frac{N}{2}\left(1+\sum_{j=1}^{n}z_{jk} \right)+1, \quad\frac{1}{2}\left( \sum_{j=1}^{n} z_{jk}\|\boldsymbol{x}_j-\boldsymbol{\mu}_k\|^2 + \|\boldsymbol{\mu}_k|^2 +2\right)\right) \\ & \boldsymbol{\pi} \sim \mathrm{Dir}(\hat{\boldsymbol{\alpha}}),\quad \hat{\alpha}_k = \alpha_k + \sum_{j=1}^n z_{jk} \end{align}

事後分布からのサンプルは今後, 様々な場面で使います. おまけ程度に, Gibbsサンプラーの学習の様子をアニメーションでどうぞ.

【コード11の実行結果】

Bayes推定

それでは, モデルを用いて真の分布を推測しましょう. 冒頭で述べたように, 推測方法には様々あります. 1つの例として, 平均プラグイン推測を試します. これは, パラメータの事後平均をパラメータの推定値として, モデル式に放り込む方法です. 今回のケースでは, Gibbsサンプラーから計算した標本平均を, モデル式のパラメータの部分に代入します.

平均プラグイン推測により推測した分布が以下の図です.

【コード7の実行結果】

真の分布の混合数は3ですが, 右側2つはくっついています.

平均プラグイン推測は事後分布のうち, 平均の情報しか利用していません. やはり, 事後分布全体の情報を利用した予測分布の方が, より良いと期待できます.

下図は, Gibbs samplerによるサンプルを用いて計算した予測分布です.

【コード8の実行結果】

3つの部分に分かれており, 真の分布に似ています. それではこの推測はどれくらい良いのでしょうか?いよいよモデル評価です.

モデル評価・選択

上の推測結果を評価します. この小節では, 本来は未知の汎化損失\(\mathrm{G}_n\)と, その推定値としてのWAICを計算します. そして, 以下の点を調査・確認します.

  • \(\mathrm{G}_n\)とWAICは大体同じくらいの値になっているか」.
  • 選択したモデル(混合数5の混合正規分布)は適切だったか.
  • 汎化性能の観点から, より良いモデルは何か.

今回のような人工的な実験では, 真の分布が分かっているので汎化損失が計算できます. もちろん現実のデータ解析では計算できません. ただし, 手計算は難しいので, モンテカルロ法で積分を近似しています. 真の分布から発生させた人工データから計算した汎化損失\(\mathrm{G}_n\)とWAICは以下の通りです. (四捨五入して表示しています. ) K=5のモデルは, 先ほどとは値が結構違います.

汎化損失\(\mathrm{G}_n\) WAIC
3.7885 3.6891

汎化損失の推定はうまくいっているようです.

他のモデルと比較します. ここで比較するのは, 混合数が異なる混合正規分布たちです. 混合数Kを様々変えたときの汎化損失とWAICの値です. (四捨五入して表示しています. )

混合数K 汎化損失\(\mathrm{G}_n\) WAIC
2 4.0274 3.9418
3 3.7825 3.6863
4 3.7998 3.6865
5 3.7818 3.6891
6 3.7954 3.6922
7 3.7826 3.6947
8 3.7877 3.6988

各列で, 最小の値をとる部分に色を着けています. この例では, 各モデル間の値の差よりも, 積分の近似の誤差の方が多いような気がします... それでも一応形式的には, K=3の混合モデルが選択できます.

注意点: モデル選択の失敗

逆相関の観察

予測分布やWAICはデータの出方に依存します. 特に, WAICのデータの出方に対する揺らぎに関しては次のような関係式が成り立ちます.

【定理: WAICと逆相関】

WAICについて次の関係式が成り立つ:

\begin{equation} (\mathrm{G}_n-\mathrm{S}) + (\mathrm{WAIC}-\mathrm{S}_n) = \frac{2\lambda}{n} + o_p\left( \frac{1}{n} \right) \end{equation}

ただし, \(\lambda\)は定数で, 次のように定める.

\begin{align} \mathrm{S} &= -\int q(x)\log q(x)dx \\ \mathrm{S}_n &= -\frac{1}{n} \sum_{i=1}^{n} \log q(X_i) \end{align}

データサイズ一定の場合, 右辺は定数とみなせます. したがって, 左辺の \begin{equation} \mathrm{G}_n-\mathrm{S},\quad \mathrm{WAIC}-\mathrm{S}_n \end{equation} は一方が増えれば一方が減ります. つまり負の相関があります. 特にモデルに依存する部分のみを取り出すと, 汎化損失\(\mathrm{G}_n\)とWAICは一方が増えれば一方が減ります. これはモデル選択の観点からは不都合です. 手元にあるデータから計算したWAICが偶然小さかったとしても, 実際には汎化損失が大きいということもあります. あくまでも平均的に一致するだけです. この逆相関を観察します.

下図は, 真の分布からデータを30回ほど発生させ, 各データから計算した汎化損失とWAICの値の変化をプロットしました. この2つの系列の相関係数は, -0.243です. 図からも, 一方が増えれば一方が減る傾向が見て取れます.

【コード12の実行結果】

コード

共通のコード

【Juliaコード1; インポート】
using Plots
using Random
using Distributions
using Statistics
using LinearAlgebra
pyplot()

正規分布の例

【Juliaコード2; 正規分布の例】
#真の分布
m = 1.0 #未知
λ = 2.0 #既知
q(x::Float64) = pdf(Normal(m, 1/sqrt(λ)),x)

#確率モデル
λ = 2.0
lik(x::Float64, μ::Float64) = pdf(Normal(μ, 1/sqrt(λ)), x)
loglik(x::Float64, μ::Float64) = log(pdf(Normal(μ, 1/sqrt(λ)), x))

#最尤推定
μML(X::Array{Float64,1}) = mean(X)

#事前分布
μ₀ = 0.0 #既知
λ₀ = 1.0 #既知
φ(μ::Float64) = pdf(Normal(μ₀, 1/sqrt(λ₀)), μ)

#事後分布
μ₀post(X::Array{Float64,1}) = (λ₀*μ₀+λ*sum(X))/(λ₀+length(X)*λ)
λ₀post(X::Array{Float64,1}) = λ₀+length(X)*λ
ppost(μ::Float64, X::Array{Float64,1}) = pdf(Normal(μ₀post(X), 1/sqrt(λ₀post(X))), μ)

#予測分布
μpred(X::Array{Float64,1}) = μ₀post(X)
λpred(X::Array{Float64,1}) = (λ*λ₀+length(X)*λ^2)/(λ₀+(length(X)+1)*λ)
ppred(x::Float64, X::Array{Float64,1}) = pdf(Normal(μpred(X), 1/sqrt(λpred(X))), x)

#汎化損失
Gn(X::Array{Float64,1}) = -log(λpred(X)/π/2)/2+λpred(X)/λ/2+λpred(X)*(m-μpred(X))^2/2

#AIC
AIC(X::Array{Float64,1}) = -mean(loglik.(X, μML(X)))+1/n

#WAIC
function WAIC(X::Array{Float64, 1})
    n = length(X)
   (λ/(λ₀+length(X)*λ))^2-log(λpred(X)/π/2)/2+((λ/(λ₀+length(X)*λ))^2+λpred(X)/2)*mean((X.-μpred(X)).^2)
end

n = 50
Random.seed!(42)
X = m.+randn(n)/sqrt(λ)

#データ空間のプロット
data_plot = plot(X, st=:histogram, bins=10, norm=true, label="data", title="Data", xlabel="x", color=:gray, alpha=0.5)
plot!(q, label="true", color=:red)
plot!(x->ppred(x,X), label="pred", color=:blue)
plot!(x->lik(x, μML(X)), label="ML", linestyle=:dot, color=:green, linewidth=2)

#パラメータ空間
μ_plot = plot(φ, label="prior", xlabel="μ", title="Estimate μ", color=:skyblue)
plot!(μ->ppost(μ, X), label="posterior", color=:orange)
plot!([μML(X)], st=:vline, label="ML", linestyle=:dot,color=:green)

#推定結果を表示
println("汎化損失: $(round(Gn(X), digits=4))")
println("AIC: $(round(AIC(X), digits=4))")
println("WAIC: $(round(WAIC(X), digits=4))")
fig1 = plot(data_plot, μ_plot, size=(800, 400))

#保存
savefig(fig1, "figs-WAIC/fig1.png")

t分布の例

【Juliaコード3; t分布の例】
#真の分布
ν = 4.0
q(x::Float64) = pdf(TDist(ν),x)

#確率モデル
λ = 0.5
lik(x::Float64, μ::Float64) = pdf(Normal(μ, 1/sqrt(λ)), x)
loglik(x::Float64, μ::Float64) = log(pdf(Normal(μ, 1/sqrt(λ)), x))

#最尤推定
μML(X::Array{Float64,1}) = mean(X)

#事前分布
μ₀ = 0.0 #既知
λ₀ = 1.0 #既知
φ(μ::Float64) = pdf(Normal(μ₀, 1/sqrt(λ₀)), μ)

#事後分布
μ₀post(X::Array{Float64,1}) = (λ₀*μ₀+λ*sum(X))/(λ₀+length(X)*λ)
λ₀post(X::Array{Float64,1}) = λ₀+length(X)*λ
ppost(μ::Float64, X::Array{Float64,1}) = pdf(Normal(μ₀post(X), 1/sqrt(λ₀post(X))), μ)

#予測分布
μpred(X::Array{Float64,1}) = μ₀post(X)
λpred(X::Array{Float64,1}) = (λ*λ₀+length(X)*λ^2)/(λ₀+(length(X)+1)*λ)
ppred(x::Float64, X::Array{Float64,1}) = pdf(Normal(μpred(X), 1/sqrt(λpred(X))), x)

#汎化損失(MCサンプリング)
function Gn(X::Array{Float64,1})
    n_samps = 5000 #MCサンプルのサンプル数
    samps = zeros(n_samps)
    for i in 1:n_samps
       samps[i] = log(ppred(rand(TDist(ν)), X) )
    end
    -mean(samps)
end

#AIC
AIC(X::Array{Float64,1}) = -mean(loglik.(X, μML(X)))+1/n

#WAIC
function WAIC(X::Array{Float64, 1})
    n = length(X)
   (λ/(λ₀+length(X)*λ))^2-log(λpred(X)/π/2)/2+((λ/(λ₀+length(X)*λ))^2+λpred(X)/2)*mean((X.-μpred(X)).^2)
end

n = 100
Random.seed!(42)
X = rand(TDist(ν), n)

#データ空間のプロット
data_plot = plot(X, st=:histogram, bins=20, norm=true, label="data", title="Data", xlabel="x", color=:gray, alpha=0.5)
plot!(q, label="true", color=:red)
plot!(x->ppred(x,X), label="pred", color=:blue)
plot!(x->lik(x, μML(X)), label="ML", linestyle=:dot, color=:green, linewidth=2)

#パラメータ空間
μ_plot = plot(φ, label="prior", xlabel="μ", title="Estimate μ", color=:skyblue)
plot!(μ->ppost(μ, X), label="posterior", color=:orange)
plot!([μML(X)], st=:vline, label="ML", linestyle=:dot,color=:green)

#推定結果を表示
println("汎化損失: $(round(Gn(X), digits=4))")
println("AIC: $(round(AIC(X), digits=4))")
println("WAIC: $(round(WAIC(X), digits=4))")
fig2 = plot(data_plot, μ_plot, size=(800, 400))

#保存
savefig(fig2, "figs-WAIC/fig2.png")

1次元混合正規分布の例

【Juliaコード4; 1次元正規分布の例】
#真の分布
μ₁ = -1.0
μ₂ = 1.0
σ₁ = 1.0
σ₂ = 1.0
π₁ = 0.5
π₂ = 0.5
q(x::Float64) = pdf(MixtureModel(Normal[Normal(μ₁,σ₁),Normal(μ₂,σ₂)], [π₁,π₂]), x)

#確率モデル
λ = 0.5
lik(x::Float64, μ::Float64) = pdf(Normal(μ, 1/sqrt(λ)), x)
loglik(x::Float64, μ::Float64) = log(pdf(Normal(μ, 1/sqrt(λ)), x))

#最尤推定
μML(X::Array{Float64,1}) = mean(X)

#事前分布
μ₀ = 0.0 #既知
λ₀ = 1.0 #既知
φ(μ::Float64) = pdf(Normal(μ₀, 1/sqrt(λ₀)), μ)

#事後分布
μ₀post(X::Array{Float64,1}) = (λ₀*μ₀+λ*sum(X))/(λ₀+length(X)*λ)
λ₀post(X::Array{Float64,1}) = λ₀+length(X)*λ
ppost(μ::Float64, X::Array{Float64,1}) = pdf(Normal(μ₀post(X), 1/sqrt(λ₀post(X))), μ)

#予測分布
μpred(X::Array{Float64,1}) = μ₀post(X)
λpred(X::Array{Float64,1}) = (λ*λ₀+length(X)*λ^2)/(λ₀+(length(X)+1)*λ)
ppred(x::Float64, X::Array{Float64,1}) = pdf(Normal(μpred(X), 1/sqrt(λpred(X))), x)

#汎化損失(MCサンプリング)
function Gn(X::Array{Float64,1})
    n_samps = 5000 #MCサンプルのサンプル数
    samps = zeros(n_samps)
    for i in 1:n_samps
       samps[i] = log(ppred(rand(MixtureModel(Normal[Normal(μ₁,σ₁),Normal(μ₂,σ₂)], [π₁,π₂])), X) )
    end
    -mean(samps)
end

#AIC
AIC(X::Array{Float64,1}) = -mean(loglik.(X, μML(X)))+1/n

#WAIC
function WAIC(X::Array{Float64, 1})
    n = length(X)
   (λ/(λ₀+length(X)*λ))^2-log(λpred(X)/π/2)/2+((λ/(λ₀+length(X)*λ))^2+λpred(X)/2)*mean((X.-μpred(X)).^2)
end

n = 100
Random.seed!(42)
X = rand(MixtureModel(Normal[Normal(μ₁,σ₁),Normal(μ₂,σ₂)], [π₁,π₂]), n)

#データ空間のプロット
data_plot = plot(X, st=:histogram, bins=20, norm=true, label="data", title="Data", xlabel="x", color=:gray, alpha=0.5)
plot!(q, label="true", color=:red)
plot!(x->ppred(x,X), label="pred", color=:blue)
plot!(x->lik(x, μML(X)), label="ML", linestyle=:dot, color=:green, linewidth=2)

#パラメータ空間
μ_plot = plot(φ, label="prior", xlabel="μ", title="Estimate μ", color=:skyblue)
plot!(μ->ppost(μ, X), label="posterior", color=:orange)
plot!([μML(X)], st=:vline, label="ML", linestyle=:dot,color=:green)

#推定結果を表示
println("汎化損失: $(round(Gn(X), digits=4))")
println("AIC: $(round(AIC(X), digits=4))")
println("WAIC: $(round(WAIC(X), digits=4))")
fig3 = plot(data_plot, μ_plot, size=(800, 400))

#保存
savefig(fig3, "figs-WAIC/fig3.png")

2次元混合正規分布の例

【Juliaコード5; 関数の定義】
#Zのサンプル
function zn_sample(prob::Array{Float64,1})
    K = length(prob) #クラスタ数
    zn = zeros(Int64, K) #n番目のデータ点に対応する潜在変数
    k = findfirst(rand() .<= cumsum(prob))
    k = (k==nothing) ? K : k #累積和がちょうど1にならない場合
    zn[k] = 1
    zn
end

#znのパラメータ更新
function update_zparams(X::Array{Float64,2}, μs::Array{Float64,2}, λs::Array{Float64,1}, πs::Array{Float64,1})
    D,N = size(X)
    K = length(πs)
    probs = zeros(K,N)
    for n in 1:N
        for k in 1:K
            probs[k,n] = exp(-λs[k]*norm(X[:,n].-μs[:,k])^2/2+D*log(λs[k])/2+log(πs[k]))
        end
        probs[:,n] = probs[:,n]/sum(probs[:,n])
    end
    probs
end

#μkのパラメータ更新
function update_μparams(X::Array{Float64,2}, Z::Array{Int64,2}, λs::Array{Float64,1})
    D,N = size(X)
    K,_ = size(Z)

    #平均と精度
    μprec = zeros(K)
    μmean = zeros(D,K)

    for k in 1:K
        μprec[k] = λs[k]*(1+sum(Z[k,:]))
        μmean[:,k] = (sum([Z[k,n]*X[:,n] for n in 1:N]))/(1+sum(Z[k,:]))
    end
    μmean, μprec
end

#λkのパラメータ更新
function update_λparams(X::Array{Float64,2}, Z::Array{Int64,2}, μs::Array{Float64,2})
    D,N = size(X)
    K,_ = size(Z)
    param1 = zeros(K)
    param2 = zeros(K)
    for k in 1:K
        param1[k] = D*(1+sum(Z[k,:]))/2+1
        param2[k] = (sum([Z[k,n]*norm(X[:,n].-μs[:,k])^2 for n in 1:N])+norm(μs[:,k])^2+2)/2
    end
    param1, param2
end

#πのパラメータ更新
function update_πparams(Z::Array{Int64,2}, α::Array{Float64,1})
    K = length(α)
    αhat = zeros(K)
    for k in 1:K
        αhat[k] = α[k]+sum(Z[k,:])
    end
    αhat
end

#GaussianMixtureに対するGibbsサンプラー
function myGMGS(X::Array{Float64,2}, n_samps::Int64, α::Array{Float64,1})
    #次元
    D,N = size(X)
    K = length(α)

    n_burnin = Int(n_samps/10)

    #保存用配列
    Zsamps = zeros(Int64,K,N, n_samps)
    μsamps = ones(D,K, n_samps)
    λsamps = zeros(K, n_samps)
    πsamps = zeros(K, n_samps)

    #初期値を保存
    πsamps[:,1] = rand(Dirichlet(α))
    for n in 1:N
        Zsamps[:,n,1] = zn_sample(πsamps[:,1])
    end

    #サンプル
    for s in 2:n_samps
        #znのサンプル
        probs = update_zparams(X, μsamps[:,:,s-1], λsamps[:,s-1], πsamps[:,s-1])
        for n in 1:N
            Zsamps[:,n,s] = zn_sample(probs[:,n])
        end

        #λkのサンプル
        param1, param2 = update_λparams(X, Zsamps[:,:,s], μsamps[:,:,s-1])
        for k in 1:K
           λsamps[k,s] = rand(Gamma(param1[k], 1/param2[k]))
        end

        #μkのサンプル
        μmean, μprec = update_μparams(X, Zsamps[:,:,s], λsamps[:,s])
        for k in 1:K
            μsamps[:,k,s] = rand(MvNormal(μmean[:,k], 1/μprec[k]*I(D)))
        end

        #πのサンプル
        αhat = update_πparams(Zsamps[:,:,s],α)
        πsamps[:,s] = rand(Dirichlet(αhat))
    end

    #burn-in期間は除去
    Zsamps = Zsamps[:,:,n_burnin:end]
    μsamps = μsamps[:,:,n_burnin:end]
    λsamps = λsamps[:,n_burnin:end]
    πsamps = πsamps[:,n_burnin:end]

    #平均を計算
    μmean = mean(μsamps, dims=3)
    λmean = mean(λsamps, dims=2)
    πmean = mean(πsamps, dims=2)

    #辞書に保存
    ret = Dict{String, Any}()
    ret["Zsamps"] = Zsamps
    ret["μsamps"] = μsamps
    ret["λsamps"] = λsamps
    ret["πsamps"] = πsamps
    ret["Z"] = Zsamps[:,:,end]
    ret["μmean"] = μmean[:,:,1]
    ret["λmean"] = λmean[:,1]
    ret["πmean"] = πmean[:,1]
    ret
end

#分布を返す関数
function createGM(μs::Array{Float64,2}, λs::Array{Float64,1}, πs::Array{Float64,1})
    #混合数
    K = length(πs)

    #分布の作成
    dists = (MvNormal(μs[:,k], 1/sqrt(λs[k])) for k in 1:K)
    MixtureModel(MvNormal[dists...], πs)
end

#エントロピーの計算(真の分布のみから計算)
function calcS(
        GMtrue::MixtureModel{Multivariate, Continuous, MvNormal, Categorical{Float64, Vector{Float64}}}
    )
    #真の分布からのサンプル
    n_samps = 10000
    logqs = zeros(n_samps)
    samps = rand(GMtrue, n_samps)
    for s in 1:n_samps
        logqs[s] = log(pdf(GMtrue, samps[:,s]))
    end
    -mean(logqs)
end

#経験エントロピーの計算
function calcSn(
        X::Array{Float64,2},
        GMtrue::MixtureModel{Multivariate, Continuous, MvNormal, Categorical{Float64, Vector{Float64}}}
    )
    _,N = size(X)
    logqs = zeros(N)
    for n in 1:N
       logqs[n] = log(pdf(GMtrue, X[:,n]))
    end
    -mean(logqs)
end

#平均プラグイン推測
function plugin_probs(
        xx::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},
        yy::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},
        X::Array{Float64,2},
        α::Array{Float64,1}
    )
    #モデルのクラスタ数
    K = length(α)

    #Gibbs samplerによる学習
    n_samps = 500
    ret = myGMGS(X, n_samps, α)
    μmean = ret["μmean"]
    λmean = ret["λmean"]
    πmean = ret["πmean"]

    #推測した分布
    dists = (MvNormal(μmean[:,k], 1/sqrt(λmean[k])) for k in 1:K)
    GM_plugin = MixtureModel(MvNormal[dists...],πmean)
    probs = [pdf(GM_plugin, [x;y]) for y in yy, x in xx]
    probs
end

#モデルp(x|w)に用いる関数
function tmpp(xn::Array{Float64,1}, μk::Array{Float64,1}, λk::Float64)
    pdf(MvNormal(μk, 1/sqrt(λk)*I(2)), xn)
end

#log(p(x|w))
function logp(x::Array{Float64,1}, μs::Array{Float64,2}, λs::Array{Float64,1}, πs::Array{Float64,1})
    K = length(πs)
    log(dot(πs, [tmpp(x, μs[:,k], λs[k]) for k in 1:K]))
end

#予測分布の計算
function pred_probs(
        xx::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},
        yy::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},
        X::Array{Float64,2},
        α::Array{Float64,1}
    )
    #モデルのクラスタ数
    K = length(α)

    #事後分布からのサンプル
    n_samps = 1000
    ret = myGMGS(X, n_samps, α)
    μsamps= ret["μsamps"]
    λsamps = ret["λsamps"]
    πsamps = ret["πsamps"]

    #MC積分
    _, S = size(πsamps)
    N₁ = length(xx)
    N₂ = length(yy)
    probs = zeros(N₂, N₁, S)
    for j in 1:N₂
        for i in 1:N₁
            for s in 1:S
                probs[j,i,s] = exp(logp([xx[i];yy[j]], μsamps[:,:,s], λsamps[:,s], πsamps[:,s]))
            end
        end
    end
    mean(probs, dims=3)[:,:,1]
end

#WAICの計算
function WAIC(X::Array{Float64,2}, α::Array{Float64,1})
    #モデルのクラスタ数
    K = length(α)

    #事後分布からのサンプル
    n_samps = 10000
    ret = myGMGS(X, n_samps, α)
    μsamps= ret["μsamps"]
    λsamps = ret["λsamps"]
    πsamps = ret["πsamps"]

    #WAICの計算
    _,N = size(X)
    _,S = size(πsamps)
    Elik = zeros(N,S)
    Elogp = zeros(N,S)
    for n in 1:N
        for s in 1:S
            Elik[n,s] = exp(logp(X[:,n], μsamps[:,:,s], λsamps[:,s], πsamps[:,s]))
            Elogp[n,s] = logp(X[:,n], μsamps[:,:,s], λsamps[:,s], πsamps[:,s])
        end
    end

    tmpTn = log.(mean(Elik, dims=2)) #期待値の近似計算
    Tn = -mean(tmpTn)
    tmpVn = var(Elogp, dims=2) #期待値の近似計算
    Vn = mean(tmpVn)
    Tn+Vn
end

#汎化損失
function Gn(
        X::Array{Float64,2},
        α::Array{Float64,1},
        GMtrue::MixtureModel{Multivariate, Continuous, MvNormal, Categorical{Float64, Vector{Float64}}}
    )
    #モデルのクラスタ数
    K = length(α)

    #真のクラスタ数
    Ktrue = length(GMtrue.components)

    #真の分布からのサンプル
    n_sampsx = 10000
    xsamps = rand(GMtrue, n_sampsx)

    #事後分布からのサンプル
    n_sampsw = 10000
    ret = myGMGS(X, n_sampsw, α)
    μsamps= ret["μsamps"]
    λsamps = ret["λsamps"]
    πsamps = ret["πsamps"]

    #MC積分
    _, S = size(πsamps)
    probs = zeros(n_sampsx, S)
    for i in 1:n_sampsx
        for s in 1:S
            probs[i,s] = logp(xsamps[:,i], μsamps[:,:,s], λsamps[:,s], πsamps[:,s])
        end
    end
    logpred_samps = mean(probs, dims=2)[:,1]
    -mean(logpred_samps)
end
【Juliaコード6; 真の分布】
#パラメータの真値
μtrue = [
    0. -sqrt(3) sqrt(3);
    2. -1. -1.
]
λtrue = [1.,1.,1.]
πtrue = ones(3)/3

#真の分布の作成
GMtrue = createGM(μtrue, λtrue, πtrue)

#真の分布をプロット
clims = (0,0.04)
xx = -5.0:0.1:5.0
yy = -5.0:0.1:5.0
probs = [pdf(GMtrue, [x,y]) for y in yy, x in xx]
hmap_true = heatmap(
    xx,
    yy,
    probs,
    clims=clims,
    title="Gaussian Mixture: true distribution",
    xlabel="x",
    ylabel="y"
)
savefig(hmap_true, "figs-GM/fig1.png")

#データの生成
N = 200
X = rand(GMtrue, N)

#ハイパーパラメータ
K = 5
α = 10*ones(K);
【Juliaコード7; 平均プラグイン推測】
#平均プラグイン推測
K = 5
α = 10*ones(K);
probs_plugin = plugin_probs(xx, yy , X, α)
hmap_plugin = heatmap(
    xx,
    yy,
    probs_plugin,
    clims=clims,
    title="Gaussian Mixture: Plug-in estimation",
    xlabel="x",
    ylabel="y"
)
savefig(hmap_plugin, "figs-GM/fig2.png")
【Juliaコード8; 予測分布】
#予測分布
K = 5
α = 10*ones(K);
probs_pred = pred_probs(xx, yy, X, α)
hmap_pred = heatmap(
    xx,
    yy,
    probs_pred,
    clims=clims,
    title="Gaussian Mixture: Predictive distribution",
    xlabel="x",
    ylabel="y"
)
savefig(hmap_pred, "figs-GM/fig3.png")
【Juliaコード9; 汎化損失とWAICの計算】
#WAIC, 汎化損失の計算
println("Gn=$(round(Gn(X, α, GMtrue), digits=4))")
println("WAIC=$(round(WAIC(X, α), digits=4))")
【Juliaコード10; モデル選択】
#モデル選択
Kmin = 2
Kmax = 10
for k in Kmin:Kmax
    α = 10*ones(k)
    wn = WAIC(X,α)
    gn = Gn(X,α,GMtrue)
    println("K=$(k):")
    println("\t Gn=$(round(gn, digits=4))")
    println("\t WAIC=$(round(wn, digits=4))")
end
【Juliaコード11; Gibbsサンプラーによる学習の様子】
#Gibbs samplerによる学習
K = 5
n_samps = 100
α = 10*ones(K)
ret = myGMGS(X, n_samps, α)
μsamps = ret["μsamps"]
λsamps = ret["λsamps"]
πsamps = ret["πsamps"]
_,S = size(πsamps)

#プロット用
xx = -5.:0.1:5
yy = xx
cont = contourf(xlabel="x", ylabel="y",clims=(0,0.04))

anim = @animate for s in 1:S
    params = (MvNormal(μsamps[:,k,s], 1/sqrt(λsamps[k,s])) for k in 1:K)
    GM = MixtureModel(MvNormal[params...], πsamps[:,s])
    prob = [pdf(GM, [x,y]) for x in xx, y in yy]
    contourf!(xx, yy, prob, title="MCMC sample=$(s)")
end

gif(anim, "figs-GM/anim2.gif", fps=10)

逆相関の例

【Juliaコード12; 逆相関】
#ハイパーパラメータ
K = 5
α = 10*ones(K)

#データサイズ
N = 200

#エントロピーの計算
S = calcS(GMtrue)

#複数回データを発生させる
n_expr = 30
estimation_of_GenErr = zeros(n_expr)
true_GenErr = zeros(n_expr)
for iter in 1:n_expr
    #データを生成
    X = rand(GMtrue, N)

    #Gn-Sを計算する
    true_GenErr[iter] = Gn(X, α, GMtrue)-S

    #WAIC-Snを計算する
    estimation_of_GenErr[iter] = WAIC(X, α)-calcSn(X, GMtrue)
end

#相関係数
println("Correlation=$(cor(true_GenErr, estimation_of_GenErr))")

error_plot = plot(true_GenErr, label="Gn-S", xlabel="iter", ylabel="Gn-S, Wn-Sn", markershape=:circle, ls=:dot)
plot!(estimation_of_GenErr, label="WAIC-Sn", markershape=:square, ls=:dot)
savefig(error_plot, "figs-GM/error_plot.png")
参考文献
      [1]渡辺澄夫, ベイズ統計の理論と方法, 2012, コロナ社