グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

Chemical VAE

一年くらい前に書いた記事の実装を行ってみました。

udnp.hatenablog.com

コード:

github.com

正しいものを使用したいときは本家本元から参照してください。

簡単に解説

分子構造はSMILESという文字列で表すことができます。例えば、ヒドロキシ基がついたベンゼン環(フェノール)はOc1ccccc1で表すことができます。文字列ならば、生成モデルを使ってOc1ccccc1→潜在空間→Oc1ccccc1となるようなモデルを作ってしまえば、多様な分子を作り出すことができるというわけです。目的の活性や物性をもつような潜在変数を探してやって、そこからデコードしてやれば目的の分子を作り出すことができるという考えです。

実装するためのポイント

前処理

まずはSMILESを機械が読み込めるような形にしていく必要があります。

化合物DBであるZINCからSMILES中の文字を数字にマッピングするための辞書を作成します。全SMILESから重複しないように文字をリストに格納していきます。

def smiles2one_hot_chars(smi_list: list) -> list:
    """obtain character in SMILES.

    :param smi_list: SMILES list
    :return: Char list
    """
    char_lists = [list(smi) for smi in smi_list]
    chars = list(set([char for sub_list in char_lists for char in sub_list]))
    chars.append(' ')

    return chars


df = pd.read_table('./data/train.csv',  header=None)
df.columns = ['smiles']
smiles = list(df.smiles)

zinc_list = smiles2one_hot_chars(smiles)

>
[
    "7", "6", "o", "]", "3", "s", "(", "-", "S", "/", "B", "4", "[", ")", "#", "I", "l", "O", "H", "c", "1",
    "@", "=", "n", "P", "8", "C", "2", "F", "5", "r", "N", "+", "\\", " "
]

次に、文字をマッピングするための辞書を作ります。

char2id = dict((c, i) for i, c in enumerate(zinc_list))
id2char = dict((i, c) for i, c in enumerate(zinc_list))

これでSMILESを文字からid, idから文字に変換することができます。ニューラルネットに学習できるように(データ数×最大原子数×辞書数)となるように変形していきます。ZINCは低分子DBなので、最大原子数は120、辞書数は35となります。Chemical VAEでは、文字列を同じにするように空白でパディングしています。

def pad_smile(string: str, max_len: int, padding: str = 'right') -> str:
    if len(string) <= max_len:
        if padding == 'right':
            return string + " " * (max_len - len(string))
        elif padding == 'left':
            return " " * (max_len - len(string)) + string
        elif padding == 'none':
            return string


def smiles_to_hot(smiles: list, max_len: int = 120, padding: str = 'right',
                  char_to_id: Dict[str, int] = char2id) -> np.ndarray:
    """smiles list into one-hot tensors.

    :param smiles: SMILES list
    :param max_len: max length of the number of atoms
    :param padding: types of padding: ('right' or 'left' or None)
    :param char_indices: dictionary of SMILES characters
    :return: one-hot matrix (Batch × MAX_LEN × len(dict))
    """
    smiles = [pad_smile(smi, max_len, padding) for smi in smiles]
    hot_x = np.zeros((len(smiles), max_len, len(char_to_id)), dtype=np.float32)
    for list_id, smile in enumerate(smiles):
        for char_id, char in enumerate(smile):
            try:
                hot_x[list_id, char_id, char_to_id[char]] = 1
            except KeyError as e:
                print("ERROR: Check chars file. Bad SMILES:", smile)
                raise e
    return hot_x

idがあれば1を追加していくという具合で配列を作成してきます。これは文章生成で行われる前処理とほとんど同じだと思います。

phenol = 'Oc1ccccc1
vec = smiles_to_hot([phenol])
vec

>array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]]], dtype=float32)

このような(1×120×35)のテンソルが得られます。ここでの目標は、この値を入力して、元に戻るように学習していくことです。SMILESへの復元は対応する番号へのマッピング操作を行ってやればよいです。

def hot_to_smiles(hot_x: np.ndarray, id2char: Dict[int, str] = id2char) -> list:
    """one hot list to SMILES list.

    :param hot_x: smiles one hot (id, max_len, node_dict)
    :param id2char: map from node id to smiles char
    :return: smiles list
    """
    smiles = ["".join([id2char[np.argmax(j)] for j in x]) for x in hot_x]
    smiles = [re.sub(' ', '', smi) for smi in smiles]  # paddingを消す
    return smiles


hot_to_smiles(vec[0])
> ['Oc1ccccc1']

学習

VAEではEncoder、Decoderを用いたモデルですが、エンコーダで得られるmu、sigmaから正規分布によるサンプリング結果をデコーダーに渡すことがAEとの違いになります。VAEの学習の数理に関しては前回の記事を参照してください。簡単に書くとELBOを最大にすることで、対数尤度を最大にするように学習を進めていけばいいです。式変形を行っていくと、損失関数は、クロスエントロピー + カルバックライブラーダイバージェンスを計算すればよいという形になります。VAEの実装に関しては様々なレポジトリがあるので、そちらを参考に実装しています。PyTorchには便利な関数があるので、

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.mean(1. + logvar - mu ** 2. - torch.exp(logvar))
    return BCE + KLD
  • Encoder エンコーダでは、1d convolutionを行っています。

  • Decoder デコーダーでは、サンプリングされたあと、GRUに渡せるように次元を(バッチ数×最大原子数×辞書数)となるようにRepeatさせています。その後は、線形変換→ソフトマックスという形で入力と同じ形になるようにしています。

サンプルコードではGANも実装しています。モデルは全く同じです。

結果

データ数、学習時間が少ないため、ほとんど元の分子に戻っていません。ネットワークも正しく決めないとダメなのかもしれません。ConvやGRU増やしたり、学習エポック数を増やしたり。GANほうが割と分子っぽい結果は出ています。 (実装も間違ってるかも)

所感

SMILESは化学構造を一意に表すことができません。最初に記載したフェノールはc1ccccc1Oやc1c(O)cccc1と書いても良いので、組み合わせ爆発を起きてしまいます。 改良法には文法規則を組み合わせたgrammar VAEがありますが、低分子ならJT-VAEやInventを使えば問題ないと思います。

chem VAEは分子生成には使うのは厳しいですが、特徴量抽出には使えるかもしれません。