グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

変分オートエンコーダー (VAE, M1)

はじめに

Alan先生のラボの論文Chemical VAEsを理解するために、VAEについてまとめます。

Chem VAE

化学構造の文字列情報(SMILES)をVAEに適用しSMILESを生成しようとする試みです。Latent Spaceの尤度を最大とするような分子が生成されます。Inductive biasを回避するためにはVAEよりもGAN basedの方 (ORGANIC)がよいと思いますが、まずはこちらから論文を見ていきたいと思います。githubにコードがあるので試すことはできますが、有効な分子はあまり生成されませんでした。

github.com

2024-02-05追記:Kerasで解説していましたが、PyTorchに書き直しています。

VAE概要

Variational AutoEncoder (VAE)は、教師無し・反教師あり学習に用いることができるオートエンコーダを利用する生成モデルです。

敵対的生成ネットワーク(Generative Adversarial Networks, GANs)と異なり尤度の計算を明示的に行います。

VAEには3つのモデルが提案されています。論文は次の通りです。

  1. [1406.5298] Semi-Supervised Learning with Deep Generative Models (2014)
  2. [1606.05908] Tutorial on Variational Autoencoders

M1

スタンダードな生成モデルです。

M2 (CVAE)

M1を条件付きにしたVAE(Conditional VAE)です。潜在変数にラベルなどを加えて同時に生成をさせてあげることで条件付けすることができます。

M1 + M2

上記のモデルを組み合わせたものです。M1の教師無し学習モデルの隠れ変数zを用いて、M2を用いて半教師あり学習を行っていきます。

M1モデル

オートエンコーダの隠れ変数(潜在変数)$z$が互いに独立であり、これらは正規分布に従うと仮定したものです。平均場近似したものとみなせます。 このようにすることで変分推論を行うことができます。

ベイズの定理

観測変数を$x$としたとき、事後分布は

$$ p(z|x) = \frac{p(x, z)}{\int p(x, z)dz} = \frac{p(x|z)p(z)}{p(x)} $$

と表現することができます。この$p(x)$が厄介で、任意の関数で近似した場合、閉じた形(closed form)にならないため、$p(z|x)$を計算することできません。

そこで、事後分布$p(z|x)$に対して、それを近似する分布$q$を考えてやろうというのがアイディアです。

変分推論

  • エンコーダーの確率分布: $P_\theta (z | x)$
  • 真の確率分布 $Q_\phi (z | x)$

とします。エンコーダーを学習させるためにはQとPのKullback-Leibler divergence (KL-divergence)を最小化させればよいです。

$$ \begin{align} D_{KL} (Q_\phi (z | x) || P_{\theta} (z | x)) &= \mathbb{E}_{z \sim Q_\phi (z|x)} [\log Q_\phi (z|x) - \log P_{\theta} (z | x) ] \\ &= \mathbb{E}_{z \sim Q_\phi (z|x)} [\log Q_\phi (z|x) - \log \frac{P_{\theta} (x | z) P_{\theta} (z)}{P_{\theta} (x) } ] \\ &= \mathbb{E}_{z \sim Q_\phi (z|x)} [\log Q_\phi (z|x) - \log P_{\theta} (x | z) - \log P_{\theta}(z) ] + \log P_{\theta} (x) \\ &= D_{KL} (Q_\phi (z | x) || P_{\theta} (z)) - \mathbb{E}_{z \sim Q_\phi (z|x)} [\log P_{\theta} (x | z)] + \log P_{\theta} (x) \\ \\ \Leftrightarrow \log P_{\theta} (x) - D_{KL} (Q_\phi (z | x) || P_{\theta} (z | x)) &= \mathbb{E}_{z \sim Q_\phi (z|x)} [\log P_{\theta} (x | z)] - D_{KL} (Q_\phi (z | x) || P_{\theta} (z)) \end{align} $$

VAEモデルを学習するための対数尤度を直接最大化することは難しいので、その変分下限を最大化します。 左辺は変分下限(variational lower bound, VLB)もしくはELBO (Evidence Lower Bound)と呼ばれています。

一方で、MCMCなどでサンプリングして、尤度を推定いく方法はcontrasive divergenceと呼ばれます。 このような方法をとる代表的な方法に、Energy-based modelsがあります。

変分問題の定式化

ELBOは次のようにして導出できます。

