変分オートエンコーダ

記事の内容


以前書いた, Bayes変分オートエンコーダを修正します. 修正点は以下の2点です. まず, Bayes版ではない通常の変分オートエンコーダの説明に差し替えます. 2点目は, 実験をクラスタリングの実験からMNISTのデータを用いたものに変更します. クラスタリングの実験は, 学習しなくても初期パラメータでクラスタリングが達成できてしまう例でした. Neural Networkは基本的に連続なので, 何となく色分けができるのは自明でした...

変分オートエンコーダの概要

問題設定

画像などの高次元のデータは, 高次元のEuclid空間の1点として表されます. 一方, 手書き文字のデータを覗いてみると, 画像の端っこの方は黒(値としては0), 中央付近には白い部分(値としては1)が広がっています. 従って, 画像データの集合は高次元のEuclid空間において"遍在"しているというよりは, 等式や不等式によって制約された低次元の空間に"偏在"していると考えられます. 各データ点に対してこの低次元の表現を抽出できれば, データの背後の構造を探ったり, 新たな画像を生成したりできます. 変分オートエンコーダ([1],[2],[3])は, この低次元の表現を抽出するためのモデルと学習方法を提供してくれるもので, 深層生成モデルの一種です.

モデルと方針

手元に高次元のデータ\(X=\{\boldsymbol{x}_1,\cdots,\boldsymbol{x}_N\}\)があるとします. このデータに対応する低次元の表現\(Z=\{\boldsymbol{z}_1,\cdots,\boldsymbol{z}_N\}\)を推定するのが目的です. 次のような生成過程を考えます.

\begin{align} \boldsymbol{z} &\sim \mathrm{N}(\boldsymbol{0}, I)\\ \boldsymbol{x}\mid\boldsymbol{z} &\sim \mathrm{N}(\boldsymbol{u}_w(\boldsymbol{z}), \mathrm{diag}(\boldsymbol{v}_w({\boldsymbol{z}}))) \end{align}

上のモデルにおいて, 低次元から高次元への変換は複雑と考えられるため, 以下のNeural Networkによって表現します. 重みパラメータを\(\boldsymbol{w}\)とし, 隠れ層は1層とします.

\begin{equation} \begin{bmatrix} \boldsymbol{u}_w(\boldsymbol{z}) \\ \log\boldsymbol{v}_w(\boldsymbol{z}) \end{bmatrix} = \Phi_w(\boldsymbol{z}) \end{equation}

\(\boldsymbol{w}\)をデータから定め, \(Z\)の事後分布\(p_w(Z\mid X)\)を求めることが目標です. しかし, \(Z\)の事後分布への寄与は複雑であるため, 解析的に計算するのは困難です. そこで, 変分推論法を用いて事後分布を近似します.

近似分布の導入とELBOの導出

\(Z\)の事後分布を以下のように近似します. \(\boldsymbol{\eta}\)は近似分布のパラメータです.

\begin{equation} r_{\eta}(Z) = \prod_{n=1}^{N}r_{\eta}(\boldsymbol{z}_n) \simeq p_w(Z\mid X) \end{equation}

ただし, 各\(r_{\eta}(\boldsymbol{z}_n)\)は正規分布\(\mathrm{N}(\boldsymbol{m}_{\eta}(\boldsymbol{x}_n), \boldsymbol{s}_{\eta}(\boldsymbol{x}_n))\)の確率密度関数とします. ただし, 正規分布のパラメータは平均ベクトルと共分散行列の2つであり, それがデータサイズ分あると考えると, その全てを調整するのは大変ですね... そこで, これらの近似分布のパラメータをデータから回帰します. 高次元のデータから低次元のパラメータを得る変換も複雑であると想定し, 以下のようなNeural Networkを用います. 重みパラメータは\(\boldsymbol{\eta}\)とし, 隠れ層は1層とします.

