Bayesian Neural Network (修正版)

記事の内容


以前書いた記事を色々修正します. 主にモデルとプログラムを修正しました. Bayesian Neural Networkの様々な学習方法を比較します.

問題設定とモデル

この節では, タスクとして分類問題を導入し, そのためのモデルを導入します. モデルは通常のNeural Networkと, そのBayes版を考えます.

問題設定

次のような分類問題を考えます. 画像のようにいくつかの点とそのラベルが与えられているとします. 点1個が1つのデータであり, 対応するラベルが赤丸と青バツです. 全部で16点あります. この2つのクラスの間を仕切る境界の推定が目標です. 要するに二値分類です.

【コード3の実行結果】

記号を導入します. 平面上の各点\(\boldsymbol{x}_n\in\mathbb{R}^{D_x}\)に対して, 対応するラベルデータ\(y_n\in\{1,0\}\)が与えられているとします. サンプルサイズは\(N\)として, データを次のようにまとめておきます.

\begin{equation} X = \{ \boldsymbol{x}_n\}_{n=1}^N ,\quad Y=\{ y_n\}_{n=1}^N \end{equation}

今の場合, \(N=16\)です. 赤い点のラベルを\(y=1\)とし, 青いバツのラベルを\(y=0\)とします.

モデル

モデルを導入します. 今回は以下の2つのモデルを考えます.

  • 中間層が1層の通常のNeural Network
  • 中間層が1層のBayesian Neural Network

赤い点に分類される確率をモデル化します. まず, 通常のNeural Network\(\Phi:\mathbb{R}^{D_x}\to\mathbb{R}\)は, 以下のように設定します.

\begin{equation} \Phi \left( \boldsymbol{x}, \boldsymbol{w}\right) = \sigma\left( W^{(3)}\sigma\left( W^{(2)}\boldsymbol{x} + \boldsymbol{b}^{(2)} \right) + \boldsymbol{b}^{(3)}\right) \end{equation}

ここで, \(\sigma\)はシグモイド関数, 中間層の幅を\(D_0=5\)とし, \(\boldsymbol{w}\in\mathbb{R}^{d_w}\)はパラメータをまとめたベクトルとします. このNeural Networkの出力を赤い点に分類される確率と解釈します.

\begin{equation} p(y=1\mid \boldsymbol{x}) = \Phi\left( \boldsymbol{x}, \boldsymbol{w} \right) \end{equation}

次にBayes版のNeural Networkをモデル化します. Neural Netwrokの重みパラメータに事前分布を設定するとき, Bayesian Neural Networkといいます. 先程のNeural Networkモデルからの連想から, 次のようなBernoulliモデルを考えます.

\begin{equation} p(y\mid \boldsymbol{x},\boldsymbol{w} ) = \Phi \left( \boldsymbol{x}, \boldsymbol{w}\right)^{y}\left\{ 1-\Phi \left( \boldsymbol{x}, \boldsymbol{w} \right) \right\}^{1-y} \end{equation}

事前分布として, 正規分布を仮定しておきます.

\begin{equation} \boldsymbol{w} \sim \mathrm{N}\left( \boldsymbol{0}, \lambda_w^{-1}I_{d_w}\right) \end{equation}

Neural Networkの学習

以下, 次のような学習方法を順に試していきます.

  • 通常のNeural Networkの学習
  • Laplace近似
  • Haniltonian Monte Carlo法
  • HMC + Gibbs sampler
  • 確率的変分推論法

通常のNeural Networkの学習

まずは通常のNeural Networkの学習です. コスト関数を二乗誤差とします. バッチサイズは1, 訓練反復回数は\(10^6\), ステップサイズは\(0.1\)とします. 結果は以下の通りです. 性能は結構良さそうですね.

【コード5の実行結果】

BNNの学習1 : Laplace近似

ここからがBayesian Neural Networkです. まずはLaplace近似による学習です. Laplace近似では, 事後分布を次のように正規分布で近似します. \(\boldsymbol{w}_{\mathrm{MAP}}\)はMAP推定量, \(H(\boldsymbol{w})\)を対数事後分布のHesse行列とします. \(\lambda_w=0.0005\)とします.

\begin{equation} \mathrm{N}\left( \boldsymbol{w}_{\mathrm{MAP}}, -H(\boldsymbol{w}_{\mathrm{MAP}})\right) \end{equation}

Taylor展開から導出できます[2]. JuliaのOptimで無理やり最大化してMAP推定量を求めました. 結果は以下の通りです. 正規分布の割にうまくいってますね.

