以前書いた記事の修正版です. 扱う問題, プログラム, 構成等を変更しました.
タイトルにある"ELBOの勾配"の必要性について確認しておきましょう. 手元にデータ\(\left\{ \boldsymbol{x}_n\right\}_{n=1}^{N}\)があり, モデル\(p(\boldsymbol{x}\mid\boldsymbol{w})\)とパラメータ\(\boldsymbol{w}\in\mathbb{R}^d\)の事前分布\(p(\boldsymbol{w})\)を作ったとします. 事後分布の推定のために変分推論法を使う場合, 事後分布の近似分布\(r_\eta(\boldsymbol{w})\)を自分で作り, 次式で定義されるELBOを最大化します:
\begin{equation} \mathcal{L}(\boldsymbol{\eta}) = \sum_{n=1}^N\int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} - D_{\mathrm{KL}}[r_\eta(\boldsymbol{w})\| p(\boldsymbol{w})]. \end{equation}第2項は近似事後分布と事前分布の間のKullback-Leiblerダイバージェンスです. 確率的変分推論法を使う場合, 上式の勾配計算が必要です. 本記事では, 以下の勾配計算について考えていきましょう:
\begin{equation} \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}積分と微分の交換ができると仮定すると, 次式の計算が必要になります:
\begin{equation} \int \nabla_\eta r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}. \end{equation}この式は一般にモンテカルロ近似できるとは限らないので, 工夫が必要です. 近似方法としては以下の2つが有名です.
適用範囲が広いのは前者ですが, 後者は一般に分散が小さくなることが知られています. 分散が小さいほど, 更新の効率が良くなります. 本記事では, 両者でどれくらい勾配計算のばらつきに差が出るのかを検証します.
スコア関数推定についてもう少し詳しく説明します. この方法では, 次式に着目して密度関数部分を作り出します:
\begin{equation} \nabla_\eta r_\eta (\boldsymbol{w}) = r_\eta (\boldsymbol{w})\nabla_\eta(\boldsymbol{w})\log r_\eta(\boldsymbol{w}). \end{equation}この式から, 積分の勾配は以下のように近似できます:
\begin{align} & \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} \\ &= \int \nabla_\eta r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w}\\ &= \int r_\eta (\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\nabla_\eta\log r_\eta(\boldsymbol{w})\mathrm{d}\boldsymbol{w}\\ &\simeq \frac{1}{S} \sum_{s=1}^S \log p(\boldsymbol{x}_n\mid \boldsymbol{w}^{(s)})\nabla_\eta \log r_\eta(\boldsymbol{w}^{(s)}),\quad \boldsymbol{w}^{(s)}\sim r_\eta. \end{align}reparametrization trickについて説明します. ここでは, 近似分布\(r_\eta\)が以下の正規分布であると仮定します:
\begin{equation} \mathrm{N}(\boldsymbol{m},\mathrm{diag}(\boldsymbol{s})^2),\quad \boldsymbol{\eta} = \begin{bmatrix}\boldsymbol{m}\\ \log \boldsymbol{s} \end{bmatrix} \in \mathbb{R}^{2d}. \end{equation}reparametrization trickでは, 以下の変数変換に着目します:
\begin{equation} \boldsymbol{w} = g(\boldsymbol{\varepsilon},\boldsymbol{\eta}) = \boldsymbol{m} + \boldsymbol{s} \odot \boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathrm{N}(\boldsymbol{0},I_d). \end{equation}従って, 積分の勾配は以下のように近似できます:
\begin{align} & \nabla_\eta \int r_\eta(\boldsymbol{w})\log p(\boldsymbol{x}_n\mid \boldsymbol{w})\mathrm{d}\boldsymbol{w} \\ &=\nabla_\eta \int \pi(\boldsymbol{\varepsilon})\log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon} \\ &= \int \pi(\boldsymbol{\varepsilon})\nabla_\eta\log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon} \\ &\simeq \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}^{(s)} \sim \mathrm{N}(\boldsymbol{0},I_d). \end{align}ここで, \(\pi\)は, 標準正規分布の密度関数です.
本節では, ばらつきの観点からreparametrization trickが望ましいことを実験により確認します.
前節で考えた積分の勾配のモンテカルロ近似のばらつき具合を計算します. 以下のような記号を導入しておきましょう:
\begin{align} I_{\mathrm{score}}(S) &= \frac{1}{S} \sum_{s=1}^S \log p(\boldsymbol{x}_n\mid \boldsymbol{w}^{(s)})\nabla_\eta \log r_\eta(\boldsymbol{w}^{(s)}) \\ I_{\mathrm{RP}}(S) &= \frac{1}{S} \sum_{s=1}^S \nabla_\eta \log p(\boldsymbol{x}_n\mid g(\boldsymbol{\varepsilon}^{(s)},\boldsymbol{\eta}))\mathrm{d}\boldsymbol{\varepsilon}. \end{align}上の2つの和はいずれも確率変数であり, これら確率変数の平均と標準偏差を調べます. 平均と標準偏差は, 上の確率変数のサンプルを\(m_{\mathrm{max}}\)個作ることで推定し, \(\boldsymbol{\eta}\)は予め適当な回数反復して更新したものを使います. サンプルサイズを変化させた時の変化も見たいので, 結局次のような手順を踏みます.
Initialize \(\boldsymbol{\eta}^{(0)}\).
for \(k=0,1,\cdots,k_{\mathrm{max}}\)
Update \(\boldsymbol{\eta}^{(k+1)}=\boldsymbol{\eta}^{(k)}+\alpha_k\nabla\mathcal{L}(\boldsymbol{\eta}^{(k)})\).
end
for \(S=1,\cdots,S_{\mathrm{max}}\)
for \(m=1,\cdots,m_{\mathrm{max}}\)
Calculate \(I(S)_{\mathrm{score}}\) and \(I_{\mathrm{RP}}\).
end
end
今回は, こちらの記事のロジスティック回帰の例で試してみます.
実験結果を示します. 反復回数\(k_{\mathrm{max}}=100\), 最大サンプルサイズ\(S_{\mathrm{max}}=100\), サンプル数\(m_{\mathrm{max}}=500\)としました. 下図は, スコア関数推定の結果で, \(\eta\)の各成分ごとの微分値を表示しました. 横軸はサンプルサイズで, 青線が平均, 薄い青領域が1シグマ範囲です.
下図は, reparametrization trickの推定結果です. 全体的にスコア関数推定よりもばらつきが小さいですね. 精度的には\(S=1\)サンプルでも十分かもしれません.
#mathematics
using LinearAlgebra
using ForwardDiff
using Flux
#statistics
using Random
using Distributions
using Statistics
#visualize
using Plots
pyplot()
#macros
using UnPack
using ProgressMeter
#split parameters
function split_params(vec)
return vec[1:2], vec[3:end]
end
#reparametrize
function reparameterize(var_mean,var_logstd)
var_mean + exp.(var_logstd) .* randn(2)
end
#logpmodel
function logpmodel(y,x,wvec)
logpdf(Bernoulli(sigmoid(wvec[1]+wvec[2]*x)),y)
end
#ELBO
function ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N)
val = 0
for n in minibatch
var_mean,var_logstd = split_params(ηvec)
val += logpmodel(Y[n],X[n],reparameterize(var_mean,var_logstd))
end
return N*val/length(minibatch)-(norm(ηvec[1:2]-m₀vec)^2+norm(exp.(ηvec[3:end]))^2)/2/s₀^2+sum(ηvec[3:end])+(1-2*log(s₀))
end
ELBO(X,Y,ηvec,m₀vec,s₀,N) = ELBO(X,Y,ηvec,1:length(X),m₀vec,s₀,N)
#create model
function create_model(X,Y,m₀vec,s₀,N)
ηvec = zeros(4)
ps = Flux.params(ηvec)
loss_func = minibatch->(-ELBO(X,Y,ηvec,minibatch,m₀vec,s₀,N))
return ηvec,ps,loss_func
end
#stochastic variational inference
function stochastic_variational_inference(data,model_params,n_train,minibatch_size)
@unpack X,N = data
@unpack m₀vec,s₀ = model_params
opt = ADAM(0.01)
history = zeros(n_train)
ηvec,ps,loss_func = create_model(X,Y,m₀vec,s₀,N)
@showprogress for k in 1:n_train
minibatch = sample(1:N,minibatch_size)
Flux.train!(loss_func,ps,minibatch,opt)
history[k] = ELBO(X,Y,ηvec,m₀vec,s₀,N)
end
return ηvec,history
end
#log likelihood
function loglik(X,Y,N,wvec)
val = 0
for n in 1:N
val += logpmodel(Y[n],X[n],wvec)
end
val
end
#reconstruction error using reconstruction error
function loglik_RP(ηvec,X,Y,N,normal_samp)
var_mean,var_logstd = split_params(ηvec)
loglik(X,Y,N,var_mean+exp.(var_logstd).*normal_samp)
end
#one sample for monte carlo estimate : reparameterization trick
function MC_sample_RP(ηvec,X,Y,N,S)
MC_samps = zeros(4,S)
normal_samps = randn(2,S)
for s in 1:S
MC_samps[:,s] = ForwardDiff.gradient(ηvec->loglik_RP(ηvec,X,Y,N,normal_samps[:,s]),ηvec)
end
return mean(MC_samps,dims=2)
end
#log rη(wvec)
function logrηwvec(ηvec,wvec)
var_mean,var_logstd = split_params(ηvec)
logpdf(MvNormal(var_mean,exp.(var_logstd)),wvec)
end
#one sample for monte carlo estimate : score function estimator
function MC_sample_score(ηvec,X,Y,N,S)
MC_samps = zeros(4,S)
var_mean,var_logstd = split_params(ηvec)
var_samps = var_mean .+ exp.(var_logstd) .* randn(2,S)
for s in 1:S
MC_samps[:,s] = (
loglik(X,Y,N,var_samps[:,s])*ForwardDiff.gradient(ηvec->logrηwvec(ηvec,var_samps[:,s]),ηvec)
)
end
return mean(MC_samps,dims=2)
end
#compute M samples of monte carlo estimate
function comp_MC_samps(m_max,ηvec,X,Y,N,S,MC_sample_func)
MC_samps = zeros(4,m_max)
for m in 1:m_max
MC_samps[:,m] = MC_sample_func(ηvec,X,Y,N,S)
end
return MC_samps
end
#compute estimate of sample mean and sample std of the gradient
function mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_func)
@unpack X,Y,N = data
MC_samps = zeros(4,m_max)
means = zeros(4,S_max)
stds = zeros(4,S_max)
@showprogress for S in 1:S_max
MC_samps = comp_MC_samps(m_max,ηvec,X,Y,N,S,MC_sample_func)
means[:,S] = mean(MC_samps,dims=2)
stds[:,S] = std(MC_samps,dims=2)
end
return means,stds
end
mean_and_std_of_grad_RP(data,S_max,m_max,ηvec) = mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_RP)
mean_and_std_of_grad_score(data,S_max,m_max,ηvec) = mean_and_std_of_grad(data,S_max,m_max,ηvec,MC_sample_score)
#create data
Random.seed!(42)
w₁ = -4.0
w₂ = 4.0
w_true = (w₁=w₁,w₂=w₂)
N = 30
X = sort(rand(-10:10,N))
Y = [rand(Bernoulli(sigmoid(w₁+w₂*X[n]))) for n in 1:N]
function true_pdf(y,x,w_true)
@unpack w₁,w₂ = w_true
pdf(Bernoulli(sigmoid(w₁+w₂*x)),y)
end
#data and model parameters
data = (X=X,Y=Y,N=N)
model_params = (m₀vec=zeros(2),s₀=1)
#training
n_train = 100
minibatch_size = N
@time ηvec,history = stochastic_variational_inference(data,model_params,n_train,minibatch_size)
#maximum n_sample and n_sample of monte carlo estimate
S_max = 100
m_max = 500
means_score,stds_score = mean_and_std_of_grad_score(data,S_max,m_max,ηvec)
means_RP,stds_RP = mean_and_std_of_grad_RP(data,S_max,m_max,ηvec)
p_score_1 = plot(means_score[1,:],ribbons=stds_score[1,:],title="gradient wrt η₁")
p_score_2 = plot(means_score[2,:],ribbons=stds_score[2,:],title="gradient wrt η₂")
p_score_3 = plot(means_score[3,:],ribbons=stds_score[3,:],title="gradient wrt η₃")
p_score_4 = plot(means_score[4,:],ribbons=stds_score[4,:],title="gradient wrt η₄")
fig1 = plot(p_score_1,p_score_2,p_score_3,p_score_4,label="1σ",ylim=(-8,8),xlabel="S",ylabel="∂η I_score(S)")
savefig(fig1,"figs-RP/fig1.png")
p_RP_1 = plot(means_RP[1,:],ribbons=stds_RP[1,:],title="gradient wrt η₁")
p_RP_2 = plot(means_RP[2,:],ribbons=stds_RP[2,:],title="gradient wrt η₂")
p_RP_3 = plot(means_RP[3,:],ribbons=stds_RP[3,:],title="gradient wrt η₃")
p_RP_4 = plot(means_RP[4,:],ribbons=stds_RP[4,:],title="gradient wrt η₄")
fig2 = plot(p_RP_1,p_RP_2,p_RP_3,p_RP_4,label="1σ",ylim=(-8,8),xlabel="S",ylabel="∂η I_RP(S)")
savefig(fig2,"figs-RP/fig2.png")