\begin{equation} \begin{bmatrix} \boldsymbol{m}_\eta(\boldsymbol{x}) \\ \log\boldsymbol{s}_\eta(\boldsymbol{x}) \end{bmatrix} = \Psi_{\eta}(\boldsymbol{x}) \end{equation}

結局求めたいのは2つのNeural Networkの重みパラメータ\(\boldsymbol{w}\)と\(\boldsymbol{\eta}\)であり, 以下のような最適化問題に帰着します.

\begin{equation} \underset{w,\eta}{\mathrm{argmin}} D_{\mathrm{KL}} [r_\eta(Z)\| p_w(Z\mid X)] \end{equation}

ELBOを計算すると, 以下のようになります. \(D_z\)は\(\boldsymbol{z}\)の次元です. 実際には, このELBOのミニバッチ版を用います.

\begin{equation} \mathcal{L}(\boldsymbol{w},\boldsymbol{\eta}) = \sum_{n=1}^N\left\{ \int r_\eta(\boldsymbol{z}_n)\log p_w(\boldsymbol{x}_n\mid\boldsymbol{z}_n)\mathrm{d}\boldsymbol{z}_n)+\frac{D_z}{2}-\frac{1}{2}\|\boldsymbol{m}_\eta(\boldsymbol{x}_n)\|^2-\frac{1}{2}\|\boldsymbol{s}_\eta(\boldsymbol{x}_n)\|^2+\mathrm{sum}(\boldsymbol{s}_\eta(\boldsymbol{x}_n))\right\} \end{equation}

あとは, 確率的変分推論, reparameterization trickを組み合わせると推論ができます.

数値実験

実験内容

MNISTの手書き文字画像を用いて実験します. 書いたプログラムが遅すぎたので, ラベルが0と1の画像のうち, 最初の1000枚分だけを用いて学習しました. 画像をベクトル化するとデータの次元は\(28^2\)です. ここから, \(D_z=2\)次元の表現を抽出することにします. ミニバッチのサイズは10とし, 変分推論アルゴリズム全体の反復回数は500回としました. また, Neural Networkの活性化関数は\(\tanh\)とし, 隠れ層のユニット数は10です. なお, 4スレッド並列で計算しました.

実験結果

以下に結果を示します. まず, パラメータ学習前の\(\boldsymbol{z}_n\)の様子です. 各点が, 高次元の画像の低次元の表現に相当します. 黒い点はラベルが0の画像, 白い点はラベルが1の画像です. Neural Networkは連続関数なので, 高次元の空間で近くにある点同士は, 低次元空間でも近くにあります. よって初期パラメータでもある程度色分けがなされます.

【コード8の実行結果】

パラメータ学習後の様子を下図に示します. 白い点は一直線上に並んでいるように見えます. 黒い点は1点に集中しているように見えます. 例えば, 白い点が並んでいる直線常に新たに点をとって\(\Phi_w\)で写せば, 新たな画像が生成できます. ただし実際には低次元過ぎて情報がほとんど削ぎ落とされ, 意味のある画像は生成できません. もう少し\(\boldsymbol{z}\)の次元を大きくすればうまくはず... 黒と白が重なっている部分は, 0と1のキメラみたいな画像が誕生するかもしれません...

【コード9の実行結果】

下図はELBOの変化をプロットしたものです. 計算は結構時間がかかるので50反復ごとに記録しています.

【コード6の実行結果】

上の図では, ELBOが0に近づいているように見えますが, 実際にはかなり差があります. 下図は, ELBOの絶対値を対数スケールでプロットしたものです.

【コード7の実行結果】

コード

【Juliaコード1; 初期化】
#mathematics
using LinearAlgebra
using ForwardDiff

#datasets
using MLDatasets

#statistics
using Distributions
using Statistics
using Random
using StatsBase

#visualize
using Plots
pyplot()

#macros
using ProgressMeter
using UnPack

#threads
using Base.Threads
nthreads()
【Juliaコード2; データの読み込み】
#reshape images
function vecX(x)
    d1,d2,N = size(x)
    X = zeros(d1*d2, N)
    for n in 1:N
        X[:,n] = x[:,:,n][:]
    end
    return X
