ELBOの勾配近似方法を比較する(修正版)

記事の内容


  • 動機と概要
    • 変分推論法と勾配近似
    • 近似方法1:スコア関数推定
    • 近似方法2:reparametrization trick
  • 実験
    • 実験内容
    • 実験結果

以前書いた記事の修正版です. 扱う問題, プログラム, 構成等を変更しました.

動機と概要

変分推論法と勾配近似

タイトルにある"ELBOの勾配"の必要性について確認しておきましょう. 手元にデータ\(\left\{ \boldsymbol{x}_n\right\}_{n=1}^{N}\)があり, モデル\(p(\boldsymbol{x}\mid\boldsymbol{w})\)とパラメータ\(\boldsymbol{w}\in\mathbb{R}^d\)の事前分布\(p(\boldsymbol{w})\)を作ったとします. 事後分布の推定のために変分推論法を使う場合, 事後分布の近似分布\(r_\eta(\boldsymbol{w})\)を自分で作り, 次式で定義されるELBOを最大化します:

\begin{equation} \mathcal{L}(\boldsymbol{\eta}) = \sum_{n=1}^N\int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} - D_{\mathrm{KL}}[r_\eta(\boldsymbol{w})\| p(\boldsymbol{w})]. \end{equation}

第2項は近似事後分布と事前分布の間のKullback-Leiblerダイバージェンスです. 確率的変分推論法を使う場合, 上式の勾配計算が必要です. 本記事では, 以下の勾配計算について考えていきましょう:

\begin{equation} \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}

積分と微分の交換ができると仮定すると, 次式の計算が必要になります:

\begin{equation} \int \nabla_\eta r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}

この式は一般にモンテカルロ近似できるとは限らないので, 工夫が必要です. 近似方法としては以下の2つが有名です.

  • スコア関数推定:無理矢理, 密度関数の部分を作り出してからモンテカルロ近似する方法.
  • reparametrization trick:近似事後分布のパラメータをlogの部分に押し付けてからモンテカルロ近似する方法.

適用範囲が広いのは前者ですが, 後者は一般に分散が小さくなることが知られています. 分散が小さいほど, 更新の効率が良くなります. 本記事では, 両者でどれくらい勾配計算のばらつきに差が出るのかを検証します.

近似方法1:スコア関数推定

スコア関数推定についてもう少し詳しく説明します. この方法では, 次式に着目して密度関数部分を作り出します:

\begin{equation} \nabla_\eta r_\eta (\boldsymbol{w}) = r_\eta (\boldsymbol{w})\nabla_\eta(\boldsymbol{w})\log r_\eta(\boldsymbol{w}). \end{equation}

この式から, 積分の勾配は以下のように近似できます:

\begin{align} & \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} \\ &= \int \nabla_\eta r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}\\ &= \int r_\eta (\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\nabla_\eta\log r_\eta(\boldsymbol{w})\mathrm{d}\boldsymbol{w}\\ &\simeq \frac{1}{S} \sum_{s=1}^S \log p(\boldsymbol{x}_n\mid \boldsymbol{w}^{(s)})\nabla_\eta \log r_\eta(\boldsymbol{w}^{(s)}),\quad \boldsymbol{w}^{(s)}\sim r_\eta. \end{align}

近似方法2:reparametrization trick

reparametrization trickについて説明します. ここでは, 近似分布\(r_\eta\)が以下の正規分布であると仮定します:

\begin{equation} \mathrm{N}(\boldsymbol{m},\mathrm{diag}(\boldsymbol{s})^2),\quad \boldsymbol{\eta} = \begin{bmatrix}\boldsymbol{m}\\ \log \boldsymbol{s} \end{bmatrix} \in \mathbb{R}^{2d}. \end{equation}

reparametrization trickでは, 以下の変数変換に着目します:

\begin{equation} \boldsymbol{w} = g(\boldsymbol{\varepsilon},\boldsymbol{\eta}) = \boldsymbol{m} + \boldsymbol{s} \odot \boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathrm{N}(\boldsymbol{0},I_d). \end{equation}

従って, 積分の勾配は以下のように近似できます:

\begin{align} & \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} \\ &=\nabla_\eta \int \pi(\boldsymbol{\varepsilon})\log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon} \\ &= \int \pi(\boldsymbol{\varepsilon})\nabla_\eta\log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon} \\ &\simeq \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}^{(s)} \sim \mathrm{N}(\boldsymbol{0},I_d). \end{align}

ここで, \(\pi\)は, 標準正規分布の密度関数です.

実験

本節では, ばらつきの観点からreparametrization trickが望ましいことを実験により確認します.

実験内容

前節で考えた積分の勾配のモンテカルロ近似のばらつき具合を計算します. 以下のような記号を導入しておきましょう:

\begin{align} I_{\mathrm{score}}(S) &= \frac{1}{S} \sum_{s=1}^S \log p(\boldsymbol{x}_n\mid \boldsymbol{w}^{(s)})\nabla_\eta \log r_\eta(\boldsymbol{w}^{(s)}) \\ I_{\mathrm{RP}}(S) &= \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon}. \end{align}

上の2つの和はいずれも確率変数であり, これら確率変数の平均と標準偏差を調べます. 平均と標準偏差は, 上の確率変数のサンプルを\(m_{\mathrm{max}}\)個作ることで推定し, \(\boldsymbol{\eta}\)は予め適当な回数反復して更新したものを使います. サンプルサイズを変化させた時の変化も見たいので, 結局次のような手順を踏みます.

【実験手順の概要】

Initialize \(\boldsymbol{\eta}^{(0)}\).
for \(k=0,1,\cdots,k_{\mathrm{max}}\)
Update \(\boldsymbol{\eta}^{(k+1)}=\boldsymbol{\eta}^{(k)}+\alpha_k\nabla\mathcal{L}(\boldsymbol{\eta}^{(k)})\).
end
for \(S=1,\cdots,S_{\mathrm{max}}\)
for \(m=1,\cdots,m_{\mathrm{max}}\)
Calculate \(I(S)_{\mathrm{score}}\) and \(I_{\mathrm{RP}}\).
end
end

今回は, こちらの記事のロジスティック回帰の例で試してみます.

実験結果

実験結果を示します. 反復回数\(k_{\mathrm{max}}=100\), 最大サンプルサイズ\(S_{\mathrm{max}}=100\), サンプル数\(m_{\mathrm{max}}=500\)としました. 下図は, スコア関数推定の結果で, \(\eta\)の各成分ごとの微分値を表示しました. 横軸はサンプルサイズで, 青線が平均, 薄い青領域が1シグマ範囲です.

【コード6の実行結果】

下図は, reparametrization trickの推定結果です. 全体的にスコア関数推定よりもばらつきが小さいですね. 精度的には\(S=1\)サンプルでも十分かもしれません.

【コード7の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using ForwardDiff
using Flux

#statistics
using Random
using Distributions
using Statistics

#visualize
using Plots
pyplot()

#macros
using UnPack
using ProgressMeter
【Juliaコード2; 変分推論用の関数定義】
#split parameters
function split_params(vec)
    return vec[1:2], vec[3:end]
end

#reparametrize
function reparameterize(var_mean,var_logstd)
    var_mean + exp.(var_logstd) .* randn(2)
end

#logpmodel
function logpmodel(y,x,wvec)
    logpdf(Bernoulli(sigmoid(wvec[1]+wvec[2]*x)),y)
end

#ELBO
function ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N)
    val = 0
    for n in minibatch
        var_mean,var_logstd = split_params(ηvec)
        val += logpmodel(Y[n],X[n],reparameterize(var_mean,var_logstd))
    end
    return N*val/length(minibatch)-(norm(ηvec[1:2]-m₀vec)^2+norm(exp.(ηvec[3:end]))^2)/2/s₀^2+sum(ηvec[3:end])+(1-2*log(s₀))
end
ELBO(X,Y,ηvec,m₀vec,s₀,N) = ELBO(X,Y,ηvec,1:length(X),m₀vec,s₀,N)

#create model
function create_model(X,Y,m₀vec,s₀,N)
    ηvec = zeros(4)
    ps = Flux.params(ηvec)
    loss_func = minibatch->(-ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N))
    return ηvec,ps,loss_func
end

#stochastic variational inference
function stochastic_variational_inference(data,model_params,n_train,minibatch_size)
    @unpack X,N = data
    @unpack m₀vec,s₀ = model_params
    opt = ADAM(0.01)
    history = zeros(n_train)
    ηvec,ps,loss_func = create_model(X,Y,m₀vec,s₀,N)
    @showprogress for k in 1:n_train
        minibatch = sample(1:N,minibatch_size)
        Flux.train!(loss_func,ps,minibatch,opt)
        history[k] = ELBO(X,Y,ηvec,m₀vec,s₀,N)
    end
    return ηvec,history
end
【Juliaコード3; 実験用の関数定義】
#log likelihood
function loglik(X,Y,N,wvec)
    val = 0
    for n in 1:N
        val += logpmodel(Y[n],X[n],wvec)
    end
    val
end

#reconstruction error using reconstruction error
function loglik_RP(ηvec,X,Y,N,normal_samp)
    var_mean,var_logstd = split_params(ηvec)
    loglik(X,Y,N,var_mean+exp.(var_logstd).*normal_samp)
end