【コード8の実行結果】

BNNの学習2 : HMC

次にHMCです. leap-frog法の反復回数は\(T=100\), ステップサイズは\(h=0.1\)としました. サンプル数は5000, 最初の500サンプルはburnin期間として除去しました. \(\lambda_w=10^{-3}\)としました. 以下に予測結果を示します. それっぽく分類できてますね.

【コード10の実行結果】

BNNの学習3 : HMC + Gibbs sampler

次に, HMCとGibbs samplerを組み合わせた方法です. \(\boldsymbol{w}\)はHMCで学習し, \(\lambda_w\)はGibbs samplerで学習します. \(\lambda_w\)の事前分布として以下のガンマ分布を仮定します.

\begin{equation} \lambda_w \sim \mathrm{Gamma}(\alpha, \beta),\quad \alpha=1.5,\quad \beta = 10^4 \end{equation}

このとき, 条件付き事後分布は以下のようになります.

\begin{equation} \lambda_w \mid X,Y,\boldsymbol{w} \sim \mathrm{Gamma}\left( \alpha+\frac{d_w}{2}, \beta+\frac{1}{2}\| \boldsymbol{w}\|^2\right) \end{equation}

サンプル数やburnin期間, leap-frog法の設定などは先ほどと同じです. 以下に予測結果を示します. 予測はうまくいってそうです.

【コード12の実行結果】

BNNの学習4 : 変分推論法

次は変分推論法です. こちらで解説したものと同じです. \(\lambda_w=10^{-3}\), 訓練回数は\(2000\)としました. 直前の反復でのELBOとの差が\(10^{-6}\)を下回った時点で反復を終了します. 更新幅は\(\alpha_k=\frac{0.4}{k}\)としました.

【コード14の実行結果】

以下の図は, ELBOを近似計算したものです. ELBOは大きくなっていますが, 色分けはうまくいきませんでした.

【コード15の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using ForwardDiff
using Optim

#statistics
using Random 
using Statistics
using Distributions

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack
【Juliaコード2; 種々の関数定義】
#plot the data and return the figure
function plot_data(X, Y)
    _,N = size(X)
    fig = plot(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, title="data", legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

function plot_data(fig, X, Y)
    _,N = size(X)
    fig = plot!(xticks=0:0.2:1, xlim=[0,1], yticks=0:0.2:1, ylim=[0,1], aspect_ratio=:equal, legend=false)
    for k in 1:N
        if Y[k]==1
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:circle, markersize=10, color="red")
        else
            plot!([X[1,k]], [X[2,k]], st=:scatter, markershape=:x, markersize=10, color="blue")
        end
    end
    return fig
end

#initialize the parameter
function init_params(st)
    @unpack  Dx, Dy, D₀ = st
    W₂ = rand(D₀, Dx)
    W₃ = rand(Dy, D₀)
    b₂ = zeros(D₀)
    b₃ = zeros(Dy)
    return W₂, W₃, b₂, b₃
end

#stick the weights and biases to a large matrix
function stick_params(W₂, W₃, b₂, b₃, Dx)
    tmp1 = vcat(b₂', W₂')
    tmp2 = hcat(tmp1, zeros(Dx+1))
    tmp3 = hcat(W₃, b₃)
    return vcat(tmp2, tmp3)
end

#devide the paramters vector to weights and biases
function reshape_params(wvec, st)
    @unpack  Dx, Dy, D₀ = st
    W = reshape(wvec, (Dx+Dy+1, D₀+1))
    W₂ = view(W, 2:Dx+1, 1:D₀)'
    W₃ = view(W, Dx+2:Dx+Dy+1, 1:D₀)
    b₂ = view(W, 1, 1:D₀)
    b₃ = view(W, Dx+2:Dx+Dy+1, D₀+1)
    return W₂, W₃, b₂, b₃
end 

#sigmoid function
σ(ξ) = 1/(1+exp(-ξ))

#Neural Network
function nn(x, wvec, st)
    W₂, W₃, b₂, b₃ = reshape_params(wvec, st)
    return σ.(W₃*σ.(W₂*x+b₂) + b₃)
end
【Juliaコード3; データの作成】
#create the data
N = 16
X = [
    0.3 0.52 0.3 0.50 0.60 0.7 0.70 0.55  0.85 0.10 0.05 0.20 0.39 0.63 0.86 0.97;
    0.1 0.35 0.9 0.15 0.95 0.2 0.80 0.75  0.55 0.76 0.15 0.45 0.56 0.50 0.80 0.20;
]
Y = vcat(zeros(div(N,2)), ones(div(N,2)))
data = (X=X,Y=Y,N=N)