end

#get the train data
train_x, train_y = MNIST.traindata()
X = vecX(train_x)

#images whose labels are 0 or 1
idx0or1 = Bool.((train_y.==0) + (train_y.==1))
X = X[:,idx0or1]
Y = train_y[idx0or1]
X = X[:,1:1000]
Y = Y[1:1000]

#data
Dx, N = size(X)
data = (X=X, Dx=Dx, N=N)
【Juliaコード3; Neural Networkの定義】
#initialize the parameter
function init_params(st)
    @unpack  DI, DO, DM = st
    W₂ = randn(DM, DI)
    W₃ = randn(DO, DM)
    b₂ = zeros(DM)
    b₃ = zeros(DO)
    return W₂, W₃, b₂, b₃
end

#stick the weights and biases to a large matrix
function stick_params(W₂, W₃, b₂, b₃, Dx)
    tmp1 = vcat(b₂', W₂')
    tmp2 = hcat(tmp1, zeros(Dx+1))
    tmp3 = hcat(W₃, b₃)
    return vcat(tmp2, tmp3)
end

#devide the paramters vector to weights and biases
function reshape_params(wvec, st)
    @unpack  DI, DO, DM = st
    W = reshape(wvec, (DI+DO+1, DM+1))
    W₂ = view(W, 2:DI+1, 1:DM)'
    W₃ = view(W, DI+2:DI+DO+1, 1:DM)
    b₂ = view(W, 1, 1:DM)
    b₃ = view(W, DI+2:DI+DO+1, DM+1)
    return W₂, W₃, b₂, b₃
end 

#Neural Network
function nn(x, wvec, st)
   @unpack  DI, DO, DM = st
    W₂, W₃, b₂, b₃ = reshape_params(wvec, st)
    return W₃*tanh.(W₂*x+b₂) + b₃
end
【Juliaコード4; 関数の定義】
#encoder and decoder
Φ(zn, wvec, stΦ) = nn(zn, wvec, stΦ)
Ψ(xn, ηvec, stΨ) = nn(xn, ηvec, stΨ)

#mean and logstd, std
function mean_and_logstd(vecn, pvec, D, st)
    tmp = nn(vecn, pvec, st)
    @inbounds tmp[1:D], tmp[D+1:end]
end

function mean_and_std(vecn, pvec, D, st)
    mvec, logsvec = mean_and_logstd(vecn, pvec, D, st)
    mvec, exp.(logsvec)
end

#used for reparameterization trick of zn
function h(ϵvec, ηvec, xn, Dz, stΨ)
    mvec, svec = mean_and_std(xn, ηvec, Dz, stΨ)
    mvec + svec .* ϵvec
end

#logpmodel
function logpmodel(xn, zn, wvec, stΦ)
    Dx = length(xn)
    mvec, svec = mean_and_std(zn, wvec, Dx, stΦ)
    logpdf(MvNormal(mvec, svec), xn)
end
logpmodel(xn, ηvec, wvec, ϵvec, Dz, stΦ, stΨ) = logpmodel(xn, h(ϵvec, ηvec, xn, Dz, stΨ), wvec, stΦ)

#used for calculation of ELBO
function f(xn, ηvec, Dz, stΨ)
    mvec, logsvec = mean_and_logstd(xn, ηvec, Dz, stΨ)
    Dz/2 + sum(logsvec) - norm(mvec)^2/2 - norm(exp.(logsvec))^2/2
end

#ELBO L
function ELBO(X, ηvec, wvec, ϵsamps, N, Dz, S, stΦ, stΨ)
    logpmodels = zeros(S)
    sumvec = zeros(N)
    for n in 1:N
        for s in 1:S
            @inbounds logpmodels[s] = logpmodel(X[:,n], ηvec, wvec, ϵsamps[s], Dz, stΦ, stΨ)
        end
        @inbounds sumvec[n] = mean(logpmodels) + f(X[:,n], ηvec, Dz, stΨ)
    end
    sum(sumvec)