#one sample for monte carlo estimate  : reparameterization trick
function MC_sample_RP(ηvec,X,Y,N,S)
    MC_samps = zeros(4,S)
    normal_samps = randn(2,S)
    for s in 1:S
        MC_samps[:,s] = ForwardDiff.gradient(ηvec->loglik_RP(ηvec,X,Y,N,normal_samps[:,s]),ηvec)
    end
    return mean(MC_samps,dims=2)
end

#log rη(wvec)
function logrηwvec(ηvec,wvec)
    var_mean,var_logstd = split_params(ηvec)
    logpdf(MvNormal(var_mean,exp.(var_logstd)),wvec)
end

#one sample for monte carlo estimate : score function estimator
function MC_sample_score(ηvec,X,Y,N,S)
    MC_samps = zeros(4,S)
    var_mean,var_logstd = split_params(ηvec)
    var_samps = var_mean .+ exp.(var_logstd) .* randn(2,S)
    for s in 1:S
        MC_samps[:,s] = (
            loglik(X,Y,N,var_samps[:,s])*ForwardDiff.gradient(ηvec->logrηwvec(ηvec,var_samps[:,s]),ηvec)
            )
    end
    return mean(MC_samps,dims=2)
end

#compute M samples of monte carlo estimate
function comp_MC_samps(m_max,ηvec,X,Y,N,S,MC_sample_func)
    MC_samps = zeros(4,m_max)
    for m in 1:m_max
        MC_samps[:,m] = MC_sample_func(ηvec,X,Y,N,S)
    end
    return MC_samps
end

#compute estimate of sample mean and sample std of the gradient
function mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_func)
    @unpack X,Y,N = data
    MC_samps = zeros(4,m_max)
    means = zeros(4,S_max)
    stds = zeros(4,S_max)
    @showprogress for S in 1:S_max
        MC_samps = comp_MC_samps(m_max,ηvec,X,Y,N,S,MC_sample_func)
        means[:,S] = mean(MC_samps,dims=2)
        stds[:,S] = std(MC_samps,dims=2)
    end
    return means,stds
end
mean_and_std_of_grad_RP(data,S_max,m_max,ηvec) = mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_RP)
mean_and_std_of_grad_score(data,S_max,m_max,ηvec) = mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_score)
【Juliaコード4; 変分パラメータの推定】
#create data
Random.seed!(42)
w₁ = -4.0
w₂ = 4.0
w_true = (w₁=w₁,w₂=w₂)
N = 30
X = sort(rand(-10:10,N))
Y = [rand(Bernoulli(sigmoid(w₁+w₂*X[n]))) for n in 1:N]

function true_pdf(y,x,w_true)
    @unpack w₁,w₂ = w_true
    pdf(Bernoulli(sigmoid(w₁+w₂*x)),y)
end

#data and model parameters
data = (X=X,Y=Y,N=N)
model_params = (m₀vec=zeros(2),s₀=1)

#training
n_train = 100
minibatch_size = N
@time ηvec,history = stochastic_variational_inference(data,model_params,n_train,minibatch_size)
【Juliaコード5; 実験】
#maximum n_sample and n_sample of monte carlo estimate
S_max = 100
m_max = 500
means_score,stds_score = mean_and_std_of_grad_score(data,S_max,m_max,ηvec)
means_RP,stds_RP = mean_and_std_of_grad_RP(data,S_max,m_max,ηvec)
【Juliaコード6; スコア関数推定の結果可視化】
p_score_1 = plot(means_score[1,:],ribbons=stds_score[1,:],title="gradient wrt η₁")
p_score_2 = plot(means_score[2,:],ribbons=stds_score[2,:],title="gradient wrt η₂")
p_score_3 = plot(means_score[3,:],ribbons=stds_score[3,:],title="gradient wrt η₃")
p_score_4 = plot(means_score[4,:],ribbons=stds_score[4,:],title="gradient wrt η₄")
fig1 = plot(p_score_1,p_score_2,p_score_3,p_score_4,label="1σ",ylim=(-8,8),xlabel="S",ylabel="∂η I_score(S)")
savefig(fig1,"figs-RP/fig1.png")
【Juliaコード7; reparameterization trickの結果可視化】
p_RP_1 = plot(means_RP[1,:],ribbons=stds_RP[1,:],title="gradient wrt η₁")
p_RP_2 = plot(means_RP[2,:],ribbons=stds_RP[2,:],title="gradient wrt η₂")
p_RP_3 = plot(means_RP[3,:],ribbons=stds_RP[3,:],title="gradient wrt η₃")
p_RP_4 = plot(means_RP[4,:],ribbons=stds_RP[4,:],title="gradient wrt η₄")
fig2 = plot(p_RP_1,p_RP_2,p_RP_3,p_RP_4,label="1σ",ylim=(-8,8),xlabel="S",ylabel="∂η I_RP(S)")
savefig(fig2,"figs-RP/fig2.png")