#plot the data
fig1 = plot_data(X, Y)
savefig(fig1, "figs-BNN/fig1.png")
【Juliaコード4; 通常のNeural Network: 関数定義】
#loss function 
Ln(y_pred, y_data) = norm(y_pred-y_data)^2

#train the neural network
function train_nn(data, n_train, α, wvec₀, st)
    wvec = wvec₀
    @unpack X,Y,N = data
    @unpack Dx,Dy,D₀ = st
    ∇Ln(wvec, idx, X, Y, st) = ForwardDiff.gradient(wvec->Ln(nn(X[:,idx], wvec, st)[1], Y[idx]), wvec)
    @showprogress for k in 1:n_train
        #choose the sample uniformaly
        idx = rand(1:N)
        
        #gradient descent
        wvec = wvec - α*∇Ln(wvec, idx, X, Y, st)
    end
    return wvec
end
【Juliaコード5; 通常のNeural Networkの訓練】
#size
Dx,N = size(X)
Dy = 1
D₀ = 5
st = (Dx=Dx, Dy=Dy, D₀=D₀)

#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
Ws = stick_params(W₂, W₃, b₂, b₃, st.Dx)
wvec₀ = Ws[:]

#train NN
n_train = Int(1e6)
α = 0.1
@time wvec = train_nn(data, n_train, α, wvec₀, st)
wvec_nn = wvec

#predict
fig2 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->nn([x1,x2], wvec, st)[1], st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig2 = plot_data(fig2,X,Y)
plot!(title="prediction")
savefig(fig2, "figs-BNN/fig2.png")
【Juliaコード6; BNN共通の関数】
#log pdf of prior, model, posterior
logpprior(wvec, λ, dw) = logpdf(MvNormal(zeros(dw),1/sqrt(λ)), wvec)
logpmodel(y, x, wvec, st) = logpdf(Bernoulli(nn(x,wvec,st)[1]), y)
loglik(X, Y, N, wvec, st) = sum([logpmodel(Y[n], X[:,n], wvec, st) for n in 1:N])

function logppost(wvec, data, model_params)
    @unpack X,Y,N = data
    @unpack λw,dw,st = model_params
    return loglik(X, Y, N, wvec, st) + logpprior(wvec, λw, dw)
end

function logppost(wvec, λw, data, model_params)
    @unpack X,Y,N = data
    @unpack dw,st = model_params
    return loglik(X, Y, N, wvec, st) + logpprior(wvec, λw, dw)
end

#predictive: returns the probability to new data classified to class 1
function ppred(x, wsamps, st)
    _, n_samps = size(wsamps)
    preds = zeros(n_samps)
    for j in 1:n_samps
        preds[j] = exp(logpmodel(1, x, wsamps[:,j], st))
    end
    return mean(preds)
end
【Juliaコード7; Laplace近似: 関数定義】
#calculate the MAP estimate of parameters
function calc_MAP(data, model_params, wvec₀)
    opt = Optim.optimize(wvec->-logppost(wvec, data, model_params), wvec₀, BFGS())
    return opt.minimizer
end

#calculate the MAP estimate and Hesse matrix
function calc_params(data, model_params, wvec₀)
    @unpack X,Y,N = data
    @unpack λw,dw,st = model_params
    wMAP = calc_MAP(data, model_params, wvec₀)
    H = ForwardDiff.hessian(wvec->logppost(wvec, data, model_params), wMAP)
    return wMAP, Matrix(Hermitian(H))
end
【Juliaコード8; Laplace近似: 実行】
#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
W₀ = stick_params(W₂, W₃, b₂, b₃, Dx)
wvec₀ = W₀[:]
dw = length(wvec₀)

#model params
λw = 1e-3/2
model_params = (λw=λw, dw=dw, st=st) 

#Laplace approximation
@time wMAP, H = calc_params(data, model_params, wvec₀)

#calculate and visualize predictive
n_samps = 5000
wsamps = rand(MvNormalCanon(wMAP, -H), n_samps)
fig3 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->ppred([x1,x2], wsamps, st), st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig3 = plot_data(fig3, X, Y)
plot!(title="prediction: Laplace approximation")
savefig(fig3, "figs-BNN/fig3.png")
【Juliaコード9; HMC: 関数定義】
#one step of Störmer-Verlet method
function myStörmerVerlet(qvec, pvec, h, f)
    p_mid = pvec + h * f(qvec)/2;
    q_new = qvec + h * p_mid;
    p_new = p_mid + h * f(q_new)/2;
    return q_new, p_new