end

#approximatio of ELBO Ln
ELBOn(xn,ηvec,wvec,ϵvec,Dz,stΦ,stΨ) = logpmodel(xn,ηvec,wvec,ϵvec,Dz,stΦ,stΨ) + f(xn,ηvec,Dz,stΨ)

#gradient of ELBO w.r.t w and η
∇wLn_samp(xm,wvec,ηvec,ϵsamp,Dz,stΦ,stΨ) = (
    ForwardDiff.gradient(vec->ELBOn(xm,ηvec,vec,ϵsamp,Dz,stΦ,stΨ), wvec)
)
function ∇wLn(minibatch, minibatch_size, N, ηvec, wvec, ϵsamp, Dz, stΦ, stΨ)
    ∇wLnvec = zeros(dw, minibatch_size)
    @threads for m in 1:minibatch_size
        @inbounds ∇wLnvec[:,m] = ∇wLn_samp(minibatch[:,m],wvec,ηvec,ϵsamp,Dz,stΦ,stΨ)
    end
    N*mean(∇wLnvec, dims=2)
end

∇ηLn_samp(xm,wvec,ηvec,ϵsamp,Dz,stΦ,stΨ) = (
    ForwardDiff.gradient(vec->ELBOn(xm,vec,wvec,ϵsamp,Dz,stΦ,stΨ), ηvec)
)
function ∇ηLn(minibatch, minibatch_size, N, ηvec, wvec, ϵsamp, Dz, stΦ, stΨ)
    ∇ηLnvec = zeros(dη, minibatch_size)
    @threads for m in 1:minibatch_size
        @inbounds ∇ηLnvec[:,m] =  ∇ηLn_samp(minibatch[:,m],wvec,ηvec,ϵsamp,Dz,stΦ,stΨ)
    end
    N*mean(∇ηLnvec, dims=2)
end    

#variational infernce
function myVI(data, model_params, n_train, minibatch_size, α)
    #initalize and set the parameters
    @unpack X,Dx,N = data
    @unpack dw,dη,Dz,stΦ,stΨ,wvec₀,ηvec₀ = model_params
    wvec = wvec₀
    ηvec = ηvec₀
    
    #minibatch
    minibatch = zeros(Dx,minibatch_size)
    
    #ELBO interval
    Δ = 50
    
    #AdaGrad
    δ = 1e-7
    gwvec = zeros(dw)
    gηvec = zeros(dη)
    rwvec = δ * ones(dw)
    rηvec = δ * ones(dη)
    
    #ELBO
    S = 100
    ϵsamp = rand(MvNormal(zeros(Dz), ones(Dz)))
    ϵsamps = rand(MvNormal(zeros(Dz), ones(Dz)), S)
    history = zeros(div(n_train, Δ))
    @inbounds @showprogress for k in 2:n_train
        #choose mini batch
        minibatch = X[:,sample(1:N, minibatch_size, replace=false)]
        
        #update(AdaGrad)
        ϵsamp = rand(MvNormal(zeros(Dz), ones(Dz)))
        gwvec = ∇wLn(minibatch, minibatch_size, N, ηvec, wvec, ϵsamp, Dz, stΦ, stΨ)
        gηvec = ∇ηLn(minibatch, minibatch_size, N, ηvec, wvec, ϵsamp, Dz, stΦ, stΨ)
        rwvec = rwvec + gwvec .* gwvec
        rηvec = rηvec + gηvec .* gηvec
        wvec += α * gwvec ./ sqrt.(rwvec)
        ηvec += α * gηvec ./ sqrt.(rηvec)
        
        #calculate ELBO
        if k%Δ == 0
            ϵsamps = rand(MvNormal(zeros(Dz), ones(Dz)), S)
            history[div(k,Δ)] = ELBO(X, ηvec, wvec, ϵsamps, N, Dz, S, stΦ, stΨ)
        end
    end
    return wvec, ηvec, history
end
【Juliaコード5; 実行】
#set the random seed
Random.seed!(42)