$$ \begin{align} \log p_{\theta}(x) &= \log \int p(X, z)dz \\ &= \log \int q(z|x) \frac{p(X, z)}{q(z|x)} dz \\ & \geq \int q(z|x) \log \frac{p(X, z)}{q(z|x)} dz \\ &= L(q_\theta (z)) \end{align} $$

ここで、3行目では、Jensenの不等式を用いました。このように、汎函数$L$は周辺尤度(エビデンス)の対数値の下限となっているため、ELBOと呼ばれています。

さらに確率密度関数積分$\int q(z|x)dz = 1$である(等号成立時)ことと、同時分布(joint distribution)に対して乗法定理を用い、式変形をしていけばELBOが求まります。

$$ \begin{align} L (X, z) &= \int q(z|x) \log \frac{p(z | X) p(X)}{q(z|x)} \\ &= \int q(z|x) \log \frac{p(z | X)}{q (z|x) } + \int q(z|x) \log p(X) dz \\ &= - D_{KL} ( q(z|x) || p(z | X) ) + \mathbb{E}_{z \sim q(z|x) }[\log p(X) ] \\ &= \log p(X) - D_{KL} ( q(z|x) || p(z | X)) \end{align} $$

以上より、VAEの目的関数は次のようになります。

$$ L(x, z) = \mathbb{E}_{z \sim Q_\phi (z|x)} [\log P_{\theta} (x | z)] - D_{KL} (Q_\phi (z | x) || P_{\theta} (z)) $$

上式を最大化させていきます(最適化計算の場合は符号を逆転させて最小化する)。

  • 1つ目の項がEncoderとDecoder全体を表す損失関数で、データXの対数尤度の期待値を最大化させることを表します。
  • 2つ目の項がEncoderとDecoderのKLダイバージェンスです。分布が近づくほど、Dは0に近づくため、ELBOを最大化できます。

Reparametarization trick (サンプリング)

この$Q(z | x)$をNNsで近似しますが、出力は確率分布であるためこのままだと計算できません。

そのために、まずEncodeから平均$\mu (z)$および分散$Σ(z)$を予測ます。

次に得られた平均と分散からサンプリングした$z$の値を用いることで近似します。

  1. $\varepsilon \sim \mathcal{N} (0, 1)$
  2. $z = \mu(x) + \varepsilon Σ^{\frac{1}{2}} (x)$

PyTorchを用いると次のようになります。

# Encoderはmu, sigmaを出力する
z_mu, z_lsgms = encoder(x)

# reparameterize
z_sgm = z_lsgms.mul(0.5).exp_()
eps = torch.randn(z_sgm.size())
z = eps * z_sgm + z_mu

# Decoderはzから復元する
x = decoder(z)

expの方がlogよりも数値安定性が良いため$Σ(X) = \log Σ(X)$と変換したのち計算します。

AEと同様に入力と出力が同じになるようなモデルとするため、Reconstruction error

$$ -\mathbb{E}_{z \sim Q_\phi (z|x)} \left[\log P_{\theta} (x | z)\right] $$

を最小化するように学習させます。そのときの損失関数はMSEやクロスエントロピーとなります。これは対数尤度を最大化させる操作と同じです。

import torch.nn.functional as F

reconst_loss = F.binary_crossentropy(inputs, outputs)

ただし、VAEを定義通りに実装するには、対数尤度$\log p(x|z)$の部分をベルヌーイ分布やガウス分布と仮定し、計算する必要があります。

bernoulli, gaussian negative log likelihoodを最小化させます。

KL Divergenceの計算

KL divergenceは解析的に求まります。計算の都合上、損失関数の最小値を求める必要があるため、負の正規分布$\mathcal{N}(0,1)$のKL divergenceは次のようになります。

$$ \mathcal{N} \left( \mu (X), Σ (X) ) || \mathcal{N}(0,1) \right) = \frac{1}{2} \sum_{j} \big(1 + \log (Σ(X)) - \mu^{2} (X) - Σ(X) \big) $$

loss_kl = - 0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
vae_loss = reconst_loss  + kl_loss

このvae_lossを最小化することは、ELBOを最大化することと同じになります。

次回はM2とM1+M2について記載します。

参考サイト

wiseodd.github.io

qiita.com

cympfh.cc

musyoku.github.io

https://www.ccn.yamanashi.ac.jp/~tmiyamoto/img/variational_bayes1.pdf