end

#update the position
function update(T, h, f, qvec, pvec)
    qvec_new = qvec
    pvec_new = pvec
    for t in 1:T
        qvec_new, pvec_new = myStörmerVerlet(qvec_new, pvec_new, h, f)
    end
    return qvec_new, pvec_new
end

#MH acceptance and rejection
function accept_or_reject(xvec, xvec_old, pvec, pvec_old, H)
    ΔH = H(xvec, pvec)-H(xvec_old, pvec_old)
    α = min(1.0, exp(-ΔH))
    u = rand()
    if u≤α
        return xvec, pvec
    else
        return xvec_old, pvec_old
    end
end

#Hamiltonian Monte Carlo
function myHMC(data, model_params, wvec₀, n_samps, n_burnin, T, h)
    #initialization
    dw = length(wvec₀)
    wsamps = zeros(dw, n_samps)
    wsamps[:,1] = wvec₀
    wvec = zeros(dw)
    pvec = zeros(dw)
    
    #Hamiltonian and potential
    U(wvec) = -logppost(wvec, data, model_params)
    ∇Uneg(wvec) = -ForwardDiff.gradient(U, wvec)
    H(wvec, pvec) = U(wvec) + norm(pvec)^2/2
    
    #sample
    wvec_old = wvec₀
    pvec_old = randn(dw)
    @showprogress for s in 2:n_samps
        pvec = randn(dw)
        wvec, pvec = update(T, h, ∇Uneg, wvec, pvec)
        wvec, pvec = accept_or_reject(wvec, wvec_old, pvec, pvec_old, H)
        wsamps[:,s] = wvec
        wvec_old = wvec
        pvec_old = pvec
    end
    return wsamps[:,n_burnin:end]
end
【Juliaコード10; HMC: 実行】
#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
W₀ = stick_params(W₂, W₃, b₂, b₃, Dx)
wvec₀ = W₀[:]
dw = length(wvec₀)

#model params
λw = 1e-3
model_params = (λw=λw, dw=dw, st=st) 

#HMC
n_samps = 5000
n_burnin = div(n_samps,10)
T = 100
h = 0.1
@time wsamps = myHMC(data, model_params, wvec₀, n_samps, n_burnin, T, h)

#calculate and visualize predictive
fig4 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->ppred([x1,x2], wsamps, st), st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig4 = plot_data(fig4, X, Y)
plot!(title="prediction: HMC")
savefig(fig4, "figs-BNN/fig4.png")
【Juliaコード11; HMC+Gibbs: 関数定義】
#Hamiltonian Monte Carlo for hierarchical model
function myHMC_HM(data, model_params, wvec₀, n_samps, n_burnin, T, h)
    #hyperparamters of λw
    @unpack α,β,λw = model_params
    
    #initialization
    dw = length(wvec₀)
    wsamps = zeros(dw, n_samps)
    wsamps[:,1] = wvec₀
    λwsamps = zeros(n_samps)
    λwsamps[1] = λw
    wvec = zeros(dw)
    pvec = zeros(dw)
    
    #Hamiltonian and potential
    U(wvec, λw) = -logppost(wvec, λw, data, model_params)
    ∇Uneg(wvec, λw) = -ForwardDiff.gradient(wvec->U(wvec, λw), wvec)
    H(wvec, pvec, λw) = U(wvec, λw) + norm(pvec)^2/2
    
    #sample
    wvec_old = wvec₀
    pvec_old = randn(dw)
    @showprogress for s in 2:n_samps
        pvec = randn(dw)
        wvec, pvec = update(T, h, wvec->∇Uneg(wvec, λw), wvec, pvec)
        wvec, pvec = accept_or_reject(wvec, wvec_old, pvec, pvec_old, (wvec, pvec)->H(wvec, pvec, λw))
        wsamps[:,s] = wvec
        λw = rand(Gamma(α+dw/2, 1/(β+norm(wvec)^2/2)))
        λwsamps[s] = λw
        wvec_old = wvec
        pvec_old = pvec
    end
    return wsamps[:,n_burnin:end], λwsamps[n_burnin:end]
end
【Juliaコード12; HMC+Gibbs: 実行】
#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
W₀ = stick_params(W₂, W₃, b₂, b₃, Dx)
wvec₀ = W₀[:]
dw = length(wvec₀)

#model params
λw = 1e-3
model_params = (λw=λw, dw=dw, st=st, α=1.5, β=1e4)