#latent dimension
Dz = 2

#initialize Neural Network (decoder)
DΦ = 10
stΦ = (DI=Dz, DO=2*Dx, DM=DΦ)
WΦ₂, WΦ₃, bΦ₂, bΦ₃ = init_params(stΦ)
WΦs = stick_params(WΦ₂, WΦ₃, bΦ₂, bΦ₃, stΦ.DI)
wvec₀ = WΦs[:]
dw = length(wvec₀)

#initialize Neural Network (encoder)
DΨ = 10
stΨ = (DI=Dx, DO=2*Dz, DM=DΨ)
WΨ₂, WΨ₃, bΨ₂, bΨ₃ = init_params(stΨ)
WΨs = stick_params(WΨ₂, WΨ₃, bΨ₂, bΨ₃, stΨ.DI)
ηvec₀ = WΨs[:]
dη = length(ηvec₀)

#model parameters
model_params = (dw=dw, dη=dη, Dz=Dz, stΦ=stΦ, stΨ=stΨ, wvec₀=wvec₀, ηvec₀=ηvec₀)

#variational inference
n_train = 500
α = 0.1
minibatch_size = 10
@time wvec, ηvec, history = myVI(data, model_params, n_train, minibatch_size, α)
【Juliaコード6; ELBOの表示】
fig1 = plot(1:50:500, history, label=false, marker=:circle, markersize=10, markerstrokewidth=0.5, 
    title="ELBO", xlabel="iter", ylabel="ELBO")
savefig(fig1, "figs-VAE/fig1.png")
【Juliaコード7; ELBOの絶対値の対数スケール表示】
fig2 = plot(1:50:500, abs.(history), label=false, marker=:circle, markersize=10, markerstrokewidth=0.5, 
    yscale=:log10, title="ELBO", xlabel="iter", ylabel="ELBO")
savefig(fig2, "figs-VAE/fig2.png")
【Juliaコード8; 初期パラメータでの潜在変数の分布】
Zs = zeros(Dz, N)
for n in 1:N
    Zs[:,n] = Ψ(X[:,n], ηvec₀, stΨ)[1:Dz]
end
fig3 = plot(Zs[1,:],Zs[2,:],st=:scatter,zcolor=Y,markersize=10,
    c=palette(:grays),label=false,markerstrokewidth=0.5,alpha=0.8, xlabel="z₁", ylabel="z₂", title="latent space(initial)")
savefig(fig3, "figs-VAE/fig3.png")
【Juliaコード9; 学習後のパラメータでの潜在変数の分布】
Zs = zeros(Dz, N)
for n in 1:N
    Zs[:,n] = Ψ(X[:,n], ηvec, stΨ)[1:Dz]
end
fig4=plot(Zs[1,:],Zs[2,:],st=:scatter,zcolor=Y,markersize=10,c=palette(:grays),
    label=false,markerstrokewidth=0.5,alpha=0.8, xlabel="z₁", ylabel="z₂", title="latent space")
savefig(fig4, "figs-VAE/fig4.png")
参考文献

      [1]C.Doersch, Tutorial on Variational Autoencoders, arXiv:1606.05908, 2016.
      [2]D.P.Kingma, M.Welling, Auto-Encoding Variational Bayes, 2nd International Conference on Learning Representations, 2014.
      [3]D.P.Kingma, M.Welling, An Introduction to Variational Autoencoders, Foundations and Trends in Machine Learning, 12(4), pp.307-392, 2019.
      [4]S.Mohamed, M.Rosca, M.Figurnov, A.Mnih, Monte Carlo Gradient Estimation in Machine Learning, Journal of Machine Learning Research, 21(132), pp.1-62, 2020
      [5]M.Xu, M.Quiroz, R.Kohn, S.A.Sisson, Variance reduction properties of the reparameterization trick,Proceedings of Machine Learning Research, 89, pp.2711–2720, 2019.
      [6]須山敦志, ベイズ深層学習, 講談社, 2020