[変分オートエンコーダー (VAE, M1)]の続きです。
conditional VAE, M2
M1モデルに対して、ラベル付きのデータを入力できるようにしたモデルです。
モデル
をクラスラベルを表すとします。 M2のグラフィカルモデルは下図のようになります。
この図より同時分布は下式のようになります。
\begin{align} q_\phi(x, y, z) &= q(z | x, y) q(y | x) q (x) \\ p_\theta(x, y, z) &= p(x | y, z) p (y) p (z) \end{align}
また推論モデルであるはカテゴリカル分布です。
\begin{align}
q_\phi(y | x) = Categorical ( y = k | \theta_k )
\end{align}
参考までに
統計モデリングの視点からカテゴリカル分布を考えます。この分布はベルヌーイ分布の多変量版のようなものです。各カテゴリーの生起確率、すなわちkの目がでる確率がであるようなK面のサイコロを表します。多項ロジスティク回帰で用いられます。
多項ロジスティク回帰の例では、ある購入した商品のカテゴリーYのデータがあるとします。その説明変数であるとすると、購入者ごとに各カテゴリーを選ぶ確率をとすると、以下のモデルが考えられます。
\begin{align}
\textbf{y}&= \textbf{w}^{T} \textbf{X} \\
\theta &= \text{softmax} (\textbf{y}) \\
\textbf{Y} & \sim \text{Categorical} (\theta)
\end{align}
はてなでのローマン体の書き方がわかりませんでした・・・
多クラス分類ではこの考え方に基づいています。
各モデルの確率分布
変分下限
- ラベル付き対数尤度のELBO \begin{align} \log p_\theta (x, y) &= \log \int p_\theta (x |z, y) p(z) p(y) dz \\ &= \log \int q_{\phi}(z|x,y) \frac{p_\theta (x |z, y) p(z) p(y)}{q_{\phi}(z|x,y)} dz \\ & \geq \int q_{\phi}(z|x,y) \log \bigg(\frac{p_\theta (x |z, y) p(z) p(y)}{q_{\phi}(z|x,y)} \bigg) dz \\ &= \mathbb{E}_{z \sim q_\phi (z|x, y)} [\log p_\theta (x|z, y) + \log p(z) + \log p(y) - \log q_\phi(z | x,y) ] \\ &= \mathbb{E}_{z \sim q_\phi (z|x, y)} [\log p_\theta (x|z, y) + \log p(y)] + \mathbb{E}_{z \sim q_\phi (z|x, y)}[\log \frac{p(z)}{q_\phi(z | x,y)}] \\ &= \mathbb{E}_{z \sim q_\phi (z|x, y)} [\log p_\theta (x|z, y) + \log p(y)] - D_{KL} (q_\phi(z | x,y) || p(z)) \\ & \simeq [\log p_\theta (x|z, y) + \log p(y)] - D_{KL} (q_\phi(z | x,y) || p(z)) \\ &= L(x, y) \end{align}
これを最大化させればよいです。
- ラベルなしの対数尤度のELBO。グラフの左側の方。 \begin{align} \log p_\theta (x) & \geq \mathbb{E}_{z, y \sim q_\phi (x|z, y)} [\log p_\theta (x|z, y) + \log p(z) + \log p(y) - \log q_\phi(z, y |x) ] \\ &= \mathbb{E}_{y \sim q_\phi (x|y)} [ \mathbb{E}_{z \sim q_\phi (x|z)} [ \log p_\theta (x|z, y) + \log p(z) + \log p(y) - \log q_\phi(z|x, y) - \log q_\phi(y|x) ]] \\ &= \mathbb{E}_{y \sim q_\phi (x|y)} [ - L(x,y) - \mathbb{E}_{z \sim q_\phi (x|z)} [ \log q_\phi(y|x) ]] \\ &= \mathbb{E}_{y \sim q_\phi (x|y)} [ - L(x,y) - \log q_\phi(y|x) ] \\ &= - U(x) \end{align}
以上より、目的関数は
\begin{align} J = \sum_{label} L(x,y) + \sum_{unlabel} U(x) \end{align}
また、はxの属するクラス確率を与えるため、これをクラス分類に使用することができます。そこで学習を次のようにします。
\begin{align} J = \sum_{label} L(x,y) + \sum_{unlabel} U(x) + \alpha \mathbb{E}_{x,y \sim p_l} [- \log q_\phi (y|x)] \end{align}
ただ、上式を用いなくても条件付VAEの実装はできます。一例です。
from keras.layers import Lambda, Input, Dense, merge from keras.models import Model from keras.losses import mse, binary_crossentropy from keras.utils import plot_model from keras import backend as K # Conditional Variational Autoenvoder (M2) mnist = input_data.read_data_sets("MNIST_data", one_hot=True) X_train, y_train = mnist.train.images, mnist.train.labels X_test, y_test = mnist.test.images, mnist.test.labels m = 50 x_dim = X_train.shape[1] y_dim = y_train.shape[1] z_dim = 2 n_epoch = 20 # Encoder: Q(z|X,y) X = Input(batch_shape=(m, x_dim) cond = Input(batch_shape=(m, y_dim)) inputs = merge([X, cond], mode="concat", concat_axis=1) h_q = Dense(512, activation="relu")(inputs) z_mean = Dense(z_dim, activation="linear")(h_q) z_log_var = Dense(z_dim, activation="linear")(h_q) def sampling(args): z_mean, z_log_var = args eps = K.random_normal(shape=(m, z_dim), mean=0, std=1.) return z_mean + K.exp(0.5*z_log_var) + K.random_normal(shape=(m, z_dim)) # Sample z ~ Q(z|X,y) z = Lambda(sample_z)([z_mean, z_log_var]) z_cond = merge([z, cond], mode="concat", concat_axis=1) # Decoder: P(X|z,y) decoder_h = Dense(512, activation="relu") decoder_out = Dense(784, activation="sigmoid") h_p = decoder_h(z_cond) outputs = decoder_out(h_p) reconstruction_error = K.sum(K.binary_crossentropy(inputs, outputs), axis=-1) kl_divergence = 0.5*K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) vae_loss = K.mean(reconstruction_error - kl_divergence) vae = Model(inputs, outputs, name="cvae") vae.add_loss(vae_loss) vae.compile(optimizer="adam")
次回
次はを含め実際にコーディングしたMNISTの結果とM1 + M2およびchemical VAEについて記載します。
- Auxiliary Deep Generative Modelsも試したい。