#HMC+Gibbs sampler
n_samps = 5000
n_burnin = div(n_samps,10)
T = 100
h = 0.1
@time wsamps,λsamps = myHMC_HM(data, model_params, wvec₀, n_samps, n_burnin, T, h)
wsamps_HMC = wsamps

#calculate and visualize predictive
fig5 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->ppred([x1,x2], wsamps, st), st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig5 = plot_data(fig5, X, Y)
plot!(title="prediction: HMC+Gibbs")
savefig(fig5, "figs-BNN/fig5.png")
【Juliaコード13; 変分推論: 関数定義】
#∇θLn(θ)
function ∇θLn(x, y, N, wsamps, θvec, λw, st)
    dw,S = size(wsamps)
    centwvec = wsamps-θvec[1:dw]*ones(S)'
    logpmodels = zeros(S)
    for s in 1:S
        logpmodels[s] = logpmodel(y, x, wsamps[:,s], st)
    end
    arr1 = centwvec .* (ones(dw)*logpmodels')
    arr2 = centwvec .* arr1
    ∇θLnvec = zeros(2*dw)
    for j in 1:dw
        ∇θLnvec[j] = N*exp(-2*θvec[dw+j])*mean(arr1[j,:])-λw*θvec[j]
        ∇θLnvec[dw+j] = N*exp(-2*θvec[dw+j])*mean(arr2[j,:])-N*mean(logpmodels)+1-λw*exp(2*θvec[dw+j])
    end
    return ∇θLnvec
end

#calculate ELBO
function ELBO(X, Y, N, wsamps, θvec, λw, st)
    dw,S = size(wsamps)
    logpmodels = zeros(S)
    for s in 1:S
        for n in 1:N
            logpmodels[s] += logpmodel(Y[n], X[:,n], wsamps[:,s], st)
        end
    end
    return mean(logpmodels)-λw*sum(θvec[1:dw].^2)/2-λw*sum(exp.(2*θvec[dw+1:end]))/2+sum(θvec[dw+1:end])+dw/2+dw*log(λw)/2
end

#sample from approximation distribution r
post_samps(θvec, n_samps, dw) = rand(MvNormal(θvec[1:dw], exp.(θvec[dw+1:2*dw])), n_samps)

#variational inference
function myVI(data, model_params, α, n_train, tol)
    @unpack X,Y,N = data
    @unpack λw, dw, st = model_params
    θvec = vcat(zeros(dw), ones(dw))
    n_samps = 5000
    wsamps = zeros(dw, n_samps)
    history = zeros(n_train)
    history[1] = ELBO(X, Y, N, wsamps, θvec, λw, st)
    @showprogress for k in 2:n_train
        idx = rand(1:N)
        x = X[:,idx]
        y = Y[idx]
        wsamps = post_samps(θvec, n_samps, dw)
        θvec = θvec + α*∇θLn(x, y, N, wsamps, θvec, λw, st)/k
        history[k] = ELBO(X, Y, N, wsamps, θvec, λw, st)
        if abs(history[k]-history[k-1])<tol
            return θvec, history[1:k]
        end
    end
    return θvec, history
end
【Juliaコード14; 変分推論: 実行】
#initialize NN
Random.seed!(42)
W₂, W₃, b₂, b₃ = init_params(st)
W₀ = stick_params(W₂, W₃, b₂, b₃, Dx)
wvec₀ = W₀[:]
dw = length(wvec₀)

#model params
λw = 1e-3
model_params = (λw=λw, dw=dw, st=st)

#calculate the variational parameters
α = 0.4
tol = 1e-6
n_train = 2000
@time θvec, history = myVI(data, model_params, α, n_train, tol)

#posterior sample
n_samps = 5000
wsamps = post_samps(θvec, n_samps, dw)

#calculate and visualize predictive
fig6 = plot(0:0.02:1, 0:0.02:1, (x1,x2)->ppred([x1,x2], wsamps, st), st=:heatmap, c=cgrad(:coolwarm), alpha=0.6, clim=(0,1))
fig6 = plot_data(fig6, X, Y)
plot!(title="prediction: Variational Inference")
savefig(fig6, "figs-BNN/fig6.png")
【Juliaコード15; ELBOの変化】
fig7 = plot(1:n_train, history, xlabel="iter", title="ELBO", label=false)
savefig(fig7, "figs-BNN/fig7.png")
参考文献

      [1] 須山敦志, ベイズ深層学習, 講談社, 2020
      [2] K.P.Murphy, Machine Learning: A Probabilistic Perspective, The MIT Press